geneformer/__init__.py CHANGED
@@ -1,10 +1,4 @@
1
  # ruff: noqa: F401
2
- from pathlib import Path
3
-
4
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
5
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
6
- ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
7
-
8
  from . import (
9
  collator_for_classification,
10
  emb_extractor,
@@ -17,11 +11,11 @@ from .collator_for_classification import (
17
  DataCollatorForCellClassification,
18
  DataCollatorForGeneClassification,
19
  )
20
- from .emb_extractor import EmbExtractor, get_embs
21
  from .in_silico_perturber import InSilicoPerturber
22
  from .in_silico_perturber_stats import InSilicoPerturberStats
23
  from .pretrainer import GeneformerPretrainer
24
  from .tokenizer import TranscriptomeTokenizer
25
 
26
  from . import classifier # noqa # isort:skip
27
- from .classifier import Classifier # noqa # isort:skip
 
1
  # ruff: noqa: F401
 
 
 
 
 
 
2
  from . import (
3
  collator_for_classification,
4
  emb_extractor,
 
11
  DataCollatorForCellClassification,
12
  DataCollatorForGeneClassification,
13
  )
14
+ from .emb_extractor import EmbExtractor
15
  from .in_silico_perturber import InSilicoPerturber
16
  from .in_silico_perturber_stats import InSilicoPerturberStats
17
  from .pretrainer import GeneformerPretrainer
18
  from .tokenizer import TranscriptomeTokenizer
19
 
20
  from . import classifier # noqa # isort:skip
21
+ from .classifier import Classifier # noqa # isort:skip
geneformer/classifier.py CHANGED
@@ -53,6 +53,7 @@ from pathlib import Path
53
  import numpy as np
54
  import pandas as pd
55
  import seaborn as sns
 
56
  from tqdm.auto import tqdm, trange
57
  from transformers import Trainer
58
  from transformers.training_args import TrainingArguments
@@ -61,7 +62,7 @@ from . import DataCollatorForCellClassification, DataCollatorForGeneClassificati
61
  from . import classifier_utils as cu
62
  from . import evaluation_utils as eu
63
  from . import perturber_utils as pu
64
- from . import TOKEN_DICTIONARY_FILE
65
 
66
  sns.set()
67
 
@@ -85,7 +86,6 @@ class Classifier:
85
  "no_eval": {bool},
86
  "stratify_splits_col": {None, str},
87
  "forward_batch_size": {int},
88
- "token_dictionary_file": {None, str},
89
  "nproc": {int},
90
  "ngpu": {int},
91
  }
@@ -107,7 +107,6 @@ class Classifier:
107
  stratify_splits_col=None,
108
  no_eval=False,
109
  forward_batch_size=100,
110
- token_dictionary_file=None,
111
  nproc=4,
112
  ngpu=1,
113
  ):
@@ -176,9 +175,6 @@ class Classifier:
176
  | Otherwise, will perform eval during training.
177
  forward_batch_size : int
178
  | Batch size for forward pass (for evaluation, not training).
179
- token_dictionary_file : None, str
180
- | Default is to use token dictionary file from Geneformer
181
- | Otherwise, will load custom gene token dictionary.
182
  nproc : int
183
  | Number of CPU processes to use.
184
  ngpu : int
@@ -187,10 +183,6 @@ class Classifier:
187
  """
188
 
189
  self.classifier = classifier
190
- if self.classifier == "cell":
191
- self.model_type = "CellClassifier"
192
- elif self.classifier == "gene":
193
- self.model_type = "GeneClassifier"
194
  self.cell_state_dict = cell_state_dict
195
  self.gene_class_dict = gene_class_dict
196
  self.filter_data = filter_data
@@ -209,7 +201,6 @@ class Classifier:
209
  self.stratify_splits_col = stratify_splits_col
210
  self.no_eval = no_eval
211
  self.forward_batch_size = forward_batch_size
212
- self.token_dictionary_file = token_dictionary_file
213
  self.nproc = nproc
214
  self.ngpu = ngpu
215
 
@@ -231,9 +222,7 @@ class Classifier:
231
  ] = self.cell_state_dict["states"]
232
 
233
  # load token dictionary (Ensembl IDs:token)
234
- if self.token_dictionary_file is None:
235
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE
236
- with open(self.token_dictionary_file, "rb") as f:
237
  self.gene_token_dict = pickle.load(f)
238
 
239
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
@@ -278,7 +267,7 @@ class Classifier:
278
  continue
279
  valid_type = False
280
  for option in valid_options:
281
- if (option in [int, float, list, dict, bool, str]) and isinstance(
282
  attr_value, option
283
  ):
284
  valid_type = True
@@ -445,8 +434,8 @@ class Classifier:
445
  test_data_output_path = (
446
  Path(output_directory) / f"{output_prefix}_labeled_test"
447
  ).with_suffix(".dataset")
448
- data_dict["train"].save_to_disk(str(train_data_output_path))
449
- data_dict["test"].save_to_disk(str(test_data_output_path))
450
  elif (test_size is not None) and (self.classifier == "cell"):
451
  if 1 > test_size > 0:
452
  if attr_to_split is None:
@@ -461,8 +450,8 @@ class Classifier:
461
  test_data_output_path = (
462
  Path(output_directory) / f"{output_prefix}_labeled_test"
463
  ).with_suffix(".dataset")
464
- data_dict["train"].save_to_disk(str(train_data_output_path))
465
- data_dict["test"].save_to_disk(str(test_data_output_path))
466
  else:
467
  data_dict, balance_df = cu.balance_attr_splits(
468
  data,
@@ -483,19 +472,19 @@ class Classifier:
483
  test_data_output_path = (
484
  Path(output_directory) / f"{output_prefix}_labeled_test"
485
  ).with_suffix(".dataset")
486
- data_dict["train"].save_to_disk(str(train_data_output_path))
487
- data_dict["test"].save_to_disk(str(test_data_output_path))
488
  else:
489
  data_output_path = (
490
  Path(output_directory) / f"{output_prefix}_labeled"
491
  ).with_suffix(".dataset")
492
- data.save_to_disk(str(data_output_path))
493
  print(data_output_path)
494
  else:
495
  data_output_path = (
496
  Path(output_directory) / f"{output_prefix}_labeled"
497
  ).with_suffix(".dataset")
498
- data.save_to_disk(str(data_output_path))
499
 
500
  def train_all_data(
501
  self,
@@ -641,6 +630,7 @@ class Classifier:
641
  | Number of trials to run for hyperparameter optimization
642
  | If 0, will not optimize hyperparameters
643
  """
 
644
  if self.num_crossval_splits == 0:
645
  logger.error("num_crossval_splits must be 1 or 5 to validate.")
646
  raise
@@ -782,20 +772,17 @@ class Classifier:
782
  ]
783
  )
784
  assert len(targets) == len(labels)
785
- n_splits = int(1 / (1 - self.train_size))
786
- skf = cu.StratifiedKFold3(n_splits=n_splits, random_state=0, shuffle=True)
787
  # (Cross-)validate
788
- test_ratio = self.oos_test_size / (self.eval_size + self.oos_test_size)
789
- for train_index, eval_index, test_index in tqdm(
790
- skf.split(targets, labels, test_ratio)
791
- ):
792
  print(
793
  f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
794
  )
795
  ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
796
  # filter data for examples containing classes for this split
797
  # subsample to max_ncells and relabel data in column "labels"
798
- train_data, eval_data = cu.prep_gene_classifier_train_eval_split(
799
  data,
800
  targets,
801
  labels,
@@ -806,18 +793,6 @@ class Classifier:
806
  self.nproc,
807
  )
808
 
809
- if self.oos_test_size > 0:
810
- test_data = cu.prep_gene_classifier_split(
811
- data,
812
- targets,
813
- labels,
814
- test_index,
815
- "test",
816
- self.max_ncells,
817
- iteration_num,
818
- self.nproc,
819
- )
820
-
821
  if n_hyperopt_trials == 0:
822
  trainer = self.train_classifier(
823
  model_directory,
@@ -827,15 +802,6 @@ class Classifier:
827
  ksplit_output_dir,
828
  predict_trainer,
829
  )
830
- result = self.evaluate_model(
831
- trainer.model,
832
- num_classes,
833
- id_class_dict,
834
- eval_data,
835
- predict_eval,
836
- ksplit_output_dir,
837
- output_prefix,
838
- )
839
  else:
840
  trainer = self.hyperopt_classifier(
841
  model_directory,
@@ -845,27 +811,20 @@ class Classifier:
845
  ksplit_output_dir,
846
  n_trials=n_hyperopt_trials,
847
  )
848
-
849
- model = cu.load_best_model(
850
- ksplit_output_dir, self.model_type, num_classes
851
- )
852
-
853
- if self.oos_test_size > 0:
854
- result = self.evaluate_model(
855
- model,
856
- num_classes,
857
- id_class_dict,
858
- test_data,
859
- predict_eval,
860
- ksplit_output_dir,
861
- output_prefix,
862
- )
863
  else:
864
- if iteration_num == self.num_crossval_splits:
865
- return
866
- else:
867
- iteration_num = iteration_num + 1
868
- continue
 
 
 
 
 
 
869
  results += [result]
870
  all_conf_mat = all_conf_mat + result["conf_mat"]
871
  # break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
@@ -966,7 +925,12 @@ class Classifier:
966
  subprocess.call(f"mkdir {output_directory}", shell=True)
967
 
968
  ##### Load model and training args #####
969
- model = pu.load_model(self.model_type, num_classes, model_directory, "train")
 
 
 
 
 
970
  def_training_args, def_freeze_layers = cu.get_default_train_args(
971
  model, self.classifier, train_data, output_directory
972
  )
@@ -982,9 +946,6 @@ class Classifier:
982
  if eval_data is None:
983
  def_training_args["evaluation_strategy"] = "no"
984
  def_training_args["load_best_model_at_end"] = False
985
- def_training_args.update(
986
- {"save_strategy": "epoch", "save_total_limit": 1}
987
- ) # only save last model for each run
988
  training_args_init = TrainingArguments(**def_training_args)
989
 
990
  ##### Fine-tune the model #####
@@ -996,9 +957,7 @@ class Classifier:
996
 
997
  # define function to initiate model
998
  def model_init():
999
- model = pu.load_model(
1000
- self.model_type, num_classes, model_directory, "train"
1001
- )
1002
 
1003
  if self.freeze_layers is not None:
1004
  def_freeze_layers = self.freeze_layers
@@ -1059,7 +1018,6 @@ class Classifier:
1059
  metric="eval_macro_f1",
1060
  metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
1061
  ),
1062
- local_dir=output_directory,
1063
  )
1064
 
1065
  return trainer
@@ -1122,7 +1080,11 @@ class Classifier:
1122
  subprocess.call(f"mkdir {output_directory}", shell=True)
1123
 
1124
  ##### Load model and training args #####
1125
- model = pu.load_model(self.model_type, num_classes, model_directory, "train")
 
 
 
 
1126
 
1127
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1128
  model, self.classifier, train_data, output_directory
@@ -1276,7 +1238,11 @@ class Classifier:
1276
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1277
 
1278
  # load previously fine-tuned model
1279
- model = pu.load_model(self.model_type, num_classes, model_directory, "eval")
 
 
 
 
1280
 
1281
  # evaluate the model
1282
  result = self.evaluate_model(
 
53
  import numpy as np
54
  import pandas as pd
55
  import seaborn as sns
56
+ from sklearn.model_selection import StratifiedKFold
57
  from tqdm.auto import tqdm, trange
58
  from transformers import Trainer
59
  from transformers.training_args import TrainingArguments
 
62
  from . import classifier_utils as cu
63
  from . import evaluation_utils as eu
64
  from . import perturber_utils as pu
65
+ from .tokenizer import TOKEN_DICTIONARY_FILE
66
 
67
  sns.set()
68
 
 
86
  "no_eval": {bool},
87
  "stratify_splits_col": {None, str},
88
  "forward_batch_size": {int},
 
89
  "nproc": {int},
90
  "ngpu": {int},
91
  }
 
107
  stratify_splits_col=None,
108
  no_eval=False,
109
  forward_batch_size=100,
 
110
  nproc=4,
111
  ngpu=1,
112
  ):
 
175
  | Otherwise, will perform eval during training.
176
  forward_batch_size : int
177
  | Batch size for forward pass (for evaluation, not training).
 
 
 
178
  nproc : int
179
  | Number of CPU processes to use.
180
  ngpu : int
 
183
  """
184
 
185
  self.classifier = classifier
 
 
 
 
186
  self.cell_state_dict = cell_state_dict
187
  self.gene_class_dict = gene_class_dict
188
  self.filter_data = filter_data
 
201
  self.stratify_splits_col = stratify_splits_col
202
  self.no_eval = no_eval
203
  self.forward_batch_size = forward_batch_size
 
204
  self.nproc = nproc
205
  self.ngpu = ngpu
206
 
 
222
  ] = self.cell_state_dict["states"]
223
 
224
  # load token dictionary (Ensembl IDs:token)
225
+ with open(TOKEN_DICTIONARY_FILE, "rb") as f:
 
 
226
  self.gene_token_dict = pickle.load(f)
227
 
228
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
 
267
  continue
268
  valid_type = False
269
  for option in valid_options:
270
+ if (option in [int, float, list, dict, bool]) and isinstance(
271
  attr_value, option
272
  ):
273
  valid_type = True
 
434
  test_data_output_path = (
435
  Path(output_directory) / f"{output_prefix}_labeled_test"
436
  ).with_suffix(".dataset")
437
+ data_dict["train"].save_to_disk(train_data_output_path)
438
+ data_dict["test"].save_to_disk(test_data_output_path)
439
  elif (test_size is not None) and (self.classifier == "cell"):
440
  if 1 > test_size > 0:
441
  if attr_to_split is None:
 
450
  test_data_output_path = (
451
  Path(output_directory) / f"{output_prefix}_labeled_test"
452
  ).with_suffix(".dataset")
453
+ data_dict["train"].save_to_disk(train_data_output_path)
454
+ data_dict["test"].save_to_disk(test_data_output_path)
455
  else:
456
  data_dict, balance_df = cu.balance_attr_splits(
457
  data,
 
472
  test_data_output_path = (
473
  Path(output_directory) / f"{output_prefix}_labeled_test"
474
  ).with_suffix(".dataset")
475
+ data_dict["train"].save_to_disk(train_data_output_path)
476
+ data_dict["test"].save_to_disk(test_data_output_path)
477
  else:
478
  data_output_path = (
479
  Path(output_directory) / f"{output_prefix}_labeled"
480
  ).with_suffix(".dataset")
481
+ data.save_to_disk(data_output_path)
482
  print(data_output_path)
483
  else:
484
  data_output_path = (
485
  Path(output_directory) / f"{output_prefix}_labeled"
486
  ).with_suffix(".dataset")
487
+ data.save_to_disk(data_output_path)
488
 
489
  def train_all_data(
490
  self,
 
630
  | Number of trials to run for hyperparameter optimization
631
  | If 0, will not optimize hyperparameters
632
  """
633
+
634
  if self.num_crossval_splits == 0:
635
  logger.error("num_crossval_splits must be 1 or 5 to validate.")
636
  raise
 
772
  ]
773
  )
774
  assert len(targets) == len(labels)
775
+ n_splits = int(1 / self.eval_size)
776
+ skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)
777
  # (Cross-)validate
778
+ for train_index, eval_index in tqdm(skf.split(targets, labels)):
 
 
 
779
  print(
780
  f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
781
  )
782
  ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
783
  # filter data for examples containing classes for this split
784
  # subsample to max_ncells and relabel data in column "labels"
785
+ train_data, eval_data = cu.prep_gene_classifier_split(
786
  data,
787
  targets,
788
  labels,
 
793
  self.nproc,
794
  )
795
 
 
 
 
 
 
 
 
 
 
 
 
 
796
  if n_hyperopt_trials == 0:
797
  trainer = self.train_classifier(
798
  model_directory,
 
802
  ksplit_output_dir,
803
  predict_trainer,
804
  )
 
 
 
 
 
 
 
 
 
805
  else:
806
  trainer = self.hyperopt_classifier(
807
  model_directory,
 
811
  ksplit_output_dir,
812
  n_trials=n_hyperopt_trials,
813
  )
814
+ if iteration_num == self.num_crossval_splits:
815
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
816
  else:
817
+ iteration_num = iteration_num + 1
818
+ continue
819
+ result = self.evaluate_model(
820
+ trainer.model,
821
+ num_classes,
822
+ id_class_dict,
823
+ eval_data,
824
+ predict_eval,
825
+ ksplit_output_dir,
826
+ output_prefix,
827
+ )
828
  results += [result]
829
  all_conf_mat = all_conf_mat + result["conf_mat"]
830
  # break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
 
925
  subprocess.call(f"mkdir {output_directory}", shell=True)
926
 
927
  ##### Load model and training args #####
928
+ if self.classifier == "cell":
929
+ model_type = "CellClassifier"
930
+ elif self.classifier == "gene":
931
+ model_type = "GeneClassifier"
932
+
933
+ model = pu.load_model(model_type, num_classes, model_directory, "train")
934
  def_training_args, def_freeze_layers = cu.get_default_train_args(
935
  model, self.classifier, train_data, output_directory
936
  )
 
946
  if eval_data is None:
947
  def_training_args["evaluation_strategy"] = "no"
948
  def_training_args["load_best_model_at_end"] = False
 
 
 
949
  training_args_init = TrainingArguments(**def_training_args)
950
 
951
  ##### Fine-tune the model #####
 
957
 
958
  # define function to initiate model
959
  def model_init():
