ctheodoris commited on
Commit
3e24216
1 Parent(s): d6e949b

Update geneformer/in_silico_perturber.py

Browse files
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +57 -35
geneformer/in_silico_perturber.py CHANGED
@@ -99,6 +99,7 @@ class InSilicoPerturber:
99
  forward_batch_size=100,
100
  nproc=4,
101
  token_dictionary_file=None,
 
102
  ):
103
  """
104
  Initialize in silico perturber.
@@ -187,6 +188,8 @@ class InSilicoPerturber:
187
  | Number of CPU processes to use.
188
  token_dictionary_file : Path
189
  | Path to pickle file containing token dictionary (Ensembl ID:token).
 
 
190
  """
191
  try:
192
  set_start_method("spawn")
@@ -224,6 +227,7 @@ class InSilicoPerturber:
224
  self.forward_batch_size = forward_batch_size
225
  self.nproc = nproc
226
  self.token_dictionary_file = token_dictionary_file
 
227
 
228
  self.validate_options()
229
 
@@ -235,17 +239,16 @@ class InSilicoPerturber:
235
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
236
 
237
  self.pad_token_id = self.gene_token_dict.get("<pad>")
 
 
238
 
239
 
240
  # Identify if special token is present in the token dictionary
241
- lowercase_token_gene_dict = {k: v.lower() for k, v in self.token_gene_dict.items()}
242
- cls_present = any("cls" in value for value in lowercase_token_gene_dict.values())
243
- eos_present = any("eos" in value for value in lowercase_token_gene_dict.values())
244
- if cls_present or eos_present:
245
  self.special_token = True
246
  else:
247
  if "cls" in self.emb_mode:
248
- logger.error(f"emb_mode set to {self.emb_mode} but <cls> token not in token dictionary.")
249
  raise
250
  self.special_token = False
251
 
@@ -454,12 +457,17 @@ class InSilicoPerturber:
454
 
455
  # Ensure emb_mode is cls if first token of the filtered input data is cls token
456
  if self.special_token:
457
- cls_token_id = self.gene_token_dict["<cls>"]
458
- if (filtered_input_data["input_ids"][0][0] == cls_token_id) and ("cls" not in self.emb_mode):
459
  logger.error(
460
  "Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
461
  )
462
  raise
 
 
 
 
 
 
463
 
464
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
465
 
@@ -554,6 +562,7 @@ class InSilicoPerturber:
554
  perturbed_data = filtered_input_data.map(
555
  make_group_perturbation_batch, num_proc=self.nproc
556
  )
 
557
  if self.perturb_type == "overexpress":
558
  filtered_input_data = filtered_input_data.add_column(
559
  "n_overflow", perturbed_data["n_overflow"]
@@ -572,7 +581,7 @@ class InSilicoPerturber:
572
  pu.truncate_by_n_overflow, num_proc=self.nproc
573
  )
574
 
575
- if self.emb_mode == "cell_and_gene":
576
  stored_gene_embs_dict = defaultdict(list)
577
 
578
  # iterate through batches
@@ -618,20 +627,24 @@ class InSilicoPerturber:
618
 
619
  if "cls" not in self.emb_mode:
620
  start = 0
 
 
621
  else:
622
  start = 1
 
 
623
 
624
- # remove overexpressed genes and cls
625
  original_emb = original_emb[
626
- :, start :, :
627
  ]
628
  if self.perturb_type == "overexpress":
629
  perturbation_emb = full_perturbation_emb[
630
- :, start+len(self.tokens_to_perturb) :, :
631
  ]
632
  elif self.perturb_type == "delete":
633
  perturbation_emb = full_perturbation_emb[
634
- :, start : max(perturbation_batch["length"]), :
635
  ]
636
 
637
  n_perturbation_genes = perturbation_emb.size()[1]
@@ -640,6 +653,7 @@ class InSilicoPerturber:
640
  if (
641
  self.cell_states_to_model is None
642
  or self.emb_mode == "cell_and_gene"
 
643
  ):
