Update geneformer/perturber_utils.py

#362
by hchen725 - opened
geneformer/in_silico_perturber.py CHANGED
@@ -38,21 +38,17 @@ import logging
38
  import os
39
  import pickle
40
  from collections import defaultdict
41
- from typing import List
42
  from multiprocess import set_start_method
43
 
44
- import seaborn as sns
45
  import torch
46
- from datasets import Dataset
47
  from tqdm.auto import trange
48
 
49
  from . import perturber_utils as pu
50
  from .emb_extractor import get_embs
51
  from .perturber_utils import TOKEN_DICTIONARY_FILE
52
 
53
-
54
- sns.set()
55
-
56
 
57
  logger = logging.getLogger(__name__)
58
 
@@ -66,7 +62,7 @@ class InSilicoPerturber:
66
  "anchor_gene": {None, str},
67
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
68
  "num_classes": {int},
69
- "emb_mode": {"cell", "cell_and_gene"},
70
  "cell_emb_style": {"mean_pool"},
71
  "filter_data": {None, dict},
72
  "cell_states_to_model": {None, dict},
@@ -74,6 +70,7 @@ class InSilicoPerturber:
74
  "max_ncells": {None, int},
75
  "cell_inds_to_perturb": {"all", dict},
76
  "emb_layer": {-1, 0},
 
77
  "forward_batch_size": {int},
78
  "nproc": {int},
79
  }
@@ -97,7 +94,8 @@ class InSilicoPerturber:
97
  emb_layer=-1,
98
  forward_batch_size=100,
99
  nproc=4,
100
- token_dictionary_file=TOKEN_DICTIONARY_FILE,
 
101
  ):
102
  """
103
  Initialize in silico perturber.
@@ -137,11 +135,11 @@ class InSilicoPerturber:
137
  num_classes : int
138
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
139
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
140
- emb_mode : {"cell", "cell_and_gene"}
141
- | Whether to output impact of perturbation on cell and/or gene embeddings.
142
  | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
143
  cell_emb_style : "mean_pool"
144
- | Method for summarizing cell embeddings.
145
  | Currently only option is mean pooling of gene embeddings for given cell.
146
  filter_data : None, dict
147
  | Default is to use all input data for in silico perturbation study.
@@ -186,6 +184,8 @@ class InSilicoPerturber:
186
  | Number of CPU processes to use.
187
  token_dictionary_file : Path
188
  | Path to pickle file containing token dictionary (Ensembl ID:token).
 
 
189
  """
190
  try:
191
  set_start_method("spawn")
@@ -222,14 +222,31 @@ class InSilicoPerturber:
222
  self.emb_layer = emb_layer
223
  self.forward_batch_size = forward_batch_size
224
  self.nproc = nproc
 
 
225
 
226
  self.validate_options()
227
 
228
  # load token dictionary (Ensembl IDs:token)
 
 
229
  with open(token_dictionary_file, "rb") as f:
230
  self.gene_token_dict = pickle.load(f)
 
231
 
232
  self.pad_token_id = self.gene_token_dict.get("<pad>")
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  if self.anchor_gene is None:
235
  self.anchor_token = None
@@ -287,7 +304,7 @@ class InSilicoPerturber:
287
  continue
288
  valid_type = False
289
  for option in valid_options:
290
- if (option in [bool, int, list, dict]) and isinstance(
291
  attr_value, option
292
  ):
293
  valid_type = True
@@ -428,22 +445,46 @@ class InSilicoPerturber:
428
  self.max_len = pu.get_model_input_size(model)
429
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
430
 
431
-
432
  ### filter input data ###
433
  # general filtering of input data based on filter_data argument
434
  filtered_input_data = pu.load_and_filter(
435
  self.filter_data, self.nproc, input_data_file
436
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
438
 
439
  if self.perturb_group is True:
440
- self.isp_perturb_set(
441
- model, filtered_input_data, layer_to_quant, output_path_prefix
442
- )
 
 
 
 
 
443
  else:
444
- self.isp_perturb_all(
445
- model, filtered_input_data, layer_to_quant, output_path_prefix
446
- )
 
 
 
 
 
447
 
448
  def apply_additional_filters(self, filtered_input_data):
449
  # additional filtering of input data dependent on isp mode
@@ -488,6 +529,7 @@ class InSilicoPerturber:
488
  layer_to_quant: int,
489
  output_path_prefix: str,
490
  ):
 
491
  def make_group_perturbation_batch(example):
492
  example_input_ids = example["input_ids"]
493
  example["tokens_to_perturb"] = self.tokens_to_perturb
@@ -506,7 +548,7 @@ class InSilicoPerturber:
506
  if self.perturb_type == "delete":
507
  example = pu.delete_indices(example)
508
  elif self.perturb_type == "overexpress":
509
- example = pu.overexpress_tokens(example, self.max_len)
510
  example["n_overflow"] = pu.calc_n_overflow(
511
  self.max_len,
512
  example["length"],
@@ -560,6 +602,7 @@ class InSilicoPerturber:
560
  layer_to_quant,
561
  self.pad_token_id,
562
  self.forward_batch_size,
 
563
  summary_stat=None,
564
  silent=True,
565
  )
@@ -579,6 +622,7 @@ class InSilicoPerturber:
579
  layer_to_quant,
580
  self.pad_token_id,
581
  self.forward_batch_size,
 
582
  summary_stat=None,
583
  silent=True,
584
  )
@@ -678,8 +722,6 @@ class InSilicoPerturber:
678
  cos_sims_dict = self.update_perturbation_dictionary(
679
  cos_sims_dict,
680
  cos_sims_data,
681
- filtered_input_data,
682
- indices_to_perturb,
683
  gene_list,
684
  )
685
  else:
@@ -688,8 +730,6 @@ class InSilicoPerturber:
688
  cos_sims_dict[state] = self.update_perturbation_dictionary(
689
  cos_sims_dict[state],
690
  cos_sims_data[state],
691
- filtered_input_data,
692
- indices_to_perturb,
693
  gene_list,
694
  )
695
  del minibatch
@@ -711,6 +751,256 @@ class InSilicoPerturber:
711
  f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
712
  )
713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  def isp_perturb_all(
715
  self,
716
  model,
@@ -738,10 +1028,10 @@ class InSilicoPerturber:
738
  layer_to_quant,
739
  self.pad_token_id,
740
  self.forward_batch_size,
 
741
  summary_stat=None,
742
  silent=True,
743
  )
