Update geneformer/in_silico_perturber.py

#355
by hchen725 - opened
geneformer/in_silico_perturber.py CHANGED
@@ -66,7 +66,7 @@ class InSilicoPerturber:
66
  "anchor_gene": {None, str},
67
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
68
  "num_classes": {int},
69
- "emb_mode": {"cell", "cell_and_gene"},
70
  "cell_emb_style": {"mean_pool"},
71
  "filter_data": {None, dict},
72
  "cell_states_to_model": {None, dict},
@@ -74,6 +74,7 @@ class InSilicoPerturber:
74
  "max_ncells": {None, int},
75
  "cell_inds_to_perturb": {"all", dict},
76
  "emb_layer": {-1, 0},
 
77
  "forward_batch_size": {int},
78
  "nproc": {int},
79
  }
@@ -97,7 +98,7 @@ class InSilicoPerturber:
97
  emb_layer=-1,
98
  forward_batch_size=100,
99
  nproc=4,
100
- token_dictionary_file=TOKEN_DICTIONARY_FILE,
101
  ):
102
  """
103
  Initialize in silico perturber.
@@ -137,11 +138,11 @@ class InSilicoPerturber:
137
  num_classes : int
138
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
139
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
140
- emb_mode : {"cell", "cell_and_gene"}
141
- | Whether to output impact of perturbation on cell and/or gene embeddings.
142
  | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
143
  cell_emb_style : "mean_pool"
144
- | Method for summarizing cell embeddings.
145
  | Currently only option is mean pooling of gene embeddings for given cell.
146
  filter_data : None, dict
147
  | Default is to use all input data for in silico perturbation study.
@@ -222,15 +223,32 @@ class InSilicoPerturber:
222
  self.emb_layer = emb_layer
223
  self.forward_batch_size = forward_batch_size
224
  self.nproc = nproc
 
225
 
226
  self.validate_options()
227
 
228
  # load token dictionary (Ensembl IDs:token)
 
 
229
  with open(token_dictionary_file, "rb") as f:
230
  self.gene_token_dict = pickle.load(f)
 
231
 
232
  self.pad_token_id = self.gene_token_dict.get("<pad>")
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  if self.anchor_gene is None:
235
  self.anchor_token = None
236
  else:
@@ -287,7 +305,7 @@ class InSilicoPerturber:
287
  continue
288
  valid_type = False
289
  for option in valid_options:
290
- if (option in [bool, int, list, dict]) and isinstance(
291
  attr_value, option
292
  ):
293
  valid_type = True
@@ -428,12 +446,21 @@ class InSilicoPerturber:
428
  self.max_len = pu.get_model_input_size(model)
429
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
430
 
431
-
432
  ### filter input data ###
433
  # general filtering of input data based on filter_data argument
434
  filtered_input_data = pu.load_and_filter(
435
  self.filter_data, self.nproc, input_data_file
436
  )
 
 
 
 
 
 
 
 
 
 
437
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
438
 
439
  if self.perturb_group is True:
@@ -506,7 +533,7 @@ class InSilicoPerturber:
506
  if self.perturb_type == "delete":
507
  example = pu.delete_indices(example)
508
  elif self.perturb_type == "overexpress":
509
- example = pu.overexpress_tokens(example, self.max_len)
510
  example["n_overflow"] = pu.calc_n_overflow(
511
  self.max_len,
512
  example["length"],
@@ -527,7 +554,6 @@ class InSilicoPerturber:
527
  perturbed_data = filtered_input_data.map(
528
  make_group_perturbation_batch, num_proc=self.nproc
529
  )
530
-
531
  if self.perturb_type == "overexpress":
532
  filtered_input_data = filtered_input_data.add_column(
533
  "n_overflow", perturbed_data["n_overflow"]
@@ -537,9 +563,14 @@ class InSilicoPerturber:
537
  # then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
538
  # (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
539
  # rather than only adding 2048)
540
- filtered_input_data = filtered_input_data.map(
541
- pu.truncate_by_n_overflow, num_proc=self.nproc
542
- )
 
 
 
 
 
543
 
544
  if self.emb_mode == "cell_and_gene":
545
  stored_gene_embs_dict = defaultdict(list)
@@ -560,6 +591,7 @@ class InSilicoPerturber:
560
  layer_to_quant,
561
  self.pad_token_id,
562
  self.forward_batch_size,
 
563
  summary_stat=None,
564
  silent=True,
565
  )
@@ -579,24 +611,32 @@ class InSilicoPerturber:
579
  layer_to_quant,
580
  self.pad_token_id,
581
  self.forward_batch_size,
 
582
  summary_stat=None,
583
  silent=True,
584
  )
585
 
586
- # remove overexpressed genes
 
 
 
 
 
 
 
 
587
  if self.perturb_type == "overexpress":
588
  perturbation_emb = full_perturbation_emb[
589
- :, len(self.tokens_to_perturb) :, :
590
  ]
591
-
592
  elif self.perturb_type == "delete":
593
  perturbation_emb = full_perturbation_emb[
594
- :, : max(perturbation_batch["length"]), :
595
  ]
596
 
597
  n_perturbation_genes = perturbation_emb.size()[1]
598
 
599
- # if no goal states, the cosine similarties are the mean of gene cosine similarities
600
  if (
601
  self.cell_states_to_model is None
602
  or self.emb_mode == "cell_and_gene"
@@ -611,16 +651,22 @@ class InSilicoPerturber:
611
 
612
  # if there are goal states, the cosine similarities are the cell cosine similarities
613
  if self.cell_states_to_model is not None:
614
- original_cell_emb = pu.mean_nonpadding_embs(
615
- full_original_emb,
616
- torch.tensor(minibatch["length"], device="cuda"),
617
- dim=1,
618
- )
619
- perturbation_cell_emb = pu.mean_nonpadding_embs(
620
- full_perturbation_emb,
621
- torch.tensor(perturbation_batch["length"], device="cuda"),
622
- dim=1,
623
- )
 
 
 
 
 
 
624
  cell_cos_sims = pu.quant_cos_sims(
625
  perturbation_cell_emb,
626
  original_cell_emb,
@@ -640,6 +686,9 @@ class InSilicoPerturber:
640
  ]
641
  for genes in gene_list
642
  ]
 
 
 
643
 
644
  for cell_i, genes in enumerate(gene_list):
645
  for gene_j, affected_gene in enumerate(genes):
@@ -672,9 +721,21 @@ class InSilicoPerturber:
672
  ]
673
  else:
674
  nonpadding_lens = perturbation_batch["length"]
675
- cos_sims_data = pu.mean_nonpadding_embs(
676
- gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
677
- )
 
 
 
 
 
 
 
 
 
 
 
 
678
  cos_sims_dict = self.update_perturbation_dictionary(
679
  cos_sims_dict,
680
  cos_sims_data,
@@ -694,9 +755,15 @@ class InSilicoPerturber:
694
  )
695
  del minibatch
696
  del perturbation_batch
 
697
  del original_emb
 
698
  del perturbation_emb
699
  del cos_sims_data
 
 
 
 
700
 
701
  torch.cuda.empty_cache()
702
 
@@ -738,6 +805,7 @@ class InSilicoPerturber:
738
  layer_to_quant,
739
  self.pad_token_id,
740
  self.forward_batch_size,
 
741
  summary_stat=None,
742
  silent=True,
743
  )
@@ -756,6 +824,7 @@ class InSilicoPerturber:
756
  self.anchor_token,
757
  self.combos,
758
  self.nproc,
 
759
  )
760
 
761
  full_perturbation_emb = get_embs(
@@ -765,21 +834,28 @@ class InSilicoPerturber:
765
  layer_to_quant,
766
  self.pad_token_id,
767
  self.forward_batch_size,
 
768
  summary_stat=None,
769
  silent=True,
770
  )
771
 
772
  num_inds_perturbed = 1 + self.combos
773
- # need to remove overexpressed gene to quantify cosine shifts
 
 
 
 
774
  if self.perturb_type == "overexpress":
775
- perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
776
  gene_list = gene_list[
777
- num_inds_perturbed:
778
- ] # index 0 is not overexpressed
779
 
780
  elif self.perturb_type == "delete":
781
- perturbation_emb = full_perturbation_emb
 
782
 
 
783
  original_batch = pu.make_comparison_batch(
784
  full_original_emb, indices_to_perturb, perturb_group=False
785
  )
@@ -792,13 +868,19 @@ class InSilicoPerturber:
792
  self.state_embs_dict,
793
  emb_mode="gene",
794
  )
 
795
  if self.cell_states_to_model is not None:
796
- original_cell_emb = pu.compute_nonpadded_cell_embedding(
797
- full_original_emb, "mean_pool"
798
- )
799
- perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
800
- full_perturbation_emb, "mean_pool"
801
- )
 
 
 
 
 
802
 
803
  cell_cos_sims = pu.quant_cos_sims(
804
  perturbation_cell_emb,
@@ -829,7 +911,18 @@ class InSilicoPerturber:
829
  ] = gene_cos_sims[perturbation_i, gene_j].item()
830
 
831
  if self.cell_states_to_model is None:
832
- cos_sims_data = torch.mean(gene_cos_sims, dim=1)
 
 
 
 
 
 
 
 
 
 
 
833
  cos_sims_dict = self.update_perturbation_dictionary(
834
  cos_sims_dict,
835
  cos_sims_data,
@@ -922,4 +1015,4 @@ class InSilicoPerturber:
922
  for i, cos in enumerate(cos_sims_data.tolist()):
923
  cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
924
 
925
- return cos_sims_dict
 
66
  "anchor_gene": {None, str},
67
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
68
  "num_classes": {int},
69
+ "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
70
  "cell_emb_style": {"mean_pool"},
71
  "filter_data": {None, dict},
72
  "cell_states_to_model": {None, dict},
 
74
  "max_ncells": {None, int},
75
  "cell_inds_to_perturb": {"all", dict},
76
  "emb_layer": {-1, 0},
77
+ "token_dictionary_file" : {None, str},
78
  "forward_batch_size": {int},
79
  "nproc": {int},
80
  }
 
98
  emb_layer=-1,
99
  forward_batch_size=100,
100
  nproc=4,
101
+ token_dictionary_file=None,
102
  ):
103
  """