960
+ model = pu.load_model(model_type, num_classes, model_directory, "train")
 
 
961
 
962
  if self.freeze_layers is not None:
963
  def_freeze_layers = self.freeze_layers
 
1018
  metric="eval_macro_f1",
1019
  metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
1020
  ),
 
1021
  )
1022
 
1023
  return trainer
 
1080
  subprocess.call(f"mkdir {output_directory}", shell=True)
1081
 
1082
  ##### Load model and training args #####
1083
+ if self.classifier == "cell":
1084
+ model_type = "CellClassifier"
1085
+ elif self.classifier == "gene":
1086
+ model_type = "GeneClassifier"
1087
+ model = pu.load_model(model_type, num_classes, model_directory, "train")
1088
 
1089
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1090
  model, self.classifier, train_data, output_directory
 
1238
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1239
 
1240
  # load previously fine-tuned model
1241
+ if self.classifier == "cell":
1242
+ model_type = "CellClassifier"
1243
+ elif self.classifier == "gene":
1244
+ model_type = "GeneClassifier"
1245
+ model = pu.load_model(model_type, num_classes, model_directory, "eval")
1246
 
1247
  # evaluate the model
1248
  result = self.evaluate_model(
geneformer/classifier_utils.py CHANGED
@@ -1,6 +1,4 @@
1
- import json
2
  import logging
3
- import os
4
  import random
5
  from collections import Counter, defaultdict
6
 
@@ -8,7 +6,6 @@ import numpy as np
8
  import pandas as pd
9
  from scipy.stats import chisquare, ranksums
10
  from sklearn.metrics import accuracy_score, f1_score
11
- from sklearn.model_selection import StratifiedKFold, train_test_split
12
 
13
  from . import perturber_utils as pu
14
 
@@ -136,55 +133,61 @@ def label_gene_classes(example, class_id_dict, gene_class_dict):
136
  ]
137
 
138
 
139
- def prep_gene_classifier_train_eval_split(
140
- data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
141
- ):
142
- # generate cross-validation splits
143
- train_data = prep_gene_classifier_split(
144
- data, targets, labels, train_index, "train", max_ncells, iteration_num, num_proc
145
- )
146
- eval_data = prep_gene_classifier_split(
147
- data, targets, labels, eval_index, "eval", max_ncells, iteration_num, num_proc
148
- )
149
- return train_data, eval_data
150
-
151
-
152
  def prep_gene_classifier_split(
153
- data, targets, labels, index, subset_name, max_ncells, iteration_num, num_proc
154
  ):
155
  # generate cross-validation splits
156
  targets = np.array(targets)
157
  labels = np.array(labels)
158
- targets_subset = targets[index]
159
- labels_subset = labels[index]
160
- label_dict_subset = dict(zip(targets_subset, labels_subset))
 
161
 
162
  # function to filter by whether contains train or eval labels
163
- def if_contains_subset_label(example):
164
- a = targets_subset
 
 
 
 
 
165
  b = example["input_ids"]
166
  return not set(a).isdisjoint(b)
167
 
168
  # filter dataset for examples containing classes for this split
169
- logger.info(f"Filtering data for {subset_name} genes in split {iteration_num}")
170
- subset_data = data.filter(if_contains_subset_label, num_proc=num_proc)
171
  logger.info(
172
- f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
 
 
 
 
 
173
  )
174
 
175
  # subsample to max_ncells
176
- subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
 
177
 
178
  # relabel genes for this split
179
- def subset_classes_to_ids(example):
180
  example["labels"] = [
181
- label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
182
  ]
183
  return example
184
 
185
- subset_data = subset_data.map(subset_classes_to_ids, num_proc=num_proc)
 
 
 
 
186
 
187
- return subset_data
 
 
 
188
 
189
 
190
  def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
@@ -306,7 +309,7 @@ def balance_attr_splits(
306
  exp_counts[cat] * sum(obs) / sum(exp_counts.values())
307
  for cat in all_categ
308
  ]
309
- pval = chisquare(f_obs=obs, f_exp=exp).pvalue
310
  train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
311
  eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
312
  df_vals += [train_attr_counts, eval_attr_counts, pval]
@@ -420,45 +423,3 @@ def get_default_train_args(model, classifier, data, output_dir):
420
  training_args.update(default_training_args)
421
 
422
  return training_args, freeze_layers
423
-
424
-
425
- def load_best_model(directory, model_type, num_classes, mode="eval"):
426
- file_dict = dict()
427
- for subdir, dirs, files in os.walk(directory):
428
- for file in files:
429
- if file.endswith("result.json"):
430
- with open(f"{subdir}/{file}", "rb") as fp:
431
- result_json = json.load(fp)
432
- file_dict[f"{subdir}"] = result_json["eval_macro_f1"]
433
- file_df = pd.DataFrame(
434
- {"dir": file_dict.keys(), "eval_macro_f1": file_dict.values()}
435
- )
436
- model_superdir = (
437
- "run-"
438
- + file_df.iloc[file_df["eval_macro_f1"].idxmax()]["dir"]
439
- .split("_objective_")[2]
440
- .split("_")[0]
441
- )
442
-
443
- for subdir, dirs, files in os.walk(f"{directory}/{model_superdir}"):
444
- for file in files:
445
- if file.endswith("model.safetensors"):
446
- model = pu.load_model(model_type, num_classes, f"{subdir}", mode)
447
- return model
448
-
449
-
450
- class StratifiedKFold3(StratifiedKFold):
451
- def split(self, targets, labels, test_ratio=0.5, groups=None):
452
- s = super().split(targets, labels, groups)
453
- for train_indxs, test_indxs in s:
454
- if test_ratio == 0:
455
- yield train_indxs, test_indxs, None
456
- else:
457
- labels_test = np.array(labels)[test_indxs]
458
- valid_indxs, test_indxs = train_test_split(
459
- test_indxs,
460
- stratify=labels_test,
461
- test_size=test_ratio,
462
- random_state=0,
463
- )
464
- yield train_indxs, valid_indxs, test_indxs
 
 
1
  import logging
 
2
  import random
3
  from collections import Counter, defaultdict
4
 
 
6
  import pandas as pd
7
  from scipy.stats import chisquare, ranksums
8
  from sklearn.metrics import accuracy_score, f1_score
 
9
 
10
  from . import perturber_utils as pu
11
 
 
133
  ]
134
 
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def prep_gene_classifier_split(
137
+ data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
138
  ):
139
  # generate cross-validation splits
140
  targets = np.array(targets)
141
  labels = np.array(labels)
142
+ targets_train, targets_eval = targets[train_index], targets[eval_index]
143
+ labels_train, labels_eval = labels[train_index], labels[eval_index]
144
+ label_dict_train = dict(zip(targets_train, labels_train))
145
+ label_dict_eval = dict(zip(targets_eval, labels_eval))
146
 
147
  # function to filter by whether contains train or eval labels
148
+ def if_contains_train_label(example):
149
+ a = targets_train
150
+ b = example["input_ids"]
151
+ return not set(a).isdisjoint(b)
152
+
153
+ def if_contains_eval_label(example):
154
+ a = targets_eval
155
  b = example["input_ids"]
156
  return not set(a).isdisjoint(b)
157
 
158
  # filter dataset for examples containing classes for this split
159
+ logger.info(f"Filtering training data for genes in split {iteration_num}")
160
+ train_data = data.filter(if_contains_train_label, num_proc=num_proc)
161
  logger.info(
162
+ f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
163
+ )
164
+ logger.info(f"Filtering evalation data for genes in split {iteration_num}")
165
+ eval_data = data.filter(if_contains_eval_label, num_proc=num_proc)
166
+ logger.info(
167
+ f"Filtered {round((1-len(eval_data)/len(data))*100)}%; {len(eval_data)} remain\n"
168
  )
169
 
170
  # subsample to max_ncells
171
+ train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
172
+ eval_data = downsample_and_shuffle(eval_data, max_ncells, None, None)
173
 
174
  # relabel genes for this split
175
+ def train_classes_to_ids(example):
176
  example["labels"] = [
177
+ label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
178
  ]
179
  return example
180
 
181
+ def eval_classes_to_ids(example):
182
+ example["labels"] = [
183
+ label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
184
+ ]
185
+ return example
186
 
187
+ train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
188
+ eval_data = eval_data.map(eval_classes_to_ids, num_proc=num_proc)
189
+
190
+ return train_data, eval_data
191
 
192
 
193
  def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
 
309
  exp_counts[cat] * sum(obs) / sum(exp_counts.values())
310
  for cat in all_categ
311
  ]
312
+ chisquare(f_obs=obs, f_exp=exp).pvalue
313
  train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
314
  eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
315
  df_vals += [train_attr_counts, eval_attr_counts, pval]
 
423
  training_args.update(default_training_args)
424
 
425
  return training_args, freeze_layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/collator_for_classification.py CHANGED
@@ -4,7 +4,6 @@ Geneformer collator for gene and cell classification.
4
  Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