744
-
745
  # gene_list is used to assign cos sims back to genes
746
  # need to remove the anchor gene
747
  gene_list = example_cell["input_ids"][0][:]
@@ -765,10 +1055,13 @@ class InSilicoPerturber:
765
  layer_to_quant,
766
  self.pad_token_id,
767
  self.forward_batch_size,
 
768
  summary_stat=None,
769
  silent=True,
770
  )
771
 
 
 
772
  num_inds_perturbed = 1 + self.combos
773
  # need to remove overexpressed gene to quantify cosine shifts
774
  if self.perturb_type == "overexpress":
@@ -780,11 +1073,11 @@ class InSilicoPerturber:
780
  elif self.perturb_type == "delete":
781
  perturbation_emb = full_perturbation_emb
782
 
783
- original_batch = pu.make_comparison_batch(
784
- full_original_emb, indices_to_perturb, perturb_group=False
785
- )
786
 
787
  if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
 
 
 
788
  gene_cos_sims = pu.quant_cos_sims(
789
  perturbation_emb,
790
  original_batch,
@@ -792,6 +1085,8 @@ class InSilicoPerturber:
792
  self.state_embs_dict,
793
  emb_mode="gene",
794
  )
 
 
795
  if self.cell_states_to_model is not None:
796
  original_cell_emb = pu.compute_nonpadded_cell_embedding(
797
  full_original_emb, "mean_pool"
@@ -807,6 +1102,8 @@ class InSilicoPerturber:
807
  self.state_embs_dict,
808
  emb_mode="cell",
809
  )
 
 
810
 
811
  if self.emb_mode == "cell_and_gene":
812
  # remove perturbed index for gene list
@@ -828,13 +1125,14 @@ class InSilicoPerturber:
828
  (perturbed_gene, affected_gene)
829
  ] = gene_cos_sims[perturbation_i, gene_j].item()
830
 
 
 
 
831
  if self.cell_states_to_model is None:
832
  cos_sims_data = torch.mean(gene_cos_sims, dim=1)
833
  cos_sims_dict = self.update_perturbation_dictionary(
834
  cos_sims_dict,
835
  cos_sims_data,
836
- filtered_input_data,
837
- indices_to_perturb,
838
  gene_list,
839
  )
840
  else:
@@ -843,25 +1141,23 @@ class InSilicoPerturber:
843
  cos_sims_dict[state] = self.update_perturbation_dictionary(
844
  cos_sims_dict[state],
845
  cos_sims_data[state],
846
- filtered_input_data,
847
- indices_to_perturb,
848
  gene_list,
849
  )
850
 
851
  # save dict to disk every 100 cells
852
- if i % 100 == 0:
853
  pu.write_perturbation_dictionary(
854
  cos_sims_dict,
855
- f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
856
  )
857
  if self.emb_mode == "cell_and_gene":
858
  pu.write_perturbation_dictionary(
859
  stored_gene_embs_dict,
860
- f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
861
  )
862
 
863
  # reset and clear memory every 1000 cells
864
- if i % 1000 == 0:
865
  pickle_batch += 1
866
  if self.cell_states_to_model is None:
867
  cos_sims_dict = defaultdict(list)
@@ -877,28 +1173,270 @@ class InSilicoPerturber:
877
  torch.cuda.empty_cache()
878
 
879
  pu.write_perturbation_dictionary(
880
- cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
881
  )
882
 
883
  if self.emb_mode == "cell_and_gene":
884
  pu.write_perturbation_dictionary(
885
  stored_gene_embs_dict,
886
- f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
887
  )
888
 
 
889
  def update_perturbation_dictionary(
890
  self,
891
  cos_sims_dict: defaultdict,
892
  cos_sims_data: torch.Tensor,
893
- filtered_input_data: Dataset,
894
- indices_to_perturb: List[List[int]],
895
  gene_list=None,
896
  ):
897
  if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
898
  logger.error(
899
  f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
900
- cos_sims_data.shape[0] = {cos_sims_data.shape[0]}.\n \
901
- len(gene_list) = {len(gene_list)}."
902
  )
903
  raise
904
 
@@ -922,4 +1460,4 @@ class InSilicoPerturber:
922
  for i, cos in enumerate(cos_sims_data.tolist()):
923
  cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
924
 
925
- return cos_sims_dict
 
38
  import os
39
  import pickle
40
  from collections import defaultdict
 
41
  from multiprocess import set_start_method
42
 
 
43
  import torch
44
+ from datasets import Dataset, disable_progress_bars
45
  from tqdm.auto import trange
46
 
47
  from . import perturber_utils as pu
48
  from .emb_extractor import get_embs
49
  from .perturber_utils import TOKEN_DICTIONARY_FILE
50
 
51
+ disable_progress_bars()
 
 
52
 
53
  logger = logging.getLogger(__name__)
54
 
 
62
  "anchor_gene": {None, str},
63
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
64
  "num_classes": {int},
65
+ "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
66
  "cell_emb_style": {"mean_pool"},
67
  "filter_data": {None, dict},
68
  "cell_states_to_model": {None, dict},
 
70
  "max_ncells": {None, int},
71
  "cell_inds_to_perturb": {"all", dict},
72
  "emb_layer": {-1, 0},
73
+ "token_dictionary_file" : {None, str},
74
  "forward_batch_size": {int},
75
  "nproc": {int},
76
  }
 
94
  emb_layer=-1,
95
  forward_batch_size=100,
96
  nproc=4,
97
+ token_dictionary_file=None,
98
+ clear_mem_ncells=1000,
99
  ):
100
  """
101
  Initialize in silico perturber.
 
135
  num_classes : int
136
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
137
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
138
+ emb_mode : {"cls", "cell", "cls_and_gene","cell_and_gene"}
139
+ | Whether to output impact of perturbation on CLS token, cell, and/or gene embeddings.
140
  | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
141
  cell_emb_style : "mean_pool"
142
+ | Method for summarizing cell embeddings if not using CLS token.
143
  | Currently only option is mean pooling of gene embeddings for given cell.
144
  filter_data : None, dict
145
  | Default is to use all input data for in silico perturbation study.
 
184
  | Number of CPU processes to use.
185
  token_dictionary_file : Path
186
  | Path to pickle file containing token dictionary (Ensembl ID:token).
187
+ clear_mem_ncells : int
188
+ | Clear memory every n cells.
189
  """
