hchen725 commited on
Commit
bcaf65e
1 Parent(s): 428e3b0

add isp for perturb all with cls

Browse files
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +222 -48
geneformer/in_silico_perturber.py CHANGED
@@ -476,9 +476,14 @@ class InSilicoPerturber:
476
  model, filtered_input_data, layer_to_quant, output_path_prefix
477
  )
478
  else:
479
- self.isp_perturb_all(
480
- model, filtered_input_data, layer_to_quant, output_path_prefix
481
- )
 
 
 
 
 
482
 
483
  def apply_additional_filters(self, filtered_input_data):
484
  # additional filtering of input data dependent on isp mode
@@ -812,7 +817,7 @@ class InSilicoPerturber:
812
  for state in pu.get_possible_states(self.cell_states_to_model)
813
  }
814
 
815
- if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
816
  stored_gene_embs_dict = defaultdict(list)
817
  for i in trange(len(filtered_input_data)):
818
  example_cell = filtered_input_data.select([i])
@@ -842,7 +847,6 @@ class InSilicoPerturber:
842
  self.anchor_token,
843
  self.combos,
844
  self.nproc,
845
- self.special_token,
846
  )
847
 
848
  full_perturbation_emb = get_embs(
@@ -859,30 +863,16 @@ class InSilicoPerturber:
859
 
860
  num_inds_perturbed = 1 + self.combos
861
 
862
- # need to remove overexpressed gene and cls/eos to quantify cosine shifts
863
- if "cls" not in self.emb_mode:
864
- start = 0
865
- end = None
866
- else:
867
- start = 1
868
- end = -1
869
  if self.perturb_type == "overexpress":
870
- perturbation_emb = full_perturbation_emb[:, start+num_inds_perturbed:end, :]
871
- gene_list = gene_list[
872
- start+num_inds_perturbed:end
873
- ] # cls/eos and index 0 is not overexpressed
874
-
875
  elif self.perturb_type == "delete":
876
- perturbation_emb = full_perturbation_emb[:, start:end, :]
877
- gene_list = gene_list[start:end]
878
 
879
  original_batch = pu.make_comparison_batch(
880
  full_original_emb, indices_to_perturb, perturb_group=False
881
  )
882
 
883
- original_batch = original_batch[:, start:end, :]
884
-
885
- if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene" or self.emb_mode == "cls_and_gene":
886
  gene_cos_sims = pu.quant_cos_sims(
887
  perturbation_emb,
888
  original_batch,
@@ -892,17 +882,198 @@ class InSilicoPerturber:
892
  )
893
 
894
  if self.cell_states_to_model is not None:
895
- if "cls" not in self.emb_mode:
896
- original_cell_emb = pu.compute_nonpadded_cell_embedding(
897
- full_original_emb, "mean_pool"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
898
  )
899
- perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
900
- full_perturbation_emb, "mean_pool"
 
 
 
 
 
 
 
 
 
901
  )
 
 
 
 
 
 
902
  else:
903
- # get cls emb
904
- original_cell_emb = full_original_emb[:,0,:]
905
- perturbation_cell_emb = full_perturbation_emb[:,0,:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906
 
907
  cell_cos_sims = pu.quant_cos_sims(
908
  perturbation_cell_emb,
@@ -912,7 +1083,7 @@ class InSilicoPerturber:
912
  emb_mode="cell",
913
  )
914
 
915
- if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
916
  # remove perturbed index for gene list
917
  perturbed_gene_dict = {
918
  gene: gene_list[:i] + gene_list[i + 1 :]
@@ -933,18 +1104,15 @@ class InSilicoPerturber:
933
  ] = gene_cos_sims[perturbation_i, gene_j].item()
934
 
935
  if self.cell_states_to_model is None:
936
- if "cls" not in self.emb_mode:
937
- cos_sims_data = torch.mean(gene_cos_sims, dim=1)
938
- else:
939
- original_cls_emb = full_original_emb[:,0,:]
940
- perturbation_cls_emb = full_perturbation_emb[:,0,:]
941
- cos_sims_data = pu.quant_cos_sims(
942
- perturbation_cls_emb,
943
- original_cls_emb,
944
- self.cell_states_to_model,
945
- self.state_embs_dict,
946
- emb_mode="cell",
947
- )
948
  cos_sims_dict = self.update_perturbation_dictionary(
949
  cos_sims_dict,
950
  cos_sims_data,
@@ -969,7 +1137,7 @@ class InSilicoPerturber:
969
  cos_sims_dict,
970
  f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
971
  )
972
- if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
973
  pu.write_perturbation_dictionary(
974
  stored_gene_embs_dict,
975
  f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
@@ -986,21 +1154,27 @@ class InSilicoPerturber:
986
  for state in pu.get_possible_states(self.cell_states_to_model)
987
  }
988
 
989
- if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
990
  stored_gene_embs_dict = defaultdict(list)
991
 
 
 
 
 
 
992
  torch.cuda.empty_cache()
993
 
994
  pu.write_perturbation_dictionary(
995
  cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
996
  )
997
 
998
- if (self.emb_mode == "cell_and_gene") or (self.emb_mode == "cls_and_gene"):
999
  pu.write_perturbation_dictionary(
1000
  stored_gene_embs_dict,
1001
  f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
1002
  )
1003
 
 
1004
  def update_perturbation_dictionary(
1005
  self,
1006
  cos_sims_dict: defaultdict,
@@ -1012,8 +1186,8 @@ class InSilicoPerturber:
1012
  if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
1013
  logger.error(
1014
  f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
1015
- cos_sims_data.shape[0] = {cos_sims_data.shape[0]}.\n \
1016
- len(gene_list) = {len(gene_list)}."
1017
  )
1018
  raise
1019
 
 
476
  model, filtered_input_data, layer_to_quant, output_path_prefix
477
  )
478
  else:
479
+ if (self.special_token) and ("cls" in self.emb_mode):
480
+ self.isp_perturb_all_special(
481
+ model, filtered_input_data, layer_to_quant, output_path_prefix
482
+ )
483
+ else:
484
+ self.isp_perturb_all(
485
+ model, filtered_input_data, layer_to_quant, output_path_prefix
486
+ )
487
 
488
  def apply_additional_filters(self, filtered_input_data):
489
  # additional filtering of input data dependent on isp mode
 
817
  for state in pu.get_possible_states(self.cell_states_to_model)
818
  }
819
 
820
+ if self.emb_mode == "cell_and_gene":
821
  stored_gene_embs_dict = defaultdict(list)
822
  for i in trange(len(filtered_input_data)):
823
  example_cell = filtered_input_data.select([i])
 
847
  self.anchor_token,
848
  self.combos,
849
  self.nproc,
 
850
  )
851
 
852
  full_perturbation_emb = get_embs(
 
863
 
864
  num_inds_perturbed = 1 + self.combos
865
 
 
 
 
 
 
 
 
866
  if self.perturb_type == "overexpress":
867
+ perturbation_emb = full_perturbation_emb[:, 0+num_inds_perturbed:None, :]
 
 
 
 
868
  elif self.perturb_type == "delete":
869
+ perturbation_emb = full_perturbation_emb
 
870
 
871
  original_batch = pu.make_comparison_batch(
872
  full_original_emb, indices_to_perturb, perturb_group=False
873
  )
874
 
875
+ if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene":
 
 
876
  gene_cos_sims = pu.quant_cos_sims(
877
  perturbation_emb,
878
  original_batch,
 
882
  )
883
 
884
  if self.cell_states_to_model is not None:
885
+ original_cell_emb = pu.compute_nonpadded_cell_embedding(
886
+ full_original_emb, "mean_pool"
887
+ )
888
+ perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
889
+ full_perturbation_emb, "mean_pool"
890
+ )
891
+
892
+ cell_cos_sims = pu.quant_cos_sims(
893
+ perturbation_cell_emb,
894
+ original_cell_emb,
895
+ self.cell_states_to_model,
896
+ self.state_embs_dict,
897
+ emb_mode="cell",
898
+ )
899
+
900
+ if self.emb_mode == "cell_and_gene":
901
+ # remove perturbed index for gene list
902
+ perturbed_gene_dict = {
903
+ gene: gene_list[:i] + gene_list[i + 1 :]
904
+ for i, gene in enumerate(gene_list)
905
+ }
906
+
907
+ for perturbation_i, perturbed_gene in enumerate(gene_list):
908
+ for gene_j, affected_gene in enumerate(
909
+ perturbed_gene_dict[perturbed_gene]
910
+ ):
911
+ try:
912
+ stored_gene_embs_dict[
913
+ (perturbed_gene, affected_gene)
914
+ ].append(gene_cos_sims[perturbation_i, gene_j].item())
915
+ except KeyError:
916
+ stored_gene_embs_dict[
917
+ (perturbed_gene, affected_gene)
918
+ ] = gene_cos_sims[perturbation_i, gene_j].item()
919
+
920
+ if self.cell_states_to_model is None:
921
+ cos_sims_data = torch.mean(gene_cos_sims, dim=1)
922
+ cos_sims_dict = self.update_perturbation_dictionary(
923
+ cos_sims_dict,
924
+ cos_sims_data,
925
+ filtered_input_data,
926
+ indices_to_perturb,
927
+ gene_list,
928
+ )
929
+ else:
930
+ cos_sims_data = cell_cos_sims
931
+ for state in cos_sims_dict.keys():
932
+ cos_sims_dict[state] = self.update_perturbation_dictionary(
933
+ cos_sims_dict[state],
934
+ cos_sims_data[state],
935
+ filtered_input_data,
936
+ indices_to_perturb,
937
+ gene_list,
938
  )