5
  """
6
  import numpy as np
7
- import pickle
8
  import torch
9
  import warnings
10
  from enum import Enum
@@ -18,11 +17,7 @@ from transformers import (
18
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
19
  from transformers.utils.generic import _is_tensorflow, _is_torch
20
 
21
- from . import TOKEN_DICTIONARY_FILE
22
-
23
- # load token dictionary (Ensembl IDs:token)
24
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
25
- token_dictionary = pickle.load(f)
26
 
27
  EncodedInput = List[int]
28
  logger = logging.get_logger(__name__)
 
4
  Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
5
  """
6
  import numpy as np
 
7
  import torch
8
  import warnings
9
  from enum import Enum
 
17
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
18
  from transformers.utils.generic import _is_tensorflow, _is_torch
19
 
20
+ from .pretrainer import token_dictionary
 
 
 
 
21
 
22
  EncodedInput = List[int]
23
  logger = logging.get_logger(__name__)
geneformer/emb_extractor.py CHANGED
@@ -25,7 +25,7 @@ from tdigest import TDigest
25
  from tqdm.auto import trange
26
 
27
  from . import perturber_utils as pu
28
- from . import TOKEN_DICTIONARY_FILE
29
 
30
  logger = logging.getLogger(__name__)
31
 
@@ -38,19 +38,19 @@ def get_embs(
38
  layer_to_quant,
39
  pad_token_id,
40
  forward_batch_size,
41
- token_gene_dict,
42
- special_token=False,
43
  summary_stat=None,
44
  silent=False,
45
  ):
46
  model_input_size = pu.get_model_input_size(model)
47
  total_batch_length = len(filtered_input_data)
48
-
49
  if summary_stat is None:
50
  embs_list = []
51
  elif summary_stat is not None:
52
- # get # of emb dims
53
- emb_dims = pu.get_model_emb_dims(model)
 
 
54
  if emb_mode == "cell":
55
  # initiate tdigests for # of emb dims
56
  embs_tdigests = [TDigest() for _ in range(emb_dims)]
@@ -67,23 +67,8 @@ def get_embs(
67
  k: [TDigest() for _ in range(emb_dims)] for k in gene_set
68
  }
69
 
70
- # Check if CLS and EOS token is present in the token dictionary
71
- cls_present = any("<cls>" in value for value in token_gene_dict.values())
72
- eos_present = any("<eos>" in value for value in token_gene_dict.values())
73
- if emb_mode == "cls":
74
- assert cls_present, "<cls> token missing in token dictionary"
75
- # Check to make sure that the first token of the filtered input data is cls token
76
- gene_token_dict = {v:k for k,v in token_gene_dict.items()}
77
- cls_token_id = gene_token_dict["<cls>"]
78
- assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
79
- elif emb_mode == "cell":
80
- if cls_present:
81
- logger.warning("CLS token present in token dictionary, excluding from average.")
82
- if eos_present:
83
- logger.warning("EOS token present in token dictionary, excluding from average.")
84
-
85
  overall_max_len = 0
86
-
87
  for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
88
  max_range = min(i + forward_batch_size, total_batch_length)
89
 
@@ -107,14 +92,7 @@ def get_embs(
107
  embs_i = outputs.hidden_states[layer_to_quant]
108
 
109
  if emb_mode == "cell":
110
- if cls_present:
111
- non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
112
- if eos_present:
113
- mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
114
- else:
115
- mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 1)
116
- else:
117
- mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
118
  if summary_stat is None:
119
  embs_list.append(mean_embs)
120
  elif summary_stat is not None:
@@ -143,13 +121,7 @@ def get_embs(
143
  accumulate_tdigests(
144
  embs_tdigests_dict[int(k)], dict_h[k], emb_dims
145
  )
146
- del embs_h
147
- del dict_h
148
- elif emb_mode == "cls":
149
- cls_embs = embs_i[:,0,:].clone().detach() # CLS token layer
150
- embs_list.append(cls_embs)
151
- del cls_embs
152
-
153
  overall_max_len = max(overall_max_len, max_len)
154
  del outputs
155
  del minibatch
@@ -157,10 +129,9 @@ def get_embs(
157
  del embs_i
158
 
159
  torch.cuda.empty_cache()
160
-
161
-
162
  if summary_stat is None:
163
- if (emb_mode == "cell") or (emb_mode == "cls"):
164
  embs_stack = torch.cat(embs_list, dim=0)
165
  elif emb_mode == "gene":
166
  embs_stack = pu.pad_tensor_list(
@@ -204,6 +175,7 @@ def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
204
  for j in range(emb_dims)
205
  ]
206
 
 
207
  def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
208
  embs_tdigests_dict[gene] = accumulate_tdigests(
209
  embs_tdigests_dict[gene], gene_embs, emb_dims
@@ -237,6 +209,14 @@ def tdigest_median(embs_tdigests, emb_dims):
237
  return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
238
 
239
 
 
 
 
 
 
 
 
 
240
  def label_cell_embs(embs, downsampled_data, emb_labels):
241
  embs_df = pd.DataFrame(embs.cpu().numpy())
242
  if emb_labels is not None:
@@ -272,7 +252,7 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict):
272
  return embs_df
273
 
274
 
275
- def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
276
  only_embs_df = embs_df.iloc[:, :emb_dims]
277
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
278
  only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
@@ -282,17 +262,15 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
282
  obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
283
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
284
  sc.tl.pca(adata, svd_solver="arpack")
285
- sc.pp.neighbors(adata, random_state=seed)
286
- sc.tl.umap(adata, random_state=seed)
287
  sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
288
  sns.set_style("white")
289
  default_kwargs_dict = {"palette": "Set2", "size": 200}
290
  if kwargs_dict is not None:
291
  default_kwargs_dict.update(kwargs_dict)
292
 
293
- with plt.rc_context():
294
- sc.pl.umap(adata, color=label, **default_kwargs_dict)
295
- plt.savefig(output_file, bbox_inches="tight")
296
 
297
 
298
  def gen_heatmap_class_colors(labels, df):
@@ -368,8 +346,7 @@ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
368
  bbox_to_anchor=(0.5, 1),
369
  facecolor="white",
370
  )
371
- plt.show()
372
- logger.info(f"Output file: {output_file}")
373
  plt.savefig(output_file, bbox_inches="tight")
374
 
375
 
@@ -377,7 +354,7 @@ class EmbExtractor:
377
  valid_option_dict = {
378
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
379
  "num_classes": {int},
380
- "emb_mode": {"cls", "cell", "gene"},
381
  "cell_emb_style": {"mean_pool"},
382
  "gene_emb_style": {"mean_pool"},
383
  "filter_data": {None, dict},
@@ -386,7 +363,6 @@ class EmbExtractor:
386
  "emb_label": {None, list},
387
  "labels_to_plot": {None, list},
388
  "forward_batch_size": {int},
389
- "token_dictionary_file" : {None, str},
390
  "nproc": {int},
391
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
392
  }
@@ -406,7 +382,7 @@ class EmbExtractor:
406
  forward_batch_size=100,
407
  nproc=4,
408
  summary_stat=None,
409
- token_dictionary_file=None,
410
  ):