190
  try:
191
  set_start_method("spawn")
 
222
  self.emb_layer = emb_layer
223
  self.forward_batch_size = forward_batch_size
224
  self.nproc = nproc
225
+ self.token_dictionary_file = token_dictionary_file
226
+ self.clear_mem_ncells = clear_mem_ncells
227
 
228
  self.validate_options()
229
 
230
  # load token dictionary (Ensembl IDs:token)
231
+ if self.token_dictionary_file is None:
232
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
233
  with open(token_dictionary_file, "rb") as f:
234
  self.gene_token_dict = pickle.load(f)
235
+ self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
236
 
237
  self.pad_token_id = self.gene_token_dict.get("<pad>")
238
+ self.cls_token_id = self.gene_token_dict.get("<cls>")
239
+ self.eos_token_id = self.gene_token_dict.get("<eos>")
240
+
241
+
242
+ # Identify if special token is present in the token dictionary
243
+ if (self.cls_token_id is not None) and (self.eos_token_id is not None):
244
+ self.special_token = True
245
+ else:
246
+ if "cls" in self.emb_mode:
247
+ logger.error(f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary.")
248
+ raise
249
+ self.special_token = False
250
 
251
  if self.anchor_gene is None:
252
  self.anchor_token = None
 
304
  continue
305
  valid_type = False
306
  for option in valid_options:
307
+ if (option in [bool, int, list, dict, str]) and isinstance(
308
  attr_value, option
309
  ):
310
  valid_type = True
 
445
  self.max_len = pu.get_model_input_size(model)
446
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
447
 
 
448
  ### filter input data ###
449
  # general filtering of input data based on filter_data argument
450
  filtered_input_data = pu.load_and_filter(
451
  self.filter_data, self.nproc, input_data_file
452
  )
453
+
454
+ # Ensure emb_mode is cls if first token of the filtered input data is cls token
455
+ if self.special_token:
456
+ if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and ("cls" not in self.emb_mode):
457
+ logger.error(
458
+ "Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
459
+ )
460
+ raise
461
+ if ("cls" in self.emb_mode):
462
+ if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (filtered_input_data["input_ids"][0][-1] != self.eos_token_id):
463
+ logger.error(
464
+ "Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
465
+ )
466
+ raise
467
+
468
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
469
 
470
  if self.perturb_group is True:
471
+ if (self.special_token) and ("cls" in self.emb_mode):
472
+ self.isp_perturb_set_special(
473
+ model, filtered_input_data, layer_to_quant, output_path_prefix
474
+ )
475
+ else:
476
+ self.isp_perturb_set(
477
+ model, filtered_input_data, layer_to_quant, output_path_prefix
478
+ )
479
  else:
480
+ if (self.special_token) and ("cls" in self.emb_mode):
481
+ self.isp_perturb_all_special(
482
+ model, filtered_input_data, layer_to_quant, output_path_prefix
483
+ )
484
+ else:
485
+ self.isp_perturb_all(
486
+ model, filtered_input_data, layer_to_quant, output_path_prefix
487
+ )
488
 
489
  def apply_additional_filters(self, filtered_input_data):
490
  # additional filtering of input data dependent on isp mode
 
529
  layer_to_quant: int,
530
  output_path_prefix: str,
531
  ):
532
+
533
  def make_group_perturbation_batch(example):
534
  example_input_ids = example["input_ids"]
535
  example["tokens_to_perturb"] = self.tokens_to_perturb
 
548
  if self.perturb_type == "delete":
549
  example = pu.delete_indices(example)
550
  elif self.perturb_type == "overexpress":
551
+ example = pu.overexpress_tokens(example, self.max_len, self.special_token)
552
  example["n_overflow"] = pu.calc_n_overflow(
553
  self.max_len,
554
  example["length"],
 
602
  layer_to_quant,
603
  self.pad_token_id,
604
  self.forward_batch_size,
605
+ token_gene_dict=self.token_gene_dict,
606
  summary_stat=None,
607
  silent=True,
608
  )
 
622
  layer_to_quant,
623
  self.pad_token_id,
624
  self.forward_batch_size,
625
+ token_gene_dict=self.token_gene_dict,
626
  summary_stat=None,
627
  silent=True,
628
  )
 
722
  cos_sims_dict = self.update_perturbation_dictionary(
723
  cos_sims_dict,
724
  cos_sims_data,
 
 
725
  gene_list,
726
  )
727
  else:
 
730
  cos_sims_dict[state] = self.update_perturbation_dictionary(
731
  cos_sims_dict[state],
732
  cos_sims_data[state],
 
 
733
  gene_list,
734
  )
735
  del minibatch
 
751
  f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
752
  )
753
 