104
  Initialize in silico perturber.
 
138
  num_classes : int
139
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
140
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
141
+ emb_mode : {"cls", "cell", "cls_and_gene","cell_and_gene"}
142
+ | Whether to output impact of perturbation on CLS token, cell, and/or gene embeddings.
143
  | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
144
  cell_emb_style : "mean_pool"
145
+ | Method for summarizing cell embeddings if not using CLS token.
146
  | Currently only option is mean pooling of gene embeddings for given cell.
147
  filter_data : None, dict
148
  | Default is to use all input data for in silico perturbation study.
 
223
  self.emb_layer = emb_layer
224
  self.forward_batch_size = forward_batch_size
225
  self.nproc = nproc
226
+ self.token_dictionary_file = token_dictionary_file
227
 
228
  self.validate_options()
229
 
230
  # load token dictionary (Ensembl IDs:token)
231
+ if self.token_dictionary_file is None:
232
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
233
  with open(token_dictionary_file, "rb") as f:
234
  self.gene_token_dict = pickle.load(f)
235
+ self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
236
 
237
  self.pad_token_id = self.gene_token_dict.get("<pad>")
238
 
239
+
240
+ # Identify if special token is present in the token dictionary
241
+ lowercase_token_gene_dict = {k: v.lower() for k, v in self.token_gene_dict.items()}
242
+ cls_present = any("cls" in value for value in lowercase_token_gene_dict.values())
243
+ eos_present = any("eos" in value for value in lowercase_token_gene_dict.values())
244
+ if cls_present or eos_present:
245
+ self.special_token = True
246
+ else:
247
+ if "cls" in self.emb_mode:
248
+ logger.error(f"emb_mode set to {self.emb_mode} but <cls> token not in token dictionary.")
249
+ raise
250
+ self.special_token = False
251
+
252
  if self.anchor_gene is None:
253
  self.anchor_token = None
254
  else:
 
305
  continue
306
  valid_type = False
307
  for option in valid_options:
308
+ if (option in [bool, int, list, dict, str]) and isinstance(
309
  attr_value, option
310
  ):
311
  valid_type = True
 
446
  self.max_len = pu.get_model_input_size(model)
447
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
448
 
 
449
  ### filter input data ###
450
  # general filtering of input data based on filter_data argument
451
  filtered_input_data = pu.load_and_filter(
452
  self.filter_data, self.nproc, input_data_file
453
  )
454
+
455
+ # Ensure emb_mode is cls if first token of the filtered input data is cls token
456
+ if self.special_token:
457
+ cls_token_id = self.gene_token_dict["<cls>"]
458
+ if (filtered_input_data["input_ids"][0][0] == cls_token_id) and ("cls" not in self.emb_mode):
459
+ logger.error(
460
+ "Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
461
+ )
462
+ raise
463
+
464
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
465
 
466
  if self.perturb_group is True:
 
533
  if self.perturb_type == "delete":
534
  example = pu.delete_indices(example)
535
  elif self.perturb_type == "overexpress":
536
+ example = pu.overexpress_tokens(example, self.max_len, self.special_token)
537
  example["n_overflow"] = pu.calc_n_overflow(
538
  self.max_len,
539
  example["length"],
 
554
  perturbed_data = filtered_input_data.map(
555
  make_group_perturbation_batch, num_proc=self.nproc
556
  )
 
557
  if self.perturb_type == "overexpress":
558
  filtered_input_data = filtered_input_data.add_column(
559
  "n_overflow", perturbed_data["n_overflow"]
 
563
  # then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
564
  # (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
565
  # rather than only adding 2048)
566
+ if self.special_token:
567
+ filtered_input_data = filtered_input_data.map(
568
+ pu.truncate_by_n_overflow_special, num_proc=self.nproc
569
+ )
570
+ else:
571
+ filtered_input_data = filtered_input_data.map(
572
+ pu.truncate_by_n_overflow, num_proc=self.nproc
573
+ )
574
 
575
  if self.emb_mode == "cell_and_gene":
576
  stored_gene_embs_dict = defaultdict(list)
 
591
  layer_to_quant,
592
  self.pad_token_id,
593
  self.forward_batch_size,
594
+ self.token_gene_dict,
595
  summary_stat=None,
596
  silent=True,
597
  )
 
611
  layer_to_quant,
612
  self.pad_token_id,
613
  self.forward_batch_size,
614
+ self.token_gene_dict,
615
  summary_stat=None,
616
  silent=True,
617
  )