411
  """
412
  Initialize embedding extractor.
@@ -418,11 +394,10 @@ class EmbExtractor:
418
  num_classes : int
419
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
420
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
421
- emb_mode : {"cls", "cell", "gene"}
422
- | Whether to output CLS, cell, or gene embeddings.
423
- | CLS embeddings are cell embeddings derived from the CLS token in the front of the rank value encoding.
424
- cell_emb_style : {"mean_pool"}
425
- | Method for summarizing cell embeddings if not using CLS token.
426
  | Currently only option is mean pooling of gene embeddings for given cell.
427
  gene_emb_style : "mean_pool"
428
  | Method for summarizing gene embeddings.
@@ -457,7 +432,6 @@ class EmbExtractor:
457
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
458
  | Non-exact is slower but more memory-efficient.
459
  token_dictionary_file : Path
460
- | Default is the Geneformer token dictionary
461
  | Path to pickle file containing token dictionary (Ensembl ID:token).
462
 
463
  **Examples:**
@@ -487,7 +461,6 @@ class EmbExtractor:
487
  self.emb_layer = emb_layer
488
  self.emb_label = emb_label
489
  self.labels_to_plot = labels_to_plot
490
- self.token_dictionary_file = token_dictionary_file
491
  self.forward_batch_size = forward_batch_size
492
  self.nproc = nproc
493
  if (summary_stat is not None) and ("exact" in summary_stat):
@@ -500,8 +473,6 @@ class EmbExtractor:
500
  self.validate_options()
501
 
502
  # load token dictionary (Ensembl IDs:token)
503
- if self.token_dictionary_file is None:
504
- token_dictionary_file = TOKEN_DICTIONARY_FILE
505
  with open(token_dictionary_file, "rb") as f:
506
  self.gene_token_dict = pickle.load(f)
507
 
@@ -517,7 +488,7 @@ class EmbExtractor:
517
  continue
518
  valid_type = False
519
  for option in valid_options:
520
- if (option in [int, list, dict, bool, str]) and isinstance(
521
  attr_value, option
522
  ):
523
  valid_type = True
@@ -591,14 +562,13 @@ class EmbExtractor:
591
  )
592
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
593
  embs = get_embs(
594
- model=model,
595
- filtered_input_data=downsampled_data,
596
- emb_mode=self.emb_mode,
597
- layer_to_quant=layer_to_quant,
598
- pad_token_id=self.pad_token_id,
599
- forward_batch_size=self.forward_batch_size,
600
- token_gene_dict=self.token_gene_dict,
601
- summary_stat=self.summary_stat,
602
  )
603
 
604
  if self.emb_mode == "cell":
@@ -612,8 +582,6 @@ class EmbExtractor:
612
  elif self.summary_stat is not None:
613
  embs_df = pd.DataFrame(embs).T
614
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
615
- elif self.emb_mode == "cls":
616
- embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
617
 
618
  # save embeddings to output_path
619
  if cell_state is None:
@@ -622,15 +590,13 @@ class EmbExtractor:
622
 
623
  if self.exact_summary_stat == "exact_mean":
624
  embs = embs.mean(dim=0)
625
- emb_dims = pu.get_model_emb_dims(model)
626
  embs_df = pd.DataFrame(
627
- embs_df[0:emb_dims-1].mean(axis="rows"), columns=[self.exact_summary_stat]
628
  ).T
629
  elif self.exact_summary_stat == "exact_median":
630
  embs = torch.median(embs, dim=0)[0]
631
- emb_dims = pu.get_model_emb_dims(model)
632
  embs_df = pd.DataFrame(
633
- embs_df[0:emb_dims-1].median(axis="rows"), columns=[self.exact_summary_stat]
634
  ).T
635
 
636
  if cell_state is not None:
@@ -813,11 +779,11 @@ class EmbExtractor:
813
  f"not present in provided embeddings dataframe."
814
  )
815
  continue
816
- output_prefix_label = output_prefix + f"_umap_{label}"
817
  output_file = (
818
  Path(output_directory) / output_prefix_label
819
  ).with_suffix(".pdf")
820
- plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
821
 
822
  if plot_style == "heatmap":
823
  for label in self.labels_to_plot:
@@ -831,4 +797,4 @@ class EmbExtractor:
831
  output_file = (
832
  Path(output_directory) / output_prefix_label
833
  ).with_suffix(".pdf")
834
- plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
 
25
  from tqdm.auto import trange
26
 
27
  from . import perturber_utils as pu
28
+ from .tokenizer import TOKEN_DICTIONARY_FILE
29
 
30
  logger = logging.getLogger(__name__)
31
 
 
38
  layer_to_quant,
39
  pad_token_id,
40
  forward_batch_size,
 
 
41
  summary_stat=None,
42
  silent=False,
43
  ):
44
  model_input_size = pu.get_model_input_size(model)
45
  total_batch_length = len(filtered_input_data)
46
+
47
  if summary_stat is None:
48
  embs_list = []
49
  elif summary_stat is not None:
50
+ # test embedding extraction for example cell and extract # emb dims
51
+ example = filtered_input_data.select([i for i in range(1)])
52
+ example.set_format(type="torch")
53
+ emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
54
  if emb_mode == "cell":
55
  # initiate tdigests for # of emb dims
56
  embs_tdigests = [TDigest() for _ in range(emb_dims)]
 
67
  k: [TDigest() for _ in range(emb_dims)] for k in gene_set
68
  }
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  overall_max_len = 0
71
+
72
  for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
73
  max_range = min(i + forward_batch_size, total_batch_length)
74
 
 
92
  embs_i = outputs.hidden_states[layer_to_quant]
93
 
94
  if emb_mode == "cell":
95
+ mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
 
 
 
 
 
 
 
96
  if summary_stat is None:
97
  embs_list.append(mean_embs)
98
  elif summary_stat is not None:
 
121
  accumulate_tdigests(
122
  embs_tdigests_dict[int(k)], dict_h[k], emb_dims
123
  )
124
+
 
 
 
 
 
 
125
  overall_max_len = max(overall_max_len, max_len)
126
  del outputs
127
  del minibatch
 
129
  del embs_i
130
 
131
  torch.cuda.empty_cache()
132
+
 
133
  if summary_stat is None:
134
+ if emb_mode == "cell":
135
  embs_stack = torch.cat(embs_list, dim=0)
136
  elif emb_mode == "gene":
137
  embs_stack = pu.pad_tensor_list(
 
175
  for j in range(emb_dims)
176
  ]
177
 
178
+
179
  def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
180
  embs_tdigests_dict[gene] = accumulate_tdigests(
181
  embs_tdigests_dict[gene], gene_embs, emb_dims
 
209
  return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
210
 
211
 
212
+ def test_emb(model, example, layer_to_quant):
213
+ with torch.no_grad():
214
+ outputs = model(input_ids=example.to("cuda"))
215
+
216
+ embs_test = outputs.hidden_states[layer_to_quant]
217
+ return embs_test.size()[2]
218
+
219
+
220
  def label_cell_embs(embs, downsampled_data, emb_labels):
221
  embs_df = pd.DataFrame(embs.cpu().numpy())
222
  if emb_labels is not None:
 
252
  return embs_df
253
 
254
 
255
+ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict):
256
  only_embs_df = embs_df.iloc[:, :emb_dims]
257
  only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
258
  only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
 
262
  obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
263
  adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
264
  sc.tl.pca(adata, svd_solver="arpack")
265
+ sc.pp.neighbors(adata)
266
+ sc.tl.umap(adata)
267
  sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
268
  sns.set_style("white")
269
  default_kwargs_dict = {"palette": "Set2", "size": 200}
270
  if kwargs_dict is not None:
271
  default_kwargs_dict.update(kwargs_dict)
272
 
273
+ sc.pl.umap(adata, color=label, save=output_file, **default_kwargs_dict)
 
 
274
 
275
 
276
  def gen_heatmap_class_colors(labels, df):
 
346
  bbox_to_anchor=(0.5, 1),
347
  facecolor="white",
348
  )
349
+
 
350
  plt.savefig(output_file, bbox_inches="tight")
351
 
352
 
 
354
  valid_option_dict = {
355
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
356
  "num_classes": {int},
357
+ "emb_mode": {"cell", "gene"},
358
  "cell_emb_style": {"mean_pool"},
359
  "gene_emb_style": {"mean_pool"},
360
  "filter_data": {None, dict},
 
363
  "emb_label": {None, list},
364
  "labels_to_plot": {None, list},
365
  "forward_batch_size": {int},
 
366
  "nproc": {int},
367
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
368
  }
 
382
  forward_batch_size=100,
383
  nproc=4,
384
  summary_stat=None,
385
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
386
  ):
387
  """
388
  Initialize embedding extractor.
 
394
  num_classes : int
395
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
396
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
397
+ emb_mode : {"cell", "gene"}
398
+ | Whether to output cell or gene embeddings.
399
+ cell_emb_style : "mean_pool"
400
+ | Method for summarizing cell embeddings.
 
401
  | Currently only option is mean pooling of gene embeddings for given cell.
402
  gene_emb_style : "mean_pool"
403
  | Method for summarizing gene embeddings.
 
432
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
433
  | Non-exact is slower but more memory-efficient.
434
  token_dictionary_file : Path
 
435
  | Path to pickle file containing token dictionary (Ensembl ID:token).
436
 
437
  **Examples:**
 
461
  self.emb_layer = emb_layer
462
  self.emb_label = emb_label
463
  self.labels_to_plot = labels_to_plot
 
464
  self.forward_batch_size = forward_batch_size
465
  self.nproc = nproc
466
  if (summary_stat is not None) and ("exact" in summary_stat):
 
473
  self.validate_options()
474
 
475
  # load token dictionary (Ensembl IDs:token)
 
 
476
  with open(token_dictionary_file, "rb") as f:
477
  self.gene_token_dict = pickle.load(f)
478
 
 
488
  continue
489
  valid_type = False
490
  for option in valid_options:
491
+ if (option in [int, list, dict, bool]) and isinstance(
492
  attr_value, option
493
  ):
494
  valid_type = True
 
562
  )
563
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
564
  embs = get_embs(
565
+ model,
566
+ downsampled_data,
567
+ self.emb_mode,
568
+ layer_to_quant,
569
+ self.pad_token_id,
570
+ self.forward_batch_size,
571
+ self.summary_stat,
 
572
  )
573
 
574
  if self.emb_mode == "cell":
 
582
  elif self.summary_stat is not None:
583
  embs_df = pd.DataFrame(embs).T
584
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
 
 
585
 
586
  # save embeddings to output_path
587
  if cell_state is None:
 
590
 