754
+
755
+ def isp_perturb_set_special(
756
+ self,
757
+ model,
758
+ filtered_input_data: Dataset,
759
+ layer_to_quant: int,
760
+ output_path_prefix: str,
761
+ ):
762
+
763
+ def make_group_perturbation_batch(example):
764
+ example_input_ids = example["input_ids"]
765
+ example["tokens_to_perturb"] = self.tokens_to_perturb
766
+ indices_to_perturb = [
767
+ example_input_ids.index(token) if token in example_input_ids else None
768
+ for token in self.tokens_to_perturb
769
+ ]
770
+ indices_to_perturb = [
771
+ item for item in indices_to_perturb if item is not None
772
+ ]
773
+ if len(indices_to_perturb) > 0:
774
+ example["perturb_index"] = indices_to_perturb
775
+ else:
776
+ # -100 indicates tokens to overexpress are not present in rank value encoding
777
+ example["perturb_index"] = [-100]
778
+ if self.perturb_type == "delete":
779
+ example = pu.delete_indices(example)
780
+ elif self.perturb_type == "overexpress":
781
+ example = pu.overexpress_tokens(example, self.max_len, self.special_token)
782
+ example["n_overflow"] = pu.calc_n_overflow(
783
+ self.max_len,
784
+ example["length"],
785
+ self.tokens_to_perturb,
786
+ indices_to_perturb,
787
+ )
788
+ return example
789
+
790
+ total_batch_length = len(filtered_input_data)
791
+ if self.cell_states_to_model is None:
792
+ cos_sims_dict = defaultdict(list)
793
+ else:
794
+ cos_sims_dict = {
795
+ state: defaultdict(list)
796
+ for state in pu.get_possible_states(self.cell_states_to_model)
797
+ }
798
+
799
+ perturbed_data = filtered_input_data.map(
800
+ make_group_perturbation_batch, num_proc=self.nproc
801
+ )
802
+
803
+ if self.perturb_type == "overexpress":
804
+ filtered_input_data = filtered_input_data.add_column(
805
+ "n_overflow", perturbed_data["n_overflow"]
806
+ )
807
+ filtered_input_data = filtered_input_data.map(
808
+ pu.truncate_by_n_overflow_special, num_proc=self.nproc
809
+ )
810
+
811
+ if self.emb_mode == "cls_and_gene":
812
+ stored_gene_embs_dict = defaultdict(list)
813
+
814
+ # iterate through batches
815
+ for i in trange(0, total_batch_length, self.forward_batch_size):
816
+ max_range = min(i + self.forward_batch_size, total_batch_length)
817
+ inds_select = [i for i in range(i, max_range)]
818
+
819
+ minibatch = filtered_input_data.select(inds_select)
820
+ perturbation_batch = perturbed_data.select(inds_select)
821
+
822
+ ##### CLS Embedding Mode #####
823
+ if self.emb_mode == "cls":
824
+ indices_to_perturb = perturbation_batch["perturb_index"]
825
+
826
+ original_cls_emb = get_embs(
827
+ model,
828
+ minibatch,
829
+ "cls",
830
+ layer_to_quant,
831
+ self.pad_token_id,
832
+ self.forward_batch_size,
833
+ token_gene_dict=self.token_gene_dict,
834
+ summary_stat=None,
835
+ silent=True,
836
+ )
837
+
838
+ perturbation_cls_emb = get_embs(
839
+ model,
840
+ perturbation_batch,
841
+ "cls",
842
+ layer_to_quant,
843
+ self.pad_token_id,
844
+ self.forward_batch_size,
845
+ token_gene_dict=self.token_gene_dict,
846
+ summary_stat=None,
847
+ silent=True,
848
+ )
849
+
850
+ # Calculate the cosine similarities
851
+ cls_cos_sims = pu.quant_cos_sims(
852
+ perturbation_cls_emb,
853
+ original_cls_emb,
854
+ self.cell_states_to_model,
855
+ self.state_embs_dict,
856
+ emb_mode="cell")
857
+
858
+ # Update perturbation dictionary
859
+ if self.cell_states_to_model is None:
860
+ cos_sims_dict = self.update_perturbation_dictionary(
861
+ cos_sims_dict,
862
+ cls_cos_sims,
863
+ gene_list = None,
864
+ )
865
+ else:
866
+ for state in cos_sims_dict.keys():
867
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
868
+ cos_sims_dict[state],
869
+ cls_cos_sims[state],
870
+ gene_list = None,
871
+ )
872
+
873
+ ##### CLS and Gene Embedding Mode #####
874
+ elif self.emb_mode == "cls_and_gene":
875
+ full_original_emb = get_embs(
876
+ model,
877
+ minibatch,
878
+ "gene",
879
+ layer_to_quant,
880
+ self.pad_token_id,
881
+ self.forward_batch_size,
882
+ self.token_gene_dict,
883
+ summary_stat=None,
884
+ silent=True,
885
+ )
886
+ indices_to_perturb = perturbation_batch["perturb_index"]
887
+ # remove indices that were perturbed
888
+ original_emb = pu.remove_perturbed_indices_set(
889
+ full_original_emb,
890
+ self.perturb_type,
891
+ indices_to_perturb,
892
+ self.tokens_to_perturb,
893
+ minibatch["length"],
894
+ )
895
+ full_perturbation_emb = get_embs(
896
+ model,
897
+ perturbation_batch,
898
+ "gene",
899
+ layer_to_quant,
900
+ self.pad_token_id,
901
+ self.forward_batch_size,
902
+ self.token_gene_dict,
903
+ summary_stat=None,
904
+ silent=True,
905
+ )
906
+
907
+ # remove special tokens and padding
908
+ original_emb = original_emb[:, 1:-1, :]
909
+ if self.perturb_type == "overexpress":
910
+ perturbation_emb = full_perturbation_emb[:,1+len(self.tokens_to_perturb):-1,:]
911
+ elif self.perturb_type == "delete":
912
+ perturbation_emb = full_perturbation_emb[:,1:max(perturbation_batch["length"])-1,:]
913
+
914
+ n_perturbation_genes = perturbation_emb.size()[1]
915
+
916
+ gene_cos_sims = pu.quant_cos_sims(
917
+ perturbation_emb,
918
+ original_emb,
919
+ self.cell_states_to_model,
920
+ self.state_embs_dict,
921
+ emb_mode="gene",
922
+ )
923
+
924
+ # get cls emb
925
+ original_cls_emb = full_original_emb[:,0,:]
926
+ perturbation_cls_emb = full_perturbation_emb[:,0,:]
927
+
928
+ cls_cos_sims = pu.quant_cos_sims(
929
+ perturbation_cls_emb,
930
+ original_cls_emb,
931
+ self.cell_states_to_model,
932
+ self.state_embs_dict,
933
+ emb_mode="cell",
934
+ )
935
+
936
+ # get cosine similarities in gene embeddings
937
+ # since getting gene embeddings, need gene names
938
+
939
+ gene_list = minibatch["input_ids"]
940
+ # need to truncate gene_list
941
+ genes_to_exclude = self.tokens_to_perturb + [self.cls_token_id, self.eos_token_id]
942
+ gene_list = [
943
+ [g for g in genes if g not in genes_to_exclude][
944
+ :n_perturbation_genes
945
+ ]
946
+ for genes in gene_list
947
+ ]
948
+
949
+ for cell_i, genes in enumerate(gene_list):
950
+ for gene_j, affected_gene in enumerate(genes):
951
+ if len(self.genes_to_perturb) > 1:
952
+ tokens_to_perturb = tuple(self.tokens_to_perturb)
953
+ else:
954
+ tokens_to_perturb = self.tokens_to_perturb[0]
955
+
956
+ # fill in the gene cosine similarities
957
+ try:
958
+ stored_gene_embs_dict[
959
+ (tokens_to_perturb, affected_gene)
960
+ ].append(gene_cos_sims[cell_i, gene_j].item())
961
+ except KeyError:
962
+ stored_gene_embs_dict[
963
+ (tokens_to_perturb, affected_gene)
964
+ ] = gene_cos_sims[cell_i, gene_j].item()
965
+
966
+ if self.cell_states_to_model is None:
967
+ cos_sims_dict = self.update_perturbation_dictionary(
968
+ cos_sims_dict,
969
+ cls_cos_sims,
970
+ gene_list = None,
971
+ )
972
+ else:
973
+ for state in cos_sims_dict.keys():
974
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
975
+ cos_sims_dict[state],
976
+ cls_cos_sims[state],
977
+ gene_list = None,
978
+ )
979
+ del full_original_emb
980
+ del original_emb
981
+ del full_perturbation_emb
982
+ del perturbation_emb
983
+ del gene_cos_sims
984
+
985
+ del original_cls_emb
986
+ del perturbation_cls_emb
987
+ del cls_cos_sims
988
+ del minibatch
989
+ del perturbation_batch
990
+
991
+ torch.cuda.empty_cache()
992
+
993
+ pu.write_perturbation_dictionary(
994
+ cos_sims_dict,
995
+ f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
996
+ )
997
+
998
+ if self.emb_mode == "cls_and_gene":
999
+ pu.write_perturbation_dictionary(
1000
+ stored_gene_embs_dict,
1001
+ f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
1002
+ )
1003
+
1004
  def isp_perturb_all(
1005
  self,
1006
  model,
 
1028
  layer_to_quant,
1029
  self.pad_token_id,
1030
  self.forward_batch_size,
1031
+ self.token_gene_dict,
1032
  summary_stat=None,
1033
  silent=True,
1034
  )
 
