Update geneformer/in_silico_perturber.py
Browse files
geneformer/in_silico_perturber.py
CHANGED
@@ -853,19 +853,14 @@ class InSilicoPerturber:
|
|
853 |
summary_stat=None,
|
854 |
silent=True,
|
855 |
)
|
856 |
-
|
857 |
-
#
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
cls_perturbation_emb,
|
865 |
-
cls_original_emb,
|
866 |
-
self.cell_states_to_model,
|
867 |
-
self.state_embs_dict,
|
868 |
-
emb_mode="cell")
|
869 |
|
870 |
# Update perturbation dictionary
|
871 |
if self.cell_states_to_model is None:
|
@@ -886,6 +881,9 @@ class InSilicoPerturber:
|
|
886 |
gene_list = None,
|
887 |
)
|
888 |
|
|
|
|
|
|
|
889 |
elif self.emb_mode == "cls_and_gene":
|
890 |
full_original_emb = get_embs(
|
891 |
model,
|
@@ -919,17 +917,11 @@ class InSilicoPerturber:
|
|
919 |
silent=True,
|
920 |
)
|
921 |
|
922 |
-
original_emb = original_emb[
|
923 |
-
:, 0 : 0, :
|
924 |
-
]
|
925 |
if self.perturb_type == "overexpress":
|
926 |
-
perturbation_emb = full_perturbation_emb[
|
927 |
-
:, 0+len(self.tokens_to_perturb) : 0, :
|
928 |
-
]
|
929 |
elif self.perturb_type == "delete":
|
930 |
-
perturbation_emb = full_perturbation_emb[
|
931 |
-
:, 0 : max(perturbation_batch["length"])+0, :
|
932 |
-
]
|
933 |
|
934 |
n_perturbation_genes = perturbation_emb.size()[1]
|
935 |
|
@@ -958,6 +950,7 @@ class InSilicoPerturber:
|
|
958 |
self.state_embs_dict,
|
959 |
emb_mode="cell",
|
960 |
)
|
|
|
961 |
|
962 |
# get cosine similarities in gene embeddings
|
963 |
# if getting gene embeddings, need gene names
|
@@ -973,11 +966,6 @@ class InSilicoPerturber:
|
|
973 |
]
|
974 |
for genes in gene_list
|
975 |
]
|
976 |
-
# remove CLS and EOS if present
|
977 |
-
# if "cls" in self.emb_mode:
|
978 |
-
# cls_token_id = self.gene_token_dict["<cls>"]
|
979 |
-
# eos_token_id = self.gene_token_dict["<eos>"]
|
980 |
-
# gene_list = [e for e in gene_list if e not in [cls_token_id,eos_token_id]]
|
981 |
|
982 |
for cell_i, genes in enumerate(gene_list):
|
983 |
for gene_j, affected_gene in enumerate(genes):
|
@@ -998,59 +986,58 @@ class InSilicoPerturber:
|
|
998 |
else:
|
999 |
gene_list = None
|
1000 |
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
1008 |
-
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
1020 |
-
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
1025 |
|
1026 |
-
|
1027 |
-
|
1028 |
-
|
1029 |
-
filtered_input_data,
|
1030 |
-
indices_to_perturb,
|
1031 |
-
gene_list,
|
1032 |
-
)
|
1033 |
-
else:
|
1034 |
-
cos_sims_data = cell_cos_sims
|
1035 |
-
for state in cos_sims_dict.keys():
|
1036 |
-
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
1037 |
-
cos_sims_dict[state],
|
1038 |
-
cos_sims_data[state],
|
1039 |
filtered_input_data,
|
1040 |
indices_to_perturb,
|
1041 |
gene_list,
|
1042 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1043 |
|
1044 |
del minibatch
|
1045 |
del perturbation_batch
|
1046 |
-
del full_original_emb
|
1047 |
-
del original_emb
|
1048 |
-
del full_perturbation_emb
|
1049 |
-
del perturbation_emb
|
1050 |
-
del cos_sims_data
|
1051 |
-
if ("cls" in self.emb_mode) and (self.cell_states_to_model is None):
|
1052 |
-
del original_cls_emb
|
1053 |
-
del perturbation_cls_emb
|
1054 |
|
1055 |
torch.cuda.empty_cache()
|
1056 |
|
@@ -1462,6 +1449,7 @@ class InSilicoPerturber:
|
|
1462 |
gene_list,
|
1463 |
)
|
1464 |
else:
|
|
|
1465 |
for state in cos_sims_dict.keys():
|
1466 |
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
1467 |
cos_sims_dict[state],
|
|
|
853 |
summary_stat=None,
|
854 |
silent=True,
|
855 |
)
|
856 |
+
|
857 |
+
# Calculate the cosine similarities
|
858 |
+
cls_cos_sims = pu.quant_cos_sims(
|
859 |
+
cls_perturbation_emb,
|
860 |
+
cls_original_emb,
|
861 |
+
self.cell_states_to_model,
|
862 |
+
self.state_embs_dict,
|
863 |
+
emb_mode="cell")
|
|
|
|
|
|
|
|
|
|
|
864 |
|
865 |
# Update perturbation dictionary
|
866 |
if self.cell_states_to_model is None:
|
|
|
881 |
gene_list = None,
|
882 |
)
|
883 |
|
884 |
+
del cls_original_emb
|
885 |
+
del cls_perturbation_emb
|
886 |
+
|
887 |
elif self.emb_mode == "cls_and_gene":
|
888 |
full_original_emb = get_embs(
|
889 |
model,
|
|
|
917 |
silent=True,
|
918 |
)
|
919 |
|
920 |
+
original_emb = original_emb[: 0:0,:]
|
|
|
|
|
921 |
if self.perturb_type == "overexpress":
|
922 |
+
perturbation_emb = full_perturbation_emb[:,0+len(self.tokens_to_perturb):0,:]
|
|
|
|
|
923 |
elif self.perturb_type == "delete":
|
924 |
+
perturbation_emb = full_perturbation_emb[:,0:max(perturbation_batch["length"])+0,:]
|
|
|
|
|
925 |
|
926 |
n_perturbation_genes = perturbation_emb.size()[1]
|
927 |
|
|
|
950 |
self.state_embs_dict,
|
951 |
emb_mode="cell",
|
952 |
)
|
953 |
+
cos_sims_data = cell_cos_sims
|
954 |
|
955 |
# get cosine similarities in gene embeddings
|
956 |
# if getting gene embeddings, need gene names
|
|
|
966 |
]
|
967 |
for genes in gene_list
|
968 |
]
|
|
|
|
|
|
|
|
|
|
|
969 |
|
970 |
for cell_i, genes in enumerate(gene_list):
|
971 |
for gene_j, affected_gene in enumerate(genes):
|
|
|
986 |
else:
|
987 |
gene_list = None
|
988 |
|
989 |
+
if self.cell_states_to_model is None:
|
990 |
+
# calculate the mean of the gene cosine similarities for cell shift
|
991 |
+
# tensor of nonpadding lengths for each cell
|
992 |
+
if self.perturb_type == "overexpress":
|
993 |
+
# subtract number of genes that were overexpressed
|
994 |
+
# since they are removed before getting cos sims
|
995 |
+
n_overexpressed = len(self.tokens_to_perturb)
|
996 |
+
nonpadding_lens = [
|
997 |
+
x - n_overexpressed for x in perturbation_batch["length"]
|
998 |
+
]
|
999 |
+
else:
|
1000 |
+
nonpadding_lens = perturbation_batch["length"]
|
1001 |
+
if "cls" not in self.emb_mode:
|
1002 |
+
cos_sims_data = pu.mean_nonpadding_embs(
|
1003 |
+
gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
|
1004 |
+
)
|
1005 |
+
else:
|
1006 |
+
cos_sims_data = pu.quant_cos_sims(
|
1007 |
+
perturbation_cls_emb,
|
1008 |
+
original_cls_emb,
|
1009 |
+
self.cell_states_to_model,
|
1010 |
+
self.state_embs_dict,
|
1011 |
+
emb_mode="cell",
|
1012 |
+
)
|
1013 |
|
1014 |
+
cos_sims_dict = self.update_perturbation_dictionary(
|
1015 |
+
cos_sims_dict,
|
1016 |
+
cos_sims_data,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1017 |
filtered_input_data,
|
1018 |
indices_to_perturb,
|
1019 |
gene_list,
|
1020 |
)
|
1021 |
+
else:
|
1022 |
+
for state in cos_sims_dict.keys():
|
1023 |
+
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
1024 |
+
cos_sims_dict[state],
|
1025 |
+
cos_sims_data[state],
|
1026 |
+
filtered_input_data,
|
1027 |
+
indices_to_perturb,
|
1028 |
+
gene_list,
|
1029 |
+
)
|
1030 |
+
del full_original_emb
|
1031 |
+
del original_emb
|
1032 |
+
del full_perturbation_emb
|
1033 |
+
del perturbation_emb
|
1034 |
+
del cos_sims_data
|
1035 |
+
if self.cell_states_to_model is None:
|
1036 |
+
del original_cls_emb
|
1037 |
+
del perturbation_cls_emb
|
1038 |
|
1039 |
del minibatch
|
1040 |
del perturbation_batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1041 |
|
1042 |
torch.cuda.empty_cache()
|
1043 |
|
|
|
1449 |
gene_list,
|
1450 |
)
|
1451 |
else:
|
1452 |
+
cos_sims_data = cell_cos_sims
|
1453 |
for state in cos_sims_dict.keys():
|
1454 |
cos_sims_dict[state] = self.update_perturbation_dictionary(
|
1455 |
cos_sims_dict[state],
|