hchen725 commited on
Commit
9870991
1 Parent(s): 038e0aa

Update geneformer/in_silico_perturber.py

Browse files
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +60 -72
geneformer/in_silico_perturber.py CHANGED
@@ -853,19 +853,14 @@ class InSilicoPerturber:
853
  summary_stat=None,
854
  silent=True,
855
  )
856
-
857
- # if no goal states, the cosine similarities are the mean of cell cosine similarities?
858
- if self.cell_states_to_model is None:
859
- continue
860
-
861
- # if there are goal states, the cosine simlarities are the CLS cosine similarities
862
- if self.cell_states_to_model is not None:
863
- cls_cos_sims = pu.quant_cos_sims(
864
- cls_perturbation_emb,
865
- cls_original_emb,
866
- self.cell_states_to_model,
867
- self.state_embs_dict,
868
- emb_mode="cell")
869
 
870
  # Update perturbation dictionary
871
  if self.cell_states_to_model is None:
@@ -886,6 +881,9 @@ class InSilicoPerturber:
886
  gene_list = None,
887
  )
888
 
 
 
 
889
  elif self.emb_mode == "cls_and_gene":
890
  full_original_emb = get_embs(
891
  model,
@@ -919,17 +917,11 @@ class InSilicoPerturber:
919
  silent=True,
920
  )
921
 
922
- original_emb = original_emb[
923
- :, 0 : 0, :
924
- ]
925
  if self.perturb_type == "overexpress":
926
- perturbation_emb = full_perturbation_emb[
927
- :, 0+len(self.tokens_to_perturb) : 0, :
928
- ]
929
  elif self.perturb_type == "delete":
930
- perturbation_emb = full_perturbation_emb[
931
- :, 0 : max(perturbation_batch["length"])+0, :
932
- ]
933
 
934
  n_perturbation_genes = perturbation_emb.size()[1]
935
 
@@ -958,6 +950,7 @@ class InSilicoPerturber:
958
  self.state_embs_dict,
959
  emb_mode="cell",
960
  )
 
961
 
962
  # get cosine similarities in gene embeddings
963
  # if getting gene embeddings, need gene names
@@ -973,11 +966,6 @@ class InSilicoPerturber:
973
  ]
974
  for genes in gene_list
975
  ]
976
- # remove CLS and EOS if present
977
- # if "cls" in self.emb_mode:
978
- # cls_token_id = self.gene_token_dict["<cls>"]
979
- # eos_token_id = self.gene_token_dict["<eos>"]
980
- # gene_list = [e for e in gene_list if e not in [cls_token_id,eos_token_id]]
981
 
982
  for cell_i, genes in enumerate(gene_list):
983
  for gene_j, affected_gene in enumerate(genes):
@@ -998,59 +986,58 @@ class InSilicoPerturber:
998
  else:
999
  gene_list = None
1000
 
1001
- if self.cell_states_to_model is None:
1002
- # calculate the mean of the gene cosine similarities for cell shift
1003
- # tensor of nonpadding lengths for each cell
1004
- if self.perturb_type == "overexpress":
1005
- # subtract number of genes that were overexpressed
1006
- # since they are removed before getting cos sims
1007
- n_overexpressed = len(self.tokens_to_perturb)
1008
- nonpadding_lens = [
1009
- x - n_overexpressed for x in perturbation_batch["length"]
1010
- ]
1011
- else:
1012
- nonpadding_lens = perturbation_batch["length"]
1013
- if "cls" not in self.emb_mode:
1014
- cos_sims_data = pu.mean_nonpadding_embs(
1015
- gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
1016
- )
1017
- else:
1018
- cos_sims_data = pu.quant_cos_sims(
1019
- perturbation_cls_emb,
1020
- original_cls_emb,
1021
- self.cell_states_to_model,
1022
- self.state_embs_dict,
1023
- emb_mode="cell",
1024
- )
1025
 
1026
- cos_sims_dict = self.update_perturbation_dictionary(
1027
- cos_sims_dict,
1028
- cos_sims_data,
1029
- filtered_input_data,
1030
- indices_to_perturb,
1031
- gene_list,
1032
- )
1033
- else:
1034
- cos_sims_data = cell_cos_sims
1035
- for state in cos_sims_dict.keys():
1036
- cos_sims_dict[state] = self.update_perturbation_dictionary(
1037
- cos_sims_dict[state],
1038
- cos_sims_data[state],
1039
  filtered_input_data,
1040
  indices_to_perturb,
1041
  gene_list,
1042
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1043
 
1044
  del minibatch
1045
  del perturbation_batch
1046
- del full_original_emb
1047
- del original_emb
1048
- del full_perturbation_emb
1049
- del perturbation_emb
1050
- del cos_sims_data
1051
- if ("cls" in self.emb_mode) and (self.cell_states_to_model is None):
1052
- del original_cls_emb
1053
- del perturbation_cls_emb
1054
 
1055
  torch.cuda.empty_cache()
1056
 
@@ -1462,6 +1449,7 @@ class InSilicoPerturber:
1462
  gene_list,
1463
  )