1035
  # gene_list is used to assign cos sims back to genes
1036
  # need to remove the anchor gene
1037
  gene_list = example_cell["input_ids"][0][:]
 
1055
  layer_to_quant,
1056
  self.pad_token_id,
1057
  self.forward_batch_size,
1058
+ self.token_gene_dict,
1059
  summary_stat=None,
1060
  silent=True,
1061
  )
1062
 
1063
+ del perturbation_batch
1064
+
1065
  num_inds_perturbed = 1 + self.combos
1066
  # need to remove overexpressed gene to quantify cosine shifts
1067
  if self.perturb_type == "overexpress":
 
1073
  elif self.perturb_type == "delete":
1074
  perturbation_emb = full_perturbation_emb
1075
 
 
 
 
1076
 
1077
  if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
1078
+ original_batch = pu.make_comparison_batch(
1079
+ full_original_emb, indices_to_perturb, perturb_group=False
1080
+ )
1081
  gene_cos_sims = pu.quant_cos_sims(
1082
  perturbation_emb,
1083
  original_batch,
 
1085
  self.state_embs_dict,
1086
  emb_mode="gene",
1087
  )
1088
+ del original_batch
1089
+
1090
  if self.cell_states_to_model is not None:
1091
  original_cell_emb = pu.compute_nonpadded_cell_embedding(
1092
  full_original_emb, "mean_pool"
 
1102
  self.state_embs_dict,
1103
  emb_mode="cell",
1104
  )
1105
+ del original_cell_emb
1106
+ del perturbation_cell_emb
1107
 
1108
  if self.emb_mode == "cell_and_gene":
1109
  # remove perturbed index for gene list
 
1125
  (perturbed_gene, affected_gene)
1126
  ] = gene_cos_sims[perturbation_i, gene_j].item()
1127
 
1128
+ del full_original_emb
1129
+ del full_perturbation_emb
1130
+
1131
  if self.cell_states_to_model is None:
1132
  cos_sims_data = torch.mean(gene_cos_sims, dim=1)
1133
  cos_sims_dict = self.update_perturbation_dictionary(
1134
  cos_sims_dict,
1135
  cos_sims_data,
 
 
1136
  gene_list,
1137
  )
1138
  else:
 
1141
  cos_sims_dict[state] = self.update_perturbation_dictionary(
1142
  cos_sims_dict[state],
1143
  cos_sims_data[state],
 
 
1144
  gene_list,
1145
  )
1146
 
1147
  # save dict to disk every 100 cells
1148
+ if i % self.clear_mem_ncells/10 == 0:
1149
  pu.write_perturbation_dictionary(
1150
  cos_sims_dict,
1151
+ f"{output_path_prefix}_dict_cell_embs_batch{pickle_batch}",
1152
  )
1153
  if self.emb_mode == "cell_and_gene":
1154
  pu.write_perturbation_dictionary(
1155
  stored_gene_embs_dict,
1156
+ f"{output_path_prefix}_dict_gene_embs_batch{pickle_batch}",
1157
  )
1158
 
1159
  # reset and clear memory every 1000 cells
1160
+ if i % self.clear_mem_ncells == 0:
1161
  pickle_batch += 1
1162
  if self.cell_states_to_model is None:
1163
  cos_sims_dict = defaultdict(list)
 
1173
  torch.cuda.empty_cache()
1174
 
1175
  pu.write_perturbation_dictionary(
1176
+ cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_batch{pickle_batch}"
1177
  )
1178
 
1179
  if self.emb_mode == "cell_and_gene":