644
  gene_cos_sims = pu.quant_cos_sims(
645
  perturbation_emb,
@@ -677,18 +691,23 @@ class InSilicoPerturber:
677
 
678
  # get cosine similarities in gene embeddings
679
  # if getting gene embeddings, need gene names
680
- if self.emb_mode == "cell_and_gene":
681
  gene_list = minibatch["input_ids"]
682
  # need to truncate gene_list
 
 
 
683
  gene_list = [
684
- [g for g in genes if g not in self.tokens_to_perturb][
685
  :n_perturbation_genes
686
  ]
687
  for genes in gene_list
688
  ]
689
- # remove CLS if present
690
- if "cls" in self.emb_mode:
691
- gene_list = gene_list[1:]
 
 
692
 
693
  for cell_i, genes in enumerate(gene_list):
694
  for gene_j, affected_gene in enumerate(genes):
@@ -760,10 +779,9 @@ class InSilicoPerturber:
760
  del full_perturbation_emb
761
  del perturbation_emb
762
  del cos_sims_data
763
- if "cls" in self.emb_mode:
764
  del original_cls_emb
765
  del perturbation_cls_emb
766
- del cls_cos_sims
767
 
768
  torch.cuda.empty_cache()
769
 
@@ -772,7 +790,7 @@ class InSilicoPerturber:
772
  f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
773
  )
774
 
775
- if self.emb_mode == "cell_and_gene":
776
  pu.write_perturbation_dictionary(
777
  stored_gene_embs_dict,
778
  f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
@@ -794,7 +812,7 @@ class InSilicoPerturber:
794
  for state in pu.get_possible_states(self.cell_states_to_model)
795
  }
796
 
797
- if self.emb_mode == "cell_and_gene":
798
  stored_gene_embs_dict = defaultdict(list)
799
  for i in trange(len(filtered_input_data)):
800
  example_cell = filtered_input_data.select([i])
@@ -840,27 +858,31 @@ class InSilicoPerturber:
840
  )
841
 
842
  num_inds_perturbed = 1 + self.combos
843
- # need to remove overexpressed gene and cls to quantify cosine shifts
 
844
  if "cls" not in self.emb_mode:
845
  start = 0
 
846
  else:
847
  start = 1
 
848
  if self.perturb_type == "overexpress":
849
- perturbation_emb = full_perturbation_emb[:, start+num_inds_perturbed:, :]
850
  gene_list = gene_list[
851
- start+num_inds_perturbed:
852
- ] # cls and index 0 is not overexpressed
853
 
854
  elif self.perturb_type == "delete":
855
- perturbation_emb = full_perturbation_emb[:, start:, :]
856
- gene_list = gene_list[start:]
857
 
858
- full_original_emb = full_original_emb[:, start:, :]
859
  original_batch = pu.make_comparison_batch(
860
  full_original_emb, indices_to_perturb, perturb_group=False
861
  )
862
 
863
- if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
 
 
864
  gene_cos_sims = pu.quant_cos_sims(
865
  perturbation_emb,
866
  original_batch,
@@ -890,7 +912,7 @@ class InSilicoPerturber:
890
  emb_mode="cell",
891
  )
892
 
893
- if self.emb_mode == "cell_and_gene":
894
  # remove perturbed index for gene list
895
  perturbed_gene_dict = {
896
  gene: gene_list[:i] + gene_list[i + 1 :]
@@ -942,19 +964,19 @@ class InSilicoPerturber:
942
  )
943
 
944
  # save dict to disk every 100 cells
945
- if i % 100 == 0:
946
  pu.write_perturbation_dictionary(
947
  cos_sims_dict,
948
  f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
949
  )
950
- if self.emb_mode == "cell_and_gene":
951
  pu.write_perturbation_dictionary(
952
  stored_gene_embs_dict,
953
  f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
954
  )
955
 
956
  # reset and clear memory every 1000 cells
957
- if i % 1000 == 0:
958
  pickle_batch += 1
959
  if self.cell_states_to_model is None:
960
  cos_sims_dict = defaultdict(list)
@@ -964,7 +986,7 @@ class InSilicoPerturber:
964
  for state in pu.get_possible_states(self.cell_states_to_model)
965
  }
966
 