618
 
619
+ if "cls" not in self.emb_mode:
620
+ start = 0
621
+ else:
622
+ start = 1
623
+
624
+ # remove overexpressed genes and cls
625
+ original_emb = original_emb[
626
+ :, start :, :
627
+ ]
628
  if self.perturb_type == "overexpress":
629
  perturbation_emb = full_perturbation_emb[
630
+ :, start+len(self.tokens_to_perturb) :, :
631
  ]
 
632
  elif self.perturb_type == "delete":
633
  perturbation_emb = full_perturbation_emb[
634
+ :, start : max(perturbation_batch["length"]), :
635
  ]
636
 
637
  n_perturbation_genes = perturbation_emb.size()[1]
638
 
639
+ # if no goal states, the cosine similarities are the mean of gene cosine similarities
640
  if (
641
  self.cell_states_to_model is None
642
  or self.emb_mode == "cell_and_gene"
 
651
 
652
  # if there are goal states, the cosine similarities are the cell cosine similarities
653
  if self.cell_states_to_model is not None:
654
+ if "cls" not in self.emb_mode:
655
+ original_cell_emb = pu.mean_nonpadding_embs(
656
+ full_original_emb,
657
+ torch.tensor(minibatch["length"], device="cuda"),
658
+ dim=1,
659
+ )
660
+ perturbation_cell_emb = pu.mean_nonpadding_embs(
661
+ full_perturbation_emb,
662
+ torch.tensor(perturbation_batch["length"], device="cuda"),
663
+ dim=1,
664
+ )
665
+ else:
666
+ # get cls emb
667
+ original_cell_emb = full_original_emb[:,0,:]
668
+ perturbation_cell_emb = full_perturbation_emb[:,0,:]
669
+
670
  cell_cos_sims = pu.quant_cos_sims(
671
  perturbation_cell_emb,
672
  original_cell_emb,
 
686
  ]
687
  for genes in gene_list
688
  ]
689
+ # remove CLS if present
690
+ if "cls" in self.emb_mode:
691
+ gene_list = gene_list[1:]
692
 
693
  for cell_i, genes in enumerate(gene_list):
694
  for gene_j, affected_gene in enumerate(genes):
 
721
  ]
722
  else:
723
  nonpadding_lens = perturbation_batch["length"]