1180
  pu.write_perturbation_dictionary(
1181
  stored_gene_embs_dict,
1182
+ f"{output_path_prefix}_dict_gene_embs_batch{pickle_batch}",
1183
+ )
1184
+
1185
+
1186
+ def isp_perturb_all_special(
1187
+ self,
1188
+ model,
1189
+ filtered_input_data: Dataset,
1190
+ layer_to_quant: int,
1191
+ output_path_prefix: str,
1192
+ ):
1193
+ pickle_batch = -1
1194
+ if self.cell_states_to_model is None:
1195
+ cos_sims_dict = defaultdict(list)
1196
+ else:
1197
+ cos_sims_dict = {
1198
+ state: defaultdict(list)
1199
+ for state in pu.get_possible_states(self.cell_states_to_model)
1200
+ }
1201
+
1202
+ if self.emb_mode == "cls_and_gene":
1203
+ stored_gene_embs_dict = defaultdict(list)
1204
+
1205
+ num_inds_perturbed = 1 + self.combos
1206
+ for i in trange(len(filtered_input_data)):
1207
+ example_cell = filtered_input_data.select([i])
1208
+
1209
+ # gene_list is used to assign cos sims back to genes
1210
+ # need to remove the anchor gene and special tokens
1211
+ gene_list = example_cell["input_ids"][0][:]
1212
+
1213
+ for token in [self.cls_token_id, self.eos_token_id]:
1214
+ gene_list.remove(token)
1215
+
1216
+
1217
+ if self.anchor_token is not None:
1218
+ for token in self.anchor_token:
1219
+ gene_list.remove(token)
1220
+ else:
1221
+ if self.perturb_type == "overexpress":
1222
+ gene_list = gene_list[
1223
+ num_inds_perturbed:
1224
+ ] # index 0 is not overexpressed
1225
+
1226
+ perturbation_batch, indices_to_perturb = pu.make_perturbation_batch_special(
1227
+ example_cell,
1228
+ self.perturb_type,
1229
+ self.tokens_to_perturb,
1230
+ self.anchor_token,
1231
+ self.combos,
1232
+ self.nproc,
1233
+ )
1234
+
1235
+ ##### CLS Embedding Mode #####
1236
+ if self.emb_mode == "cls":
1237
+ # Extract cls embeddings from original and perturbed cells
1238
+ perturbation_cls_emb = get_embs(
1239
+ model,
1240
+ perturbation_batch,
1241
+ "cls",
1242
+ layer_to_quant,
1243
+ self.pad_token_id,
1244
+ self.forward_batch_size,
1245
+ self.token_gene_dict,
1246
+ summary_stat=None,
1247
+ silent=True,
1248
+ )
1249
+ original_cls_emb = get_embs(
1250
+ model,
1251
+ example_cell,
1252
+ "cls",
1253
+ layer_to_quant,
1254
+ self.pad_token_id,
1255
+ self.forward_batch_size,
1256
+ self.token_gene_dict,
1257
+ summary_stat=None,
1258
+ silent=True,
1259
+ )
1260
+
1261
+ # Calculate cosine similarities
1262
+ cls_cos_sims = pu.quant_cos_sims(
1263
+ perturbation_cls_emb,
1264
+ original_cls_emb,
1265
+ self.cell_states_to_model,
1266
+ self.state_embs_dict,
1267
+ emb_mode="cell",
1268
+ )
1269
+
1270
+ if self.cell_states_to_model is None:
1271
+ cos_sims_dict = self.update_perturbation_dictionary(
1272
+ cos_sims_dict,
1273
+ cls_cos_sims,
1274
+ gene_list,
1275
+ )
1276
+ else:
1277
+
1278
+ for state in cos_sims_dict.keys():
1279
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1280
+ cos_sims_dict[state],
1281
+ cls_cos_sims[state],
1282
+ gene_list,
1283
+ )
1284
+
1285
+ del perturbation_batch
1286
+ del original_cls_emb
1287
+ del perturbation_cls_emb
1288
+ del cls_cos_sims
1289
+
1290
+ ##### CLS and Gene Embedding Mode #####
1291
+ elif self.emb_mode == "cls_and_gene":
1292
+ full_perturbation_emb = get_embs(
1293
+ model,
1294
+ perturbation_batch,
1295
+ "gene",
1296
+ layer_to_quant,
1297
+ self.pad_token_id,
1298
+ self.forward_batch_size,
1299
+ self.token_gene_dict,
1300
+ summary_stat=None,
1301
+ silent=True,
1302
+ )
1303
+
1304
+ # need to remove overexpressed gene and cls/eos to quantify cosine shifts
1305
+ if self.perturb_type == "overexpress":
1306
+ perturbation_emb = full_perturbation_emb[:, 1+num_inds_perturbed:-1, :].clone().detach()
1307
+ elif self.perturb_type == "delete":
1308
+ perturbation_emb = full_perturbation_emb[:, 1:-1, :].clone().detach()
1309
+
1310
+ full_original_emb = get_embs(
1311
+ model,
1312
+ example_cell,
1313
+ "gene",
1314
+ layer_to_quant,
1315
+ self.pad_token_id,
1316
+ self.forward_batch_size,
1317
+ self.token_gene_dict,
1318
+ summary_stat=None,
1319
+ silent=True,
1320
+ )
1321
+
1322
+ original_batch = pu.make_comparison_batch(
1323
+ full_original_emb, indices_to_perturb, perturb_group=False
1324
+ )
1325
+
1326
+ original_batch = original_batch[:, 1:-1, :].clone().detach()
1327
+ gene_cos_sims = pu.quant_cos_sims(
1328
+ perturbation_emb,
1329
+ original_batch,
1330
+ self.cell_states_to_model,
1331
+ self.state_embs_dict,
1332
+ emb_mode="gene",
1333
+ )
1334
+
1335
+ # remove perturbed index for gene list
1336
+ perturbed_gene_dict = {
1337
+ gene: gene_list[:i] + gene_list[i + 1 :]
1338
+ for i, gene in enumerate(gene_list)
1339
+ }
1340
+
1341
+ for perturbation_i, perturbed_gene in enumerate(gene_list):
1342
+ for gene_j, affected_gene in enumerate(
1343
+ perturbed_gene_dict[perturbed_gene]
1344
+ ):
1345
+ try:
1346
+ stored_gene_embs_dict[
1347
+ (perturbed_gene, affected_gene)
1348
+ ].append(gene_cos_sims[perturbation_i, gene_j].item())
1349
+ except KeyError:
1350
+ stored_gene_embs_dict[
1351
+ (perturbed_gene, affected_gene)
1352
+ ] = gene_cos_sims[perturbation_i, gene_j].item()
1353
+
1354
+ # get cls emb
1355
+ original_cls_emb = full_original_emb[:,0,:].clone().detach()
1356
+ perturbation_cls_emb = full_perturbation_emb[:,0,:].clone().detach()
1357
+
1358
+ cls_cos_sims = pu.quant_cos_sims(
1359
+ perturbation_cls_emb,
1360
+ original_cls_emb,
1361
+ self.cell_states_to_model,
1362
+ self.state_embs_dict,
1363
+ emb_mode="cell",
1364
+ )
1365
+
1366
+ if self.cell_states_to_model is None:
1367
+ cos_sims_dict = self.update_perturbation_dictionary(
1368
+ cos_sims_dict,
1369
+ cls_cos_sims,
1370
+ gene_list,
1371
+ )
1372
+ else:
1373
+ for state in cos_sims_dict.keys():
1374
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1375
+ cos_sims_dict[state],
1376
+ cls_cos_sims[state],
1377
+ gene_list,
1378
+ )
1379
+
1380
+ del perturbation_batch
1381
+ del original_batch
1382
+ del full_original_emb
1383
+ del full_perturbation_emb
1384
+ del perturbation_emb
1385
+ del original_cls_emb
1386
+ del perturbation_cls_emb
1387
+ del cls_cos_sims
1388
+ del gene_cos_sims
1389
+
1390
+ # save dict to disk every self.clear_mem_ncells/10 (default 100) cells
1391
+ if i % max(1,self.clear_mem_ncells/10) == 0:
1392
+ pu.write_perturbation_dictionary(
1393
+ cos_sims_dict,
1394
+ f"{output_path_prefix}_dict_cell_embs_batch{pickle_batch}",
1395
+ )
1396
+ if self.emb_mode == "cls_and_gene":
1397
+ pu.write_perturbation_dictionary(
1398
+ stored_gene_embs_dict,
1399
+ f"{output_path_prefix}_dict_gene_embs_batch{pickle_batch}",
1400
+ )
1401
+
1402
+ # reset and clear memory every self.clear_mem_ncells (default 1000) cells
1403
+ if i % self.clear_mem_ncells == 0:
1404
+ pickle_batch += 1
1405
+ if self.cell_states_to_model is None:
1406
+ cos_sims_dict = defaultdict(list)
1407
+ else:
1408
+ cos_sims_dict = {
1409
+ state: defaultdict(list)
1410
+ for state in pu.get_possible_states(self.cell_states_to_model)
1411
+ }
1412
+
1413
+ if self.emb_mode == "cls_and_gene":
1414
+ stored_gene_embs_dict = defaultdict(list)
1415
+
1416
+ torch.cuda.empty_cache()
1417
+
1418
+ pu.write_perturbation_dictionary(
1419
+ cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_batch{pickle_batch}"
1420
+ )
1421
+
1422
+ if self.emb_mode == "cls_and_gene":
1423
+ pu.write_perturbation_dictionary(
1424
+ stored_gene_embs_dict,
1425
+ f"{output_path_prefix}_dict_gene_embs_batch{pickle_batch}",
1426
  )