939
+
940
+ # save dict to disk every 100 cells
941
+ if i % self.clear_mem_ncells/10 == 0:
942
+ pu.write_perturbation_dictionary(
943
+ cos_sims_dict,
944
+ f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
945
+ )
946
+ if self.emb_mode == "cell_and_gene":
947
+ pu.write_perturbation_dictionary(
948
+ stored_gene_embs_dict,
949
+ f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
950
  )
951
+
952
+ # reset and clear memory every 1000 cells
953
+ if i % self.clear_mem_ncells == 0:
954
+ pickle_batch += 1
955
+ if self.cell_states_to_model is None:
956
+ cos_sims_dict = defaultdict(list)
957
  else:
958
+ cos_sims_dict = {
959
+ state: defaultdict(list)
960
+ for state in pu.get_possible_states(self.cell_states_to_model)
961
+ }
962
+
963
+ if self.emb_mode == "cell_and_gene":
964
+ stored_gene_embs_dict = defaultdict(list)
965
+
966
+ del full_original_emb
967
+ del perturbation_batch
968
+ del full_perturbation_emb
969
+ del perturbation_emb
970
+ del original_batch
971
+
972
+ torch.cuda.empty_cache()
973
+
974
+ pu.write_perturbation_dictionary(
975
+ cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
976
+ )
977
+
978
+ if self.emb_mode == "cell_and_gene":
979
+ pu.write_perturbation_dictionary(
980
+ stored_gene_embs_dict,
981
+ f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
982
+ )
983
+
984
+
985
+ def isp_perturb_all_special(
986
+ self,
987
+ model,
988
+ filtered_input_data: Dataset,
989
+ layer_to_quant: int,
990
+ output_path_prefix: str,
991
+ ):
992
+ pickle_batch = -1
993
+ if self.cell_states_to_model is None:
994
+ cos_sims_dict = defaultdict(list)
995
+ else:
996
+ cos_sims_dict = {
997
+ state: defaultdict(list)
998
+ for state in pu.get_possible_states(self.cell_states_to_model)
999
+ }
1000
+
1001
+ if self.emb_mode == "cls_and_gene":
1002
+ stored_gene_embs_dict = defaultdict(list)
1003
+ for i in trange(len(filtered_input_data)):
1004
+ example_cell = filtered_input_data.select([i])
1005
+ full_original_emb = get_embs(
1006
+ model,
1007
+ example_cell,
1008
+ "gene",
1009
+ layer_to_quant,
1010
+ self.pad_token_id,
1011
+ self.forward_batch_size,
1012
+ self.token_gene_dict,
1013
+ summary_stat=None,
1014
+ silent=True,
1015
+ )
1016
+
1017
+ # gene_list is used to assign cos sims back to genes
1018
+ # need to remove the anchor gene
1019
+ gene_list = example_cell["input_ids"][0][:]
1020
+ if self.anchor_token is not None:
1021
+ for token in self.anchor_token:
1022
+ gene_list.remove(token)
1023
+
1024
+ # Also exclude special token from gene_list
1025
+ if self.special_token:
1026
+ for token in [self.cls_token_id, self.eos_token_id]:
1027
+ gene_list.remove(token)
1028
+
1029
+ perturbation_batch, indices_to_perturb = pu.make_perturbation_batch_special(
1030
+ example_cell,
1031
+ self.perturb_type,
1032
+ self.tokens_to_perturb,
1033
+ self.anchor_token,
1034
+ self.combos,
1035
+ self.nproc,
1036
+ )
1037
+
1038
+ full_perturbation_emb = get_embs(
1039
+ model,
1040
+ perturbation_batch,
1041
+ "gene",
1042
+ layer_to_quant,
1043
+ self.pad_token_id,
1044
+ self.forward_batch_size,
1045
+ self.token_gene_dict,
1046
+ summary_stat=None,
1047
+ silent=True,
1048
+ )
1049
+
1050
+ num_inds_perturbed = 1 + self.combos
1051
+
1052
+ # need to remove overexpressed gene and cls/eos to quantify cosine shifts
1053
+ if self.perturb_type == "overexpress":
1054
+ perturbation_emb = full_perturbation_emb[:, 1+num_inds_perturbed:-1, :]
1055
+ elif self.perturb_type == "delete":
1056
+ perturbation_emb = full_perturbation_emb[:, 1:-1, :]
1057
+
1058
+ original_batch = pu.make_comparison_batch(
1059
+ full_original_emb, indices_to_perturb, perturb_group=False
1060
+ )
1061
+
1062
+ original_batch = original_batch[:, 1:-1, :]
1063
+
1064
+ if self.cell_states_to_model is None or self.emb_mode == "cls_and_gene":
1065
+ gene_cos_sims = pu.quant_cos_sims(
1066
+ perturbation_emb,
1067
+ original_batch,
1068
+ self.cell_states_to_model,
1069
+ self.state_embs_dict,
1070
+ emb_mode="gene",
1071
+ )
1072
+
1073
+ if self.cell_states_to_model is not None:
1074
+ # get cls emb
1075
+ original_cell_emb = full_original_emb[:,0,:]
1076
+ perturbation_cell_emb = full_perturbation_emb[:,0,:]
1077
 