591
  if self.exact_summary_stat == "exact_mean":
592
  embs = embs.mean(dim=0)
 
593
  embs_df = pd.DataFrame(
594
+ embs_df[0:255].mean(axis="rows"), columns=[self.exact_summary_stat]
595
  ).T
596
  elif self.exact_summary_stat == "exact_median":
597
  embs = torch.median(embs, dim=0)[0]
 
598
  embs_df = pd.DataFrame(
599
+ embs_df[0:255].median(axis="rows"), columns=[self.exact_summary_stat]
600
  ).T
601
 
602
  if cell_state is not None:
 
779
  f"not present in provided embeddings dataframe."
780
  )
781
  continue
782
+ output_prefix_label = "_" + output_prefix + f"_umap_{label}"
783
  output_file = (
784
  Path(output_directory) / output_prefix_label
785
  ).with_suffix(".pdf")
786
+ plot_umap(embs, emb_dims, label, output_prefix_label, kwargs_dict)
787
 
788
  if plot_style == "heatmap":
789
  for label in self.labels_to_plot:
 
797
  output_file = (
798
  Path(output_directory) / output_prefix_label
799
  ).with_suffix(".pdf")
800
+ plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
geneformer/evaluation_utils.py CHANGED
@@ -21,7 +21,7 @@ from sklearn.metrics import (
21
  from tqdm.auto import trange
22
 
23
  from .emb_extractor import make_colorbar
24
- from . import TOKEN_DICTIONARY_FILE
25
 
26
  logger = logging.getLogger(__name__)
27
 
 
21
  from tqdm.auto import trange
22
 
23
  from .emb_extractor import make_colorbar
24
+ from .tokenizer import TOKEN_DICTIONARY_FILE
25
 
26
  logger = logging.getLogger(__name__)
27
 
geneformer/in_silico_perturber.py CHANGED
@@ -38,18 +38,19 @@ import logging
38
  import os
39
  import pickle
40
  from collections import defaultdict
41
- from multiprocess import set_start_method
42
  from typing import List
43
 
 
44
  import torch
45
- from datasets import Dataset, disable_progress_bars
46
  from tqdm.auto import trange
47
 
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
50
- from . import TOKEN_DICTIONARY_FILE
 
 
51
 
52
- disable_progress_bars()
53
 
54
  logger = logging.getLogger(__name__)
55
 
@@ -184,10 +185,6 @@ class InSilicoPerturber:
184
  token_dictionary_file : Path
185
  | Path to pickle file containing token dictionary (Ensembl ID:token).
186
  """
187
- try:
188
- set_start_method("spawn")
189
- except RuntimeError:
190
- pass
191
 
192
  self.perturb_type = perturb_type
193
  self.perturb_rank_shift = perturb_rank_shift
@@ -225,7 +222,6 @@ class InSilicoPerturber:
225
  # load token dictionary (Ensembl IDs:token)
226
  with open(token_dictionary_file, "rb") as f:
227
  self.gene_token_dict = pickle.load(f)
228
- self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
229
 
230
  self.pad_token_id = self.gene_token_dict.get("<pad>")
231
 
@@ -426,7 +422,6 @@ class InSilicoPerturber:
426
  self.max_len = pu.get_model_input_size(model)
427
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
428
 
429
-
430
  ### filter input data ###
431
  # general filtering of input data based on filter_data argument
432
  filtered_input_data = pu.load_and_filter(
@@ -525,7 +520,6 @@ class InSilicoPerturber:
525
  perturbed_data = filtered_input_data.map(
526
  make_group_perturbation_batch, num_proc=self.nproc
527
  )
528
-
529
  if self.perturb_type == "overexpress":
530
  filtered_input_data = filtered_input_data.add_column(
531
  "n_overflow", perturbed_data["n_overflow"]
@@ -558,7 +552,6 @@ class InSilicoPerturber:
558
  layer_to_quant,
559
  self.pad_token_id,
560
  self.forward_batch_size,
561
- token_gene_dict=self.token_gene_dict,
562
  summary_stat=None,
563
  silent=True,
564
  )
@@ -578,7 +571,6 @@ class InSilicoPerturber:
578
  layer_to_quant,
579
  self.pad_token_id,
580
  self.forward_batch_size,
581
- token_gene_dict=self.token_gene_dict,
582
  summary_stat=None,
583
  silent=True,
584
  )
@@ -738,7 +730,6 @@ class InSilicoPerturber:
738
  layer_to_quant,
739
  self.pad_token_id,
740
  self.forward_batch_size,
741
- token_gene_dict=self.token_gene_dict,
742
  summary_stat=None,
743
  silent=True,
744
  )
@@ -766,7 +757,6 @@ class InSilicoPerturber:
766
  layer_to_quant,
767
  self.pad_token_id,
768
  self.forward_batch_size,
769
- token_gene_dict=self.token_gene_dict,
770
  summary_stat=None,
771
  silent=True,
772
  )
 
38
  import os
39
  import pickle
40
  from collections import defaultdict
 
41
  from typing import List
42
 
43
+ import seaborn as sns
44
  import torch
45
+ from datasets import Dataset
46
  from tqdm.auto import trange
47
 
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
50
+ from .tokenizer import TOKEN_DICTIONARY_FILE
51
+
52
+ sns.set()
53
 
 
54
 
55
  logger = logging.getLogger(__name__)
56
 
 
185
  token_dictionary_file : Path
186
  | Path to pickle file containing token dictionary (Ensembl ID:token).
187
  """
 
 
 
 
188
 
189
  self.perturb_type = perturb_type
190
  self.perturb_rank_shift = perturb_rank_shift
 
222
  # load token dictionary (Ensembl IDs:token)
223
  with open(token_dictionary_file, "rb") as f:
224
  self.gene_token_dict = pickle.load(f)
 
225
 
226
  self.pad_token_id = self.gene_token_dict.get("<pad>")
227
 
 
422
  self.max_len = pu.get_model_input_size(model)
423
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
424
 
 
425
  ### filter input data ###
426
  # general filtering of input data based on filter_data argument
427
  filtered_input_data = pu.load_and_filter(
 
520
  perturbed_data = filtered_input_data.map(
521
  make_group_perturbation_batch, num_proc=self.nproc
522
  )
 
523
  if self.perturb_type == "overexpress":
524
  filtered_input_data = filtered_input_data.add_column(
525
  "n_overflow", perturbed_data["n_overflow"]
 
552
  layer_to_quant,
553
  self.pad_token_id,
554
  self.forward_batch_size,
 
555
  summary_stat=None,
556
  silent=True,
557
  )
 
571
  layer_to_quant,
572
  self.pad_token_id,
573
  self.forward_batch_size,
 
574
  summary_stat=None,
575
  silent=True,
576
  )
 
730
  layer_to_quant,
731
  self.pad_token_id,
732
  self.forward_batch_size,
 
733
  summary_stat=None,
734
  silent=True,
735
  )
 
757
  layer_to_quant,
758
  self.pad_token_id,
759
  self.forward_batch_size,
 
760
  summary_stat=None,
761
  silent=True,
762
  )
geneformer/in_silico_perturber_stats.py CHANGED
@@ -38,7 +38,9 @@ from sklearn.mixture import GaussianMixture
38
  from tqdm.auto import tqdm, trange
39
 
40
  from .perturber_utils import flatten_list, validate_cell_states_to_model
41
- from . import TOKEN_DICTIONARY_FILE, ENSEMBL_DICTIONARY_FILE
 
 
42
 
43
  logger = logging.getLogger(__name__)
44
 
@@ -190,48 +192,22 @@ def get_impact_component(test_value, gaussian_mixture_model):
190
 
191
 
192
  # aggregate data for single perturbation in multiple cells
193
- def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
194
- names = ["Cosine_sim", "Gene"]
195
- cos_sims_full_dfs = []
196
- if isinstance(genes_perturbed,list):
197
- if len(genes_perturbed)>1:
198
- gene_ids_df = cos_sims_df.loc[np.isin([set(idx) for idx in cos_sims_df["Ensembl_ID"]], set(genes_perturbed)), :]
199
- else:
200
- gene_ids_df = cos_sims_df.loc[np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :]
201
- else:
202
- logger.error(
203
- "aggregate_data is for perturbation of single gene or single group of genes. genes_to_perturb should be formatted as list."
204
- )
205
- raise
206
-
207
- if gene_ids_df.empty:
208
- logger.error(
209
- "genes_to_perturb not found in data."
210
- )
211
- raise
212
-
213
- tokens = gene_ids_df["Gene"]
214
- symbols = gene_ids_df["Gene_name"]
215
-
216
- for token, symbol in zip(tokens, symbols):
217
- cos_shift_data = []
218
- for dict_i in dict_list:
219
- cos_shift_data += dict_i.get((token, "cell_emb"), [])
220
 