1427
 
1428
+
1429
  def update_perturbation_dictionary(
1430
  self,
1431
  cos_sims_dict: defaultdict,
1432
  cos_sims_data: torch.Tensor,
 
 
1433
  gene_list=None,
1434
  ):
1435
  if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
1436
  logger.error(
1437
  f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
1438
+ {cos_sims_data.shape[0]=}.\n \
1439
+ {len(gene_list)=}."
1440
  )
1441
  raise
1442
 
 
1460
  for i, cos in enumerate(cos_sims_data.tolist()):
1461
  cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
1462
 
1463
+ return cos_sims_dict
geneformer/perturber_utils.py CHANGED
@@ -23,8 +23,6 @@ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
23
  ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
24
 
25
 
26
- sns.set()
27
-
28
  logger = logging.getLogger(__name__)
29
 
30
 
@@ -156,8 +154,12 @@ def quant_layers(model):
156
  return int(max(layer_nums)) + 1
157
 
158
 
 
 
 
 
159
  def get_model_input_size(model):
160
- return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
161
 
162
 
163
  def flatten_list(megalist):
@@ -222,27 +224,47 @@ def overexpress_indices(example):
222
  indices = example["perturb_index"]
223
  if any(isinstance(el, list) for el in indices):
224
  indices = flatten_list(indices)
225
- for index in sorted(indices, reverse=True):
226
- example["input_ids"].insert(0, example["input_ids"].pop(index))
227
-
 
228
  example["length"] = len(example["input_ids"])
229
  return example
230
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
233
- def overexpress_tokens(example, max_len):
234
  # -100 indicates tokens to overexpress are not present in rank value encoding
235
  if example["perturb_index"] != [-100]:
236
  example = delete_indices(example)
237
- [
238
- example["input_ids"].insert(0, token)
239
- for token in example["tokens_to_perturb"][::-1]
240
- ]
 
 
 
 
 
 
241
 
242
  # truncate to max input size, must also truncate original emb to be comparable
243
  if len(example["input_ids"]) > max_len:
244
- example["input_ids"] = example["input_ids"][0:max_len]
245
-
 
 
246
  example["length"] = len(example["input_ids"])
247
  return example
248
 
@@ -259,6 +281,13 @@ def truncate_by_n_overflow(example):
259
  example["length"] = len(example["input_ids"])
260
  return example
261
 
 
 
 
 
 
 
 
262
 
263
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
264
  # indices_to_remove is list of indices to remove