967
- if self.emb_mode == "cell_and_gene":
968
  stored_gene_embs_dict = defaultdict(list)
969
 
970
  torch.cuda.empty_cache()
@@ -973,7 +995,7 @@ class InSilicoPerturber:
973
  cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
974
  )
975
 
976
- if self.emb_mode == "cell_and_gene":
977
  pu.write_perturbation_dictionary(
978
  stored_gene_embs_dict,
979
  f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
 
99
  forward_batch_size=100,
100
  nproc=4,
101
  token_dictionary_file=None,
102
+ clear_mem_ncells=1000,
103
  ):
104
  """
105
  Initialize in silico perturber.
 
188
  | Number of CPU processes to use.
189
  token_dictionary_file : Path
190
  | Path to pickle file containing token dictionary (Ensembl ID:token).
191
+ clear_mem_ncells : int
192
+ | Clear memory every n cells.
193
  """
194
  try:
195
  set_start_method("spawn")
 
227
  self.forward_batch_size = forward_batch_size
228
  self.nproc = nproc
229
  self.token_dictionary_file = token_dictionary_file
230
+ self.clear_mem_ncells = clear_mem_ncells
231
 
232
  self.validate_options()
233
 
 
239
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
240
 
241
  self.pad_token_id = self.gene_token_dict.get("<pad>")
242
+ self.cls_token_id = self.gene_token_dict.get("<cls>")
243
+ self.eos_token_id = self.gene_token_dict.get("<eos>")
244
 
245
 
246
  # Identify if special token is present in the token dictionary
247
+ if (self.cls_token_id is not None) and (self.eos_token_id is not None):
 
 
 
248
  self.special_token = True
249
  else:
250
  if "cls" in self.emb_mode:
251
+ logger.error(f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary.")
252
  raise
253
  self.special_token = False
254
 
 
457
 
458
  # Ensure emb_mode is cls if first token of the filtered input data is cls token
459
  if self.special_token:
460
+ if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and ("cls" not in self.emb_mode):
 
461
  logger.error(
462
  "Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
463
  )
464
  raise
465
+ if ("cls" in self.emb_mode):
466
+ if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (filtered_input_data["input_ids"][0][-1] != self.eos_token_id):
467
+ logger.error(
468
+ "Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
469
+ )
470
+ raise
471
 
472
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
473
 
 
562
  perturbed_data = filtered_input_data.map(
563
  make_group_perturbation_batch, num_proc=self.nproc
564
  )
565
+
566
  if self.perturb_type == "overexpress":
567
  filtered_input_data = filtered_input_data.add_column(
568
  "n_overflow", perturbed_data["n_overflow"]
 
581
  pu.truncate_by_n_overflow, num_proc=self.nproc
582
  )
583
 
584
+ if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
585
  stored_gene_embs_dict = defaultdict(list)
586
 
587
  # iterate through batches
 
627
 
628
  if "cls" not in self.emb_mode:
629
  start = 0
630
+ end_add = 0
631
+ end = None
632
  else:
633
  start = 1
634
+ end_add = -1
635
+ end = -1
636
 
637
+ # remove overexpressed genes and cls/eos
638
  original_emb = original_emb[
639
+ :, start : end, :
640
  ]
641
  if self.perturb_type == "overexpress":
642
  perturbation_emb = full_perturbation_emb[
643
+ :, start+len(self.tokens_to_perturb) : end, :
644
  ]
645
  elif self.perturb_type == "delete":
646
  perturbation_emb = full_perturbation_emb[
647
+ :, start : max(perturbation_batch["length"])+end_add, :
648
  ]
649
 
650
  n_perturbation_genes = perturbation_emb.size()[1]
 
653
  if (
654
  self.cell_states_to_model is None
655
  or self.emb_mode == "cell_and_gene"
656
+ or self.emb_mode == "cls_and_gene"
657
  ):
658
  gene_cos_sims = pu.quant_cos_sims(
659
  perturbation_emb,
 
691
 
692
  # get cosine similarities in gene embeddings
693
  # if getting gene embeddings, need gene names
694
+ if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
695
  gene_list = minibatch["input_ids"]
696
  # need to truncate gene_list
697
+ genes_to_exclude = self.tokens_to_perturb
698
+ if self.emb_mode == "cls_and_gene":
699
+ genes_to_exclude = genes_to_exclude + [self.cls_token_id, self.eos_token_id]
700
  gene_list = [
701
+ [g for g in genes if g not in genes_to_exclude][
702
  :n_perturbation_genes
703
  ]
704
  for genes in gene_list
705
  ]
706
+ # remove CLS and EOS if present
707
+ # if "cls" in self.emb_mode:
708
+ # cls_token_id = self.gene_token_dict["<cls>"]
709
+ # eos_token_id = self.gene_token_dict["<eos>"]
710
+ # gene_list = [e for e in gene_list if e not in [cls_token_id,eos_token_id]]
711
 
712
  for cell_i, genes in enumerate(gene_list):
713
  for gene_j, affected_gene in enumerate(genes):
 
779
  del full_perturbation_emb
780
  del perturbation_emb
781
  del cos_sims_data
782
+ if ("cls" in self.emb_mode) and (self.cell_states_to_model is None):
783
  del original_cls_emb
784
  del perturbation_cls_emb
 
785
 
786
  torch.cuda.empty_cache()
787
 
 
790
  f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
791
  )
792
 
793
+ if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
794
  pu.write_perturbation_dictionary(
795
  stored_gene_embs_dict,
796
  f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
 
812
  for state in pu.get_possible_states(self.cell_states_to_model)
813
  }
814
 
815
+ if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
816
  stored_gene_embs_dict = defaultdict(list)
817
  for i in trange(len(filtered_input_data)):
818
  example_cell = filtered_input_data.select([i])
 
858
  )