221
- df = pd.DataFrame(columns=names)
222
- df["Cosine_sim"] = cos_shift_data
223
- df["Gene"] = symbol
224
- cos_sims_full_dfs.append(df)
225
-
226
- return pd.concat(cos_sims_full_dfs)
227
 
228
 
229
  def find(variable, x):
230
  try:
231
  if x in variable: # Test if variable is iterable and contains x
232
  return True
233
- elif x == variable:
234
- return True
235
  except (ValueError, TypeError):
236
  return x == variable # Test if variable is x if non-iterable
237
 
@@ -272,15 +248,15 @@ def isp_aggregate_gene_shifts(
272
  cos_sims_full_df["Affected_Ensembl_ID"] = [
273
  gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
274
  ]
275
- cos_sims_full_df["Cosine_sim_mean"] = [v[0] for k, v in cos_data_mean.items()]
276
- cos_sims_full_df["Cosine_sim_stdev"] = [v[1] for k, v in cos_data_mean.items()]
277
  cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
278
 
279
  specific_val = "cell_emb"
280
  cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
281
- # reorder so cell embs are at the top and all are subordered by magnitude of cosine sim
282
  cos_sims_full_df = cos_sims_full_df.sort_values(
283
- by=(["temp", "Cosine_sim_mean"]), ascending=[False, True]
284
  ).drop("temp", axis=1)
285
 
286
  return cos_sims_full_df
@@ -671,7 +647,7 @@ class InSilicoPerturberStats:
671
  cell_states_to_model=None,
672
  pickle_suffix="_raw.pickle",
673
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
674
- gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
675
  ):
676
  """
677
  Initialize in silico perturber stats generator.
@@ -938,11 +914,11 @@ class InSilicoPerturberStats:
938
  | 1: within impact component; 0: not within impact component
939
  | "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
940
 
941
- | In case of aggregating data / gene shifts:
942
  | "Perturbed": ID(s) of gene(s) being perturbed
943
  | "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
944
- | "Cosine_sim_mean": mean of cosine similarity of cell or affected gene in original vs. perturbed
945
- | "Cosine_sim_stdev": standard deviation of cosine similarity of cell or affected gene in original vs. perturbed
946
  """
947
 
948
  if self.mode not in [
@@ -1041,8 +1017,8 @@ class InSilicoPerturberStats:
1041
  cos_sims_df_initial, dict_list, self.combos, self.anchor_token
1042
  )
1043
 
1044
- elif self.mode == "aggregate_data":
1045
- cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list, self.genes_perturbed)
1046
 
1047
  elif self.mode == "aggregate_gene_shifts":
1048
  cos_sims_df = isp_aggregate_gene_shifts(
 
38
  from tqdm.auto import tqdm, trange
39
 
40
  from .perturber_utils import flatten_list, validate_cell_states_to_model
41
+ from .tokenizer import TOKEN_DICTIONARY_FILE
42
+
43
+ GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
44
 
45
  logger = logging.getLogger(__name__)
46
 
 
192
 
193
 
194
  # aggregate data for single perturbation in multiple cells
195
+ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
196
+ names = ["Cosine_shift"]
197
+ cos_sims_full_df = pd.DataFrame(columns=names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ cos_shift_data = []
200
+ token = cos_sims_df["Gene"][0]
201
+ for dict_i in dict_list:
202
+ cos_shift_data += dict_i.get((token, "cell_emb"), [])
203
+ cos_sims_full_df["Cosine_shift"] = cos_shift_data
204
+ return cos_sims_full_df
205
 
206
 
207
  def find(variable, x):
208
  try:
209
  if x in variable: # Test if variable is iterable and contains x
210
  return True
 
 
211
  except (ValueError, TypeError):
212
  return x == variable # Test if variable is x if non-iterable
213
 
 
248
  cos_sims_full_df["Affected_Ensembl_ID"] = [
249
  gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
250
  ]
251
+ cos_sims_full_df["Cosine_shift_mean"] = [v[0] for k, v in cos_data_mean.items()]
252
+ cos_sims_full_df["Cosine_shift_stdev"] = [v[1] for k, v in cos_data_mean.items()]
253
  cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
254
 
255
  specific_val = "cell_emb"
256
  cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
257
+ # reorder so cell embs are at the top and all are subordered by magnitude of cosine shift
258
  cos_sims_full_df = cos_sims_full_df.sort_values(
259
+ by=(["temp", "Cosine_shift_mean"]), ascending=[False, False]
260
  ).drop("temp", axis=1)
261
 
262
  return cos_sims_full_df
 
647
  cell_states_to_model=None,
648
  pickle_suffix="_raw.pickle",
649
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
650
+ gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
651
  ):
652
  """
653
  Initialize in silico perturber stats generator.
 
914
  | 1: within impact component; 0: not within impact component
915
  | "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
916
 
917
+ | In case of aggregating gene shifts:
918
  | "Perturbed": ID(s) of gene(s) being perturbed
919
  | "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
920
+ | "Cosine_shift_mean": mean of cosine shift of modeled perturbation on affected gene or cell
921
+ | "Cosine_shift_stdev": standard deviation of cosine shift of modeled perturbation on affected gene or cell
922
  """
923
 
924
  if self.mode not in [
 
1017
  cos_sims_df_initial, dict_list, self.combos, self.anchor_token
1018
  )
1019
 
1020
+ elif self.mode == "aggregate_data":
1021
+ cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
1022
 
1023
  elif self.mode == "aggregate_gene_shifts":
1024
  cos_sims_df = isp_aggregate_gene_shifts(
geneformer/perturber_utils.py CHANGED
@@ -4,8 +4,6 @@ import pickle
4
  import re
5
  from collections import defaultdict
6
  from typing import List
7
- from pathlib import Path
8
-
9
 
10
  import numpy as np
11
  import pandas as pd
@@ -18,8 +16,7 @@ from transformers import (
18
  BertForTokenClassification,
19
  )
20
 
21
- from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE, ENSEMBL_DICTIONARY_FILE
22
-
23
 
24
  logger = logging.getLogger(__name__)
25
 
@@ -152,12 +149,8 @@ def quant_layers(model):
152
  return int(max(layer_nums)) + 1
153
 
154
 
155
- def get_model_emb_dims(model):
156
- return model.config.hidden_size
157
-
158
-
159
  def get_model_input_size(model):
160
- return model.config.max_position_embeddings
161
 
162
 
163
  def flatten_list(megalist):
@@ -588,11 +581,9 @@ def quant_cos_sims(
588
  elif emb_mode == "cell":
589
  cos = torch.nn.CosineSimilarity(dim=1)
590
 
591
- # if emb_mode == "gene", can only calculate gene cos sims
592
- # against original cell anyways
593
- if cell_states_to_model is None or emb_mode == "gene":
594
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
595
- elif cell_states_to_model is not None and emb_mode == "cell":
596
  possible_states = get_possible_states(cell_states_to_model)
597
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
598
  for state in possible_states:
@@ -714,48 +705,3 @@ def validate_cell_states_to_model(cell_states_to_model):
714
  "'alt_states': ['hcm', 'other1', 'other2']}"
715
  )
716
  raise
717
-
718
- class GeneIdHandler:
719
- def __init__(self, raise_errors=False):
720
- def invert_dict(dict_obj):
721
- return {v:k for k,v in dict_obj.items()}
722
-
723
- self.raise_errors = raise_errors
724
-
725
- with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
726
- self.gene_token_dict = pickle.load(f)
727
- self.token_gene_dict = invert_dict(self.gene_token_dict)
728
-
729
- with open(ENSEMBL_DICTIONARY_FILE, 'rb') as f:
730
- self.id_gene_dict = pickle.load(f)
731
- self.gene_id_dict = invert_dict(self.id_gene_dict)
732
-
733
- def ens_to_token(self, ens_id):
734
- if not self.raise_errors:
735
- return self.gene_token_dict.get(ens_id, ens_id)
736
- else:
737
- return self.gene_token_dict[ens_id]
738
-
739
- def token_to_ens(self, token):
740
- if not self.raise_errors:
741
- return self.token_gene_dict.get(token, token)
742
- else:
743
- return self.token_gene_dict[token]
744
-
745
- def ens_to_symbol(self, ens_id):
746
- if not self.raise_errors:
747
- return self.gene_id_dict.get(ens_id, ens_id)
748
- else:
749
- return self.gene_id_dict[ens_id]
750
-
751
- def symbol_to_ens(self, symbol):
752
- if not self.raise_errors:
753
- return self.id_gene_dict.get(symbol, symbol)
754
- else:
755
- return self.id_gene_dict[symbol]
756
-
757
- def token_to_symbol(self, token):
758
- return self.ens_to_symbol(self.token_to_ens(token))
759
-
760
- def symbol_to_token(self, symbol):
761
- return self.ens_to_token(self.symbol_to_ens(symbol))
 
4
  import re
5
  from collections import defaultdict
6
  from typing import List
 
 
7
 
8
  import numpy as np
9
  import pandas as pd
 
16
  BertForTokenClassification,
17
  )
18
 
19
+ sns.set()
 
20
 
21
  logger = logging.getLogger(__name__)
22
 
 
149
  return int(max(layer_nums)) + 1
150
 
151
 
 
 
 
 
152
  def get_model_input_size(model):
153
+ return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
154
 
155
 
156
  def flatten_list(megalist):
 
581
  elif emb_mode == "cell":
582
  cos = torch.nn.CosineSimilarity(dim=1)
583
 
584
+ if cell_states_to_model is None:
 
 
585
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
586
+ else:
587
  possible_states = get_possible_states(cell_states_to_model)
588
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
589
  for state in possible_states:
 
705
  "'alt_states': ['hcm', 'other1', 'other2']}"
706
  )