@@ -392,7 +421,81 @@ def make_perturbation_batch(
392
  return perturbation_dataset, indices_to_perturb
393
 
394
 
395
- # perturbed cell emb removing the activated/overexpressed/inhibited gene emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  # so that only non-perturbed gene embeddings are compared to each other
397
  # in original or perturbed context
398
  def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
@@ -589,9 +692,10 @@ def quant_cos_sims(
589
  cos = torch.nn.CosineSimilarity(dim=1)
590
 
591
  # if emb_mode == "gene", can only calculate gene cos sims
592
- # against original cell anyways
593
  if cell_states_to_model is None or emb_mode == "gene":
594
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
 
595
  elif cell_states_to_model is not None and emb_mode == "cell":
596
  possible_states = get_possible_states(cell_states_to_model)
597
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
@@ -758,4 +862,4 @@ class GeneIdHandler:
758
  return self.ens_to_symbol(self.token_to_ens(token))
759
 
760
  def symbol_to_token(self, symbol):
761
- return self.ens_to_token(self.symbol_to_ens(symbol))
 
23
  ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
24
 
25
 
 
 
26
  logger = logging.getLogger(__name__)
27
 
28
 
 
154
  return int(max(layer_nums)) + 1
155
 
156
 
157
+ def get_model_emb_dims(model):
158
+ return model.config.hidden_size
159
+
160
+
161
  def get_model_input_size(model):
162
+ return model.config.max_position_embeddings
163
 
164
 
165
  def flatten_list(megalist):
 
224
  indices = example["perturb_index"]
225
  if any(isinstance(el, list) for el in indices):
226
  indices = flatten_list(indices)
227
+ insert_pos = 0
228
+ for index in sorted(indices, reverse=False):
229
+ example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
230
+ insert_pos += 1
231
  example["length"] = len(example["input_ids"])
232
  return example
233
 
234
+ # if CLS token present, move to 1st rather than 0th position
235
+ def overexpress_indices_special(example):
236
+ indices = example["perturb_index"]
237
+ if any(isinstance(el, list) for el in indices):
238
+ indices = flatten_list(indices)
239
+ insert_pos = 1 # Insert starting after CLS token
240
+ for index in sorted(indices, reverse=False):
241
+ example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
242
+ insert_pos += 1
243
+ example["length"] = len(example["input_ids"])
244
+ return example
245
 
246
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
247
+ def overexpress_tokens(example, max_len, special_token):
248
  # -100 indicates tokens to overexpress are not present in rank value encoding
249
  if example["perturb_index"] != [-100]:
250
  example = delete_indices(example)
251
+ if special_token:
252
+ [
253
+ example["input_ids"].insert(1, token)
254
+ for token in example["tokens_to_perturb"][::-1]
255
+ ]
256
+ else:
257
+ [
258
+ example["input_ids"].insert(0, token)
259
+ for token in example["tokens_to_perturb"][::-1]
260
+ ]
261
 
262
  # truncate to max input size, must also truncate original emb to be comparable
263
  if len(example["input_ids"]) > max_len:
264
+ if special_token:
265
+ example["input_ids"] = example["input_ids"][0:max_len-1]+[example["input_ids"][-1]]
266
+ else:
267
+ example["input_ids"] = example["input_ids"][0:max_len]
268
  example["length"] = len(example["input_ids"])
269
  return example
270
 
 
281
  example["length"] = len(example["input_ids"])
282
  return example
283
 
284
+ def truncate_by_n_overflow_special(example):
285
+ if example["n_overflow"] > 0:
286
+ new_max_len = example["length"] - example["n_overflow"]
287
+ example["input_ids"] = example["input_ids"][0:new_max_len-1]+[example["input_ids"][-1]]
288
+ example["length"] = len(example["input_ids"])
289
+ return example
290
+
291
 
292
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
293
  # indices_to_remove is list of indices to remove
 
421
  return perturbation_dataset, indices_to_perturb
422
 
423
 
424
+ def make_perturbation_batch_special(
425
+ example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
426
+ ) -> tuple[Dataset, List[int]]:
427
+ if combo_lvl == 0 and tokens_to_perturb == "all":
428
+ if perturb_type in ["overexpress", "activate"]:
429
+ range_start = 1
430
+ elif perturb_type in ["delete", "inhibit"]:
431
+ range_start = 0
432
+ range_start += 1 # Starting after the CLS token
433
+ indices_to_perturb = [
434
+ [i] for i in range(range_start, example_cell["length"][0]-1) # And excluding the EOS token
435
+ ]
436
+
437
+ # elif combo_lvl > 0 and anchor_token is None:
438
+ ## to implement
439
+ elif combo_lvl > 0 and (anchor_token is not None):
440
+ example_input_ids = example_cell["input_ids"][0]
441
+ anchor_index = example_input_ids.index(anchor_token[0])
442
+ indices_to_perturb = [
443
+ sorted([anchor_index, i]) if i != anchor_index else None
444
+ for i in range(1, example_cell["length"][0]-1) # Exclude CLS and EOS tokens
445
+ ]
446
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
447
+ else:
448
+ example_input_ids = example_cell["input_ids"][0]
449
+ indices_to_perturb = [
450
+ [example_input_ids.index(token)] if token in example_input_ids else None
451
+ for token in tokens_to_perturb
452
+ ]
453
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
454
+
455
+ # create all permutations of combo_lvl of modifiers from tokens_to_perturb
456
+ if combo_lvl > 0 and (anchor_token is None):
457
+ if tokens_to_perturb != "all":
458
+ if len(tokens_to_perturb) == combo_lvl + 1:
459
+ indices_to_perturb = [
460
+ list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
461
+ ]
462
+ else:
463
+ all_indices = [[i] for i in range(1, example_cell["length"][0]-1)] # Exclude CLS and EOS tokens
464
+ all_indices = [
465
+ index for index in all_indices if index not in indices_to_perturb
466
+ ]
467
+ indices_to_perturb = [
468
+ [[j for i in indices_to_perturb for j in i], x] for x in all_indices
469
+ ]
470
+
471
+ length = len(indices_to_perturb)
472
+ perturbation_dataset = Dataset.from_dict(
473
+ {
474
+ "input_ids": example_cell["input_ids"] * length,
475
+ "perturb_index": indices_to_perturb,
476
+ }
477
+ )
478
+
479
+ if length < 400:
480
+ num_proc_i = 1
481
+ else:
482
+ num_proc_i = num_proc
483
+
484
+ if perturb_type == "delete":
485
+ perturbation_dataset = perturbation_dataset.map(
486
+ delete_indices, num_proc=num_proc_i
487
+ )
488
+ elif perturb_type == "overexpress":
489
+ perturbation_dataset = perturbation_dataset.map(
490
+ overexpress_indices_special, num_proc=num_proc_i
491
+ )
492
+
493
+ perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
494
+
495
+ return perturbation_dataset, indices_to_perturb
496
+
497
+
498
+ # original cell emb removing the activated/overexpressed/inhibited gene emb
499
  # so that only non-perturbed gene embeddings are compared to each other
500
  # in original or perturbed context
501
  def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
 
692
  cos = torch.nn.CosineSimilarity(dim=1)
693
 
694
  # if emb_mode == "gene", can only calculate gene cos sims
695
+ # against original cell
696
  if cell_states_to_model is None or emb_mode == "gene":
697
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
698
+
699
  elif cell_states_to_model is not None and emb_mode == "cell":
700
  possible_states = get_possible_states(cell_states_to_model)
701
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
 
862
  return self.ens_to_symbol(self.token_to_ens(token))
863
 
864
  def symbol_to_token(self, symbol):
865
+ return self.ens_to_token(self.symbol_to_ens(symbol))