859
 
860
  num_inds_perturbed = 1 + self.combos
861
+
862
+ # need to remove overexpressed gene and cls/eos to quantify cosine shifts
863
  if "cls" not in self.emb_mode:
864
  start = 0
865
+ end = None
866
  else:
867
  start = 1
868
+ end = -1
869
  if self.perturb_type == "overexpress":
870
+ perturbation_emb = full_perturbation_emb[:, start+num_inds_perturbed:end, :]
871
  gene_list = gene_list[
872
+ start+num_inds_perturbed:end
873
+ ] # cls/eos and index 0 is not overexpressed
874
 
875
  elif self.perturb_type == "delete":
876
+ perturbation_emb = full_perturbation_emb[:, start:end, :]
877
+ gene_list = gene_list[start:end]
878
 
 
879
  original_batch = pu.make_comparison_batch(
880
  full_original_emb, indices_to_perturb, perturb_group=False
881
  )
882
 
883
+ original_batch = original_batch[:, start:end, :]
884
+
885
+ if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene" or self.emb_mode == "cls_and_gene":
886
  gene_cos_sims = pu.quant_cos_sims(
887
  perturbation_emb,
888
  original_batch,
 
912
  emb_mode="cell",
913
  )
914
 
915
+ if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
916
  # remove perturbed index for gene list
917
  perturbed_gene_dict = {
918
  gene: gene_list[:i] + gene_list[i + 1 :]
 
964
  )
965
 
966
  # save dict to disk every 100 cells
967
+ if i % clear_mem_ncells/10 == 0:
968
  pu.write_perturbation_dictionary(
969
  cos_sims_dict,
970
  f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
971
  )
972
+ if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
973
  pu.write_perturbation_dictionary(
974
  stored_gene_embs_dict,
975
  f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
976
  )
977
 
978
  # reset and clear memory every 1000 cells
979
+ if i % clear_mem_ncells == 0:
980
  pickle_batch += 1
981
  if self.cell_states_to_model is None:
982
  cos_sims_dict = defaultdict(list)
 
986
  for state in pu.get_possible_states(self.cell_states_to_model)
987
  }
988
 
989
+ if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
990
  stored_gene_embs_dict = defaultdict(list)
991
 
992
  torch.cuda.empty_cache()
 
995
  cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
996
  )
997
 
998
+ if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
999
  pu.write_perturbation_dictionary(
1000
  stored_gene_embs_dict,
1001
  f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",