1464
  else:
 
1465
  for state in cos_sims_dict.keys():
1466
  cos_sims_dict[state] = self.update_perturbation_dictionary(
1467
  cos_sims_dict[state],
 
853
  summary_stat=None,
854
  silent=True,
855
  )
856
+
857
+ # Calculate the cosine similarities
858
+ cls_cos_sims = pu.quant_cos_sims(
859
+ cls_perturbation_emb,
860
+ cls_original_emb,
861
+ self.cell_states_to_model,
862
+ self.state_embs_dict,
863
+ emb_mode="cell")
 
 
 
 
 
864
 
865
  # Update perturbation dictionary
866
  if self.cell_states_to_model is None:
 
881
  gene_list = None,
882
  )
883
 
884
+ del cls_original_emb
885
+ del cls_perturbation_emb
886
+
887
  elif self.emb_mode == "cls_and_gene":
888
  full_original_emb = get_embs(
889
  model,
 
917
  silent=True,
918
  )
919
 
920
+ original_emb = original_emb[: 0:0,:]
 
 
921
  if self.perturb_type == "overexpress":
922
+ perturbation_emb = full_perturbation_emb[:,0+len(self.tokens_to_perturb):0,:]
 
 
923
  elif self.perturb_type == "delete":
924
+ perturbation_emb = full_perturbation_emb[:,0:max(perturbation_batch["length"])+0,:]
 
 
925
 
926
  n_perturbation_genes = perturbation_emb.size()[1]
927
 
 
950
  self.state_embs_dict,
951
  emb_mode="cell",
952
  )
953
+ cos_sims_data = cell_cos_sims
954
 
955
  # get cosine similarities in gene embeddings
956
  # if getting gene embeddings, need gene names
 
966
  ]
967
  for genes in gene_list
968
  ]
 
 
 
 
 
969
 
970
  for cell_i, genes in enumerate(gene_list):
971
  for gene_j, affected_gene in enumerate(genes):
 
986
  else:
987
  gene_list = None
988
 
989
+ if self.cell_states_to_model is None:
990
+ # calculate the mean of the gene cosine similarities for cell shift
991
+ # tensor of nonpadding lengths for each cell
992
+ if self.perturb_type == "overexpress":
993
+ # subtract number of genes that were overexpressed
994
+ # since they are removed before getting cos sims
995
+ n_overexpressed = len(self.tokens_to_perturb)
996
+ nonpadding_lens = [
997
+ x - n_overexpressed for x in perturbation_batch["length"]
998
+ ]
999
+ else:
1000
+ nonpadding_lens = perturbation_batch["length"]
1001
+ if "cls" not in self.emb_mode:
1002
+ cos_sims_data = pu.mean_nonpadding_embs(
1003
+ gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
1004
+ )
1005
+ else:
1006
+ cos_sims_data = pu.quant_cos_sims(
1007
+ perturbation_cls_emb,
1008
+ original_cls_emb,
1009
+ self.cell_states_to_model,
1010
+ self.state_embs_dict,
1011
+ emb_mode="cell",
1012
+ )
1013
 
1014
+ cos_sims_dict = self.update_perturbation_dictionary(
1015
+ cos_sims_dict,
1016
+ cos_sims_data,
 
 
 
 
 
 
 
 
 
 
1017
  filtered_input_data,
1018
  indices_to_perturb,
1019
  gene_list,
1020
  )
1021
+ else:
1022
+ for state in cos_sims_dict.keys():
1023
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
1024
+ cos_sims_dict[state],
1025
+ cos_sims_data[state],
1026
+ filtered_input_data,
1027
+ indices_to_perturb,
1028
+ gene_list,
1029
+ )
1030
+ del full_original_emb
1031
+ del original_emb
1032
+ del full_perturbation_emb
1033
+ del perturbation_emb
1034
+ del cos_sims_data
1035
+ if self.cell_states_to_model is None:
1036
+ del original_cls_emb
1037
+ del perturbation_cls_emb
1038
 
1039
  del minibatch
1040
  del perturbation_batch
 
 
 
 
 
 
 
 
1041
 
1042
  torch.cuda.empty_cache()
1043
 
 
1449
  gene_list,
1450
  )
1451
  else:
1452
+ cos_sims_data = cell_cos_sims
1453
  for state in cos_sims_dict.keys():
1454
  cos_sims_dict[state] = self.update_perturbation_dictionary(
1455
  cos_sims_dict[state],