707
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/pretrainer.py CHANGED
@@ -32,7 +32,7 @@ from transformers.training_args import ParallelMode
32
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
  from transformers.utils.generic import _is_tensorflow, _is_torch
34
 
35
- from . import TOKEN_DICTIONARY_FILE
36
 
37
  logger = logging.get_logger(__name__)
38
  EncodedInput = List[int]
@@ -106,8 +106,9 @@ class TensorType(ExplicitEnum):
106
 
107
  class GeneformerPreCollator(SpecialTokensMixin):
108
  def __init__(self, *args, **kwargs) -> None:
109
- super().__init__(mask_token="<mask>", pad_token="<pad>")
110
-
 
111
  self.token_dictionary = kwargs.get("token_dictionary")
112
  # self.mask_token = "<mask>"
113
  # self.mask_token_id = self.token_dictionary.get("<mask>")
@@ -119,8 +120,8 @@ class GeneformerPreCollator(SpecialTokensMixin):
119
  # self.token_dictionary.get("<pad>"),
120
  # ]
121
  self.model_input_names = ["input_ids"]
122
-
123
- def convert_ids_to_tokens(self, value):
124
  return self.token_dictionary.get(value)
125
 
126
  def _get_padding_truncation_strategies(
@@ -390,6 +391,7 @@ class GeneformerPreCollator(SpecialTokensMixin):
390
 
391
  for key, value in encoded_inputs.items():
392
  encoded_inputs[key] = to_py_obj(value)
 
393
 
394
  # Convert padding_strategy in PaddingStrategy
395
  padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
@@ -594,17 +596,15 @@ class GeneformerPreCollator(SpecialTokensMixin):
594
 
595
  class GeneformerPretrainer(Trainer):
596
  def __init__(self, *args, **kwargs):
597
- data_collator = kwargs.get("data_collator", None)
598
  token_dictionary = kwargs.pop("token_dictionary")
599
- mlm = kwargs.pop("mlm", True)
600
- mlm_probability = kwargs.pop("mlm_probability", 0.15)
601
 
602
  if data_collator is None:
603
  precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
604
 
605
  # # Data Collator Functions
606
  data_collator = DataCollatorForLanguageModeling(
607
- tokenizer=precollator, mlm=mlm, mlm_probability=mlm_probability
608
  )
609
  kwargs["data_collator"] = data_collator
610
 
@@ -694,7 +694,6 @@ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
694
  Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
695
  length while keeping a bit of randomness.
696
  """
697
-
698
  # Copied and adapted from PyTorch DistributedSampler.
699
  def __init__(
700
  self,
@@ -758,7 +757,7 @@ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
758
  # Deterministically shuffle based on epoch and seed
759
  g = torch.Generator()
760
  g.manual_seed(self.seed + self.epoch)
761
-
762
  indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
763
 
764
  if not self.drop_last:
 
32
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
  from transformers.utils.generic import _is_tensorflow, _is_torch
34
 
35
+ from .tokenizer import TOKEN_DICTIONARY_FILE
36
 
37
  logger = logging.get_logger(__name__)
38
  EncodedInput = List[int]
 
106
 
107
  class GeneformerPreCollator(SpecialTokensMixin):
108
  def __init__(self, *args, **kwargs) -> None:
109
+
110
+ super().__init__(mask_token = "<mask>", pad_token = "<pad>")
111
+
112
  self.token_dictionary = kwargs.get("token_dictionary")
113
  # self.mask_token = "<mask>"
114
  # self.mask_token_id = self.token_dictionary.get("<mask>")
 
120
  # self.token_dictionary.get("<pad>"),
121
  # ]
122
  self.model_input_names = ["input_ids"]
123
+
124
+ def convert_ids_to_tokens(self,value):
125
  return self.token_dictionary.get(value)
126
 
127
  def _get_padding_truncation_strategies(
 
391
 
392
  for key, value in encoded_inputs.items():
393
  encoded_inputs[key] = to_py_obj(value)
394
+
395
 
396
  # Convert padding_strategy in PaddingStrategy
397
  padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
 
596
 
597
  class GeneformerPretrainer(Trainer):
598
  def __init__(self, *args, **kwargs):
599
+ data_collator = kwargs.get("data_collator",None)
600
  token_dictionary = kwargs.pop("token_dictionary")
 
 
601
 
602
  if data_collator is None:
603
  precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
604
 
605
  # # Data Collator Functions
606
  data_collator = DataCollatorForLanguageModeling(
607
+ tokenizer=precollator, mlm=True, mlm_probability=0.15
608
  )
609
  kwargs["data_collator"] = data_collator
610
 
 
694
  Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
695
  length while keeping a bit of randomness.
696
  """
 
697
  # Copied and adapted from PyTorch DistributedSampler.
698
  def __init__(
699
  self,
 
757
  # Deterministically shuffle based on epoch and seed
758
  g = torch.Generator()
759
  g.manual_seed(self.seed + self.epoch)
760
+
761
  indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
762
 
763
  if not self.drop_last:
geneformer/tokenizer.py CHANGED
@@ -52,7 +52,8 @@ import loompy as lp # noqa
52
 
53
  logger = logging.getLogger(__name__)
54
 
55
- from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
 
56
 
57
 
58
  def rank_genes(gene_vector, gene_tokens):
@@ -102,7 +103,7 @@ class TranscriptomeTokenizer:
102
  model_input_size : int = 2048
103
  | Max input size of model to truncate input to.
104
  special_token : bool = False
105
- | Adds CLS token before and EOS token after rank value encoding.
106
  gene_median_file : Path
107
  | Path to pickle file containing dictionary of non-zero median
108
  | gene expression values across Genecorpus-30M.
@@ -122,7 +123,7 @@ class TranscriptomeTokenizer:
122
  # input size for tokenization
123
  self.model_input_size = model_input_size
124
 
125
- # add CLS and EOS tokens
126
  self.special_token = special_token
127
 
128
  # load dictionary of gene normalization factors
@@ -175,7 +176,7 @@ class TranscriptomeTokenizer:
175
  )
176
 
177
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
178
- tokenized_dataset.save_to_disk(str(output_path))
179
 
180
  def tokenize_files(
181
  self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"
@@ -377,14 +378,14 @@ class TranscriptomeTokenizer:
377
  if self.special_token:
378
  example["input_ids"] = example["input_ids"][
379
  0 : self.model_input_size - 2
380
- ] # truncate to leave space for CLS and EOS token
381
  example["input_ids"] = np.insert(
382
  example["input_ids"], 0, self.gene_token_dict.get("<cls>")
383
  )
384
  example["input_ids"] = np.insert(
385
  example["input_ids"],
386
  len(example["input_ids"]),
387
- self.gene_token_dict.get("<eos>"),
388
  )
389
  else:
390
  # Truncate/Crop input_ids to input size
 
52
 
53
  logger = logging.getLogger(__name__)
54
 
55
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
56
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
57
 
58
 
59
  def rank_genes(gene_vector, gene_tokens):
 
103
  model_input_size : int = 2048
104
  | Max input size of model to truncate input to.
105
  special_token : bool = False
106
+ | Adds CLS token before and SEP token after rank value encoding.
107
  gene_median_file : Path
108
  | Path to pickle file containing dictionary of non-zero median
109
  | gene expression values across Genecorpus-30M.
 
123
  # input size for tokenization
124
  self.model_input_size = model_input_size
125
 
126
+ # add CLS and SEP tokens
127
  self.special_token = special_token
128
 
129
  # load dictionary of gene normalization factors
 
176
  )
177
 
178
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
179
+ tokenized_dataset.save_to_disk(output_path)
180
 
181
  def tokenize_files(
182
  self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"
 
378
  if self.special_token:
379
  example["input_ids"] = example["input_ids"][
380
  0 : self.model_input_size - 2
381
+ ] # truncate to leave space for CLS and SEP token
382
  example["input_ids"] = np.insert(
383
  example["input_ids"], 0, self.gene_token_dict.get("<cls>")
384
  )
385
  example["input_ids"] = np.insert(
386
  example["input_ids"],
387
  len(example["input_ids"]),
388
+ self.gene_token_dict.get("<sep>"),
389
  )
390
  else:
391
  # Truncate/Crop input_ids to input size