724
+ if "cls" not in self.emb_mode:
725
+ cos_sims_data = pu.mean_nonpadding_embs(
726
+ gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
727
+ )
728
+ else:
729
+ original_cls_emb = full_original_emb[:,0,:]
730
+ perturbation_cls_emb = full_perturbation_emb[:,0,:]
731
+ cos_sims_data = pu.quant_cos_sims(
732
+ perturbation_cls_emb,
733
+ original_cls_emb,
734
+ self.cell_states_to_model,
735
+ self.state_embs_dict,
736
+ emb_mode="cell",
737
+ )
738
+
739
  cos_sims_dict = self.update_perturbation_dictionary(
740
  cos_sims_dict,
741
  cos_sims_data,
 
755
  )
756
  del minibatch
757
  del perturbation_batch
758
+ del full_original_emb
759
  del original_emb
760
+ del full_perturbation_emb
761
  del perturbation_emb
762
  del cos_sims_data
763
+ if "cls" in self.emb_mode:
764
+ del original_cls_emb
765
+ del perturbation_cls_emb
766
+ del cls_cos_sims
767
 
768
  torch.cuda.empty_cache()
769
 
 
805
  layer_to_quant,
806
  self.pad_token_id,
807
  self.forward_batch_size,
808
+ self.token_gene_dict,
809
  summary_stat=None,
810
  silent=True,
811
  )
 
824
  self.anchor_token,
825
  self.combos,
826
  self.nproc,
827
+ self.special_token,
828
  )
829
 
830
  full_perturbation_emb = get_embs(
 
834
  layer_to_quant,
835
  self.pad_token_id,
836
  self.forward_batch_size,
837
+ self.token_gene_dict,
838
  summary_stat=None,
839
  silent=True,
840
  )
841
 
842
  num_inds_perturbed = 1 + self.combos
843
+ # need to remove overexpressed gene and cls to quantify cosine shifts
844
+ if "cls" not in self.emb_mode:
845
+ start = 0
846
+ else:
847
+ start = 1
848
  if self.perturb_type == "overexpress":
849
+ perturbation_emb = full_perturbation_emb[:, start+num_inds_perturbed:, :]
850
  gene_list = gene_list[
851
+ start+num_inds_perturbed:
852
+ ] # cls and index 0 is not overexpressed
853
 
854
  elif self.perturb_type == "delete":
855
+ perturbation_emb = full_perturbation_emb[:, start:, :]
856
+ gene_list = gene_list[start:]
857
 
858
+ full_original_emb = full_original_emb[:, start:, :]
859
  original_batch = pu.make_comparison_batch(
860
  full_original_emb, indices_to_perturb, perturb_group=False
861
  )
 
868
  self.state_embs_dict,
869
  emb_mode="gene",
870
  )
871
+
872
  if self.cell_states_to_model is not None:
873
+ if "cls" not in self.emb_mode:
874
+ original_cell_emb = pu.compute_nonpadded_cell_embedding(
875
+ full_original_emb, "mean_pool"
876
+ )
877
+ perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
878
+ full_perturbation_emb, "mean_pool"
879
+ )
880
+ else:
881
+ # get cls emb
882
+ original_cell_emb = full_original_emb[:,0,:]
883
+ perturbation_cell_emb = full_perturbation_emb[:,0,:]
884
 
885
  cell_cos_sims = pu.quant_cos_sims(
886
  perturbation_cell_emb,
 
911
  ] = gene_cos_sims[perturbation_i, gene_j].item()
912
 
913
  if self.cell_states_to_model is None:
914
+ if "cls" not in self.emb_mode:
915
+ cos_sims_data = torch.mean(gene_cos_sims, dim=1)
916
+ else:
917
+ original_cls_emb = full_original_emb[:,0,:]
918
+ perturbation_cls_emb = full_perturbation_emb[:,0,:]
919
+ cos_sims_data = pu.quant_cos_sims(
920
+ perturbation_cls_emb,
921
+ original_cls_emb,
922
+ self.cell_states_to_model,
923
+ self.state_embs_dict,
924
+ emb_mode="cell",
925
+ )
926
  cos_sims_dict = self.update_perturbation_dictionary(
927
  cos_sims_dict,
928
  cos_sims_data,
 
1015
  for i, cos in enumerate(cos_sims_data.tolist()):
1016
  cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
1017
 
1018
+ return cos_sims_dict
geneformer/perturber_utils.py CHANGED
@@ -228,16 +228,32 @@ def overexpress_indices(example):
228
  example["length"] = len(example["input_ids"])
229
  return example
230
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
233
- def overexpress_tokens(example, max_len):
234
  # -100 indicates tokens to overexpress are not present in rank value encoding
235
  if example["perturb_index"] != [-100]:
236
  example = delete_indices(example)
237
- [
238
- example["input_ids"].insert(0, token)
239
- for token in example["tokens_to_perturb"][::-1]
240
- ]
 
 
 
 
 
 
241
 
242
  # truncate to max input size, must also truncate original emb to be comparable
243
  if len(example["input_ids"]) > max_len:
@@ -259,6 +275,12 @@ def truncate_by_n_overflow(example):
259
  example["length"] = len(example["input_ids"])
260
  return example
261
 
 
 
 
 
 
 
262
 
263
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
264
  # indices_to_remove is list of indices to remove
@@ -321,7 +343,7 @@ def remove_perturbed_indices_set(
321
 
322
 
323
  def make_perturbation_batch(
324
- example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
325
  ) -> tuple[Dataset, List[int]]:
326
  if combo_lvl == 0 and tokens_to_perturb == "all":
327
  if perturb_type in ["overexpress", "activate"]:
@@ -383,9 +405,14 @@ def make_perturbation_batch(
383
  delete_indices, num_proc=num_proc_i
384
  )
385
  elif perturb_type == "overexpress":
386
- perturbation_dataset = perturbation_dataset.map(
387
- overexpress_indices, num_proc=num_proc_i
388
- )
 
 
 
 
 
389
 
390
  perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
391
 
@@ -758,4 +785,4 @@ class GeneIdHandler:
758
  return self.ens_to_symbol(self.token_to_ens(token))
759
 
760
  def symbol_to_token(self, symbol):
761
- return self.ens_to_token(self.symbol_to_ens(symbol))
 
228
  example["length"] = len(example["input_ids"])
229
  return example
230
 
231
+ # if CLS token present, move to 1st rather than 0th position
232
+ def overexpress_indices_special(example):
233
+ indices = example["perturb_index"]
234
+ if any(isinstance(el, list) for el in indices):
235
+ indices = flatten_list(indices)
236
+ for index in sorted(indices, reverse=True):
237
+ example["input_ids"].insert(1, example["input_ids"].pop(index))
238
+
239
+ example["length"] = len(example["input_ids"])
240
+ return example
241
 
242
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
243
+ def overexpress_tokens(example, max_len, special_token):
244
  # -100 indicates tokens to overexpress are not present in rank value encoding
245
  if example["perturb_index"] != [-100]:
246
  example = delete_indices(example)
247
+ if special_token:
248
+ [
249
+ example["input_ids"].insert(1, token)
250
+ for token in example["tokens_to_perturb"][::-1]
251
+ ]
252
+ else:
253
+ [
254
+ example["input_ids"].insert(0, token)
255
+ for token in example["tokens_to_perturb"][::-1]
256
+ ]
257
 
258
  # truncate to max input size, must also truncate original emb to be comparable
259
  if len(example["input_ids"]) > max_len:
 
275
  example["length"] = len(example["input_ids"])
276
  return example
277
 
278
+ def truncate_by_n_overflow_special(example):
279
+ new_max_len = example["length"] - example["n_overflow"]
280
+ example["input_ids"] = example["input_ids"][0:new_max_len-1]+[example["input_ids"][-1]]
281
+ example["length"] = len(example["input_ids"])
282
+ return example
283
+
284
 
285
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
286
  # indices_to_remove is list of indices to remove
 
343
 
344
 
345
  def make_perturbation_batch(
346
+ example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc, special_token
347
  ) -> tuple[Dataset, List[int]]:
348
  if combo_lvl == 0 and tokens_to_perturb == "all":
349
  if perturb_type in ["overexpress", "activate"]:
 
405
  delete_indices, num_proc=num_proc_i
406
  )
407
  elif perturb_type == "overexpress":
408
+ if special_token:
409
+ perturbation_dataset = perturbation_dataset.map(
410
+ overexpress_indices_special, num_proc=num_proc_i
411
+ )
412
+ else:
413
+ perturbation_dataset = perturbation_dataset.map(
414
+ overexpress_indices, num_proc=num_proc_i
415
+ )
416
 
417
  perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
418
 
 
785
  return self.ens_to_symbol(self.token_to_ens(token))
786
 
787
  def symbol_to_token(self, symbol):
788
+ return self.ens_to_token(self.symbol_to_ens(symbol))