1078
  cell_cos_sims = pu.quant_cos_sims(
1079
  perturbation_cell_emb,
 
1083
  emb_mode="cell",
1084
  )
1085
 
1086
+ if self.emb_mode == "cls_and_gene":
1087
  # remove perturbed index for gene list
1088
  perturbed_gene_dict = {
1089
  gene: gene_list[:i] + gene_list[i + 1 :]
 
1104
  ] = gene_cos_sims[perturbation_i, gene_j].item()
1105
 
1106
  if self.cell_states_to_model is None:
1107
+ original_cls_emb = full_original_emb[:,0,:]
1108
+ perturbation_cls_emb = full_perturbation_emb[:,0,:]
1109
+ cos_sims_data = pu.quant_cos_sims(
1110
+ perturbation_cls_emb,
1111
+ original_cls_emb,
1112
+ self.cell_states_to_model,
1113
+ self.state_embs_dict,
1114
+ emb_mode="cell",
1115
+ )
 
 
 
1116
  cos_sims_dict = self.update_perturbation_dictionary(
1117
  cos_sims_dict,
1118
  cos_sims_data,
 
1137
  cos_sims_dict,
1138
  f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}",
1139
  )
1140
+ if self.emb_mode == "cls_and_gene":
1141
  pu.write_perturbation_dictionary(
1142
  stored_gene_embs_dict,
1143
  f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
 
1154
  for state in pu.get_possible_states(self.cell_states_to_model)
1155
  }
1156
 
1157
+ if self.emb_mode == "cls_and_gene":
1158
  stored_gene_embs_dict = defaultdict(list)
1159
 
1160
+ del full_original_emb
1161
+ del perturbation_batch
1162
+ del full_perturbation_emb
1163
+ del perturbation_emb
1164
+ del original_batch
1165
  torch.cuda.empty_cache()
1166
 
1167
  pu.write_perturbation_dictionary(
1168
  cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}"
1169
  )
1170
 
1171
+ if self.emb_mode == "cls_and_gene":
1172
  pu.write_perturbation_dictionary(
1173
  stored_gene_embs_dict,
1174
  f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}",
1175
  )
1176
 
1177
+
1178
  def update_perturbation_dictionary(
1179
  self,
1180
  cos_sims_dict: defaultdict,
 
1186
  if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
1187
  logger.error(
1188
  f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
1189
+ {cos_sims_data.shape[0]=}.\n \
1190
+ {len(gene_list)=}."
1191
  )
1192
  raise
1193