Fix bug in selecting a gene with "aggregate_data" option

#311
Files changed (1) hide show
  1. in_silico_perturber_stats.py +752 -0
in_silico_perturber_stats.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer in silico perturber stats generator.
3
+
4
+ Usage:
5
+ from geneformer import InSilicoPerturberStats
6
+ ispstats = InSilicoPerturberStats(mode="goal_state_shift",
7
+ combos=0,
8
+ anchor_gene=None,
9
+ cell_states_to_model={"state_key": "disease",
10
+ "start_state": "dcm",
11
+ "goal_state": "nf",
12
+ "alt_states": ["hcm", "other1", "other2"]})
13
+ ispstats.get_stats("path/to/input_data",
14
+ None,
15
+ "path/to/output_directory",
16
+ "output_prefix")
17
+ """
18
+
19
+
20
+ import os
21
+ import logging
22
+ import numpy as np
23
+ import pandas as pd
24
+ import pickle
25
+ import random
26
+ import statsmodels.stats.multitest as smt
27
+ from pathlib import Path
28
+ from scipy.stats import ranksums
29
+ from sklearn.mixture import GaussianMixture
30
+ from tqdm.auto import trange, tqdm
31
+
32
+ from .perturber_helpers import flatten_list
33
+
34
+ from .tokenizer import TOKEN_DICTIONARY_FILE
35
+
36
+ GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ # invert dictionary keys/values
41
+ def invert_dict(dictionary):
42
+ return {v: k for k, v in dictionary.items()}
43
+
44
+ def read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token):
45
+ if cell_or_gene_emb == "cell":
46
+ cell_emb_dict = {k: v for k,
47
+ v in cos_sims_dict.items() if v and "cell_emb" in k}
48
+ return [cell_emb_dict]
49
+ elif cell_or_gene_emb == "gene":
50
+ gene_emb_dict = {k: v for k,
51
+ v in cos_sims_dict.items() if v and anchor_token == k[0]}
52
+ return [gene_emb_dict]
53
+
54
+
55
+ def recursive_search_dir(dir, pickle_suffix):
56
+
57
+
58
+ # read raw dictionary files
59
+ def read_dictionaries(input_data_directory,
60
+ cell_or_gene_emb,
61
+ anchor_token,
62
+ cell_states_to_model,
63
+ pickle_suffix,
64
+ recursive=False):
65
+
66
+ file_found = False
67
+ file_path_list = []
68
+ if cell_states_to_model is None:
69
+ dict_list = []
70
+ else:
71
+ state_dict = {state: [] for state in cell_states_to_model}
72
+
73
+ for file in os.listdir(input_data_directory):
74
+ # process only _raw.pickle files
75
+ if file.endswith(pickle_suffix):
76
+ file_found = True
77
+ file_path_list += [f"{input_data_directory}/{file}"]
78
+ for file_path in tqdm(file_path_list):
79
+ with open(file_path, 'rb') as fp:
80
+ cos_sims_dict = pickle.load(fp)
81
+ if cell_states_to_model is None:
82
+ dict_list += read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token)
83
+ else:
84
+ for state in cell_states_to_model:
85
+ state_dict[state] += read_dict(cos_sims_dict[state], cell_or_gene_emb, anchor_token)
86
+ if not file_found:
87
+ logger.error(
88
+ f"No raw data for processing found within provided directory. " \
89
+ "Please ensure data files end with '{pickle_suffix}'.")
90
+ raise
91
+ if cell_states_to_model is None:
92
+ return dict_list
93
+ else:
94
+ return state_dict
95
+
96
+ # get complete gene list
97
+ def get_gene_list(dict_list,mode):
98
+ if mode == "cell":
99
+ position = 0
100
+ elif mode == "gene":
101
+ position = 1
102
+ gene_set = set()
103
+ for dict_i in dict_list:
104
+ gene_set.update([k[position] for k, v in dict_i.items() if v])
105
+ gene_list = list(gene_set)
106
+ if mode == "gene":
107
+ gene_list.remove("cell_emb")
108
+ gene_list.sort()
109
+ return gene_list
110
+
111
+ def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
112
+ try:
113
+ return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
114
+ except TypeError as te:
115
+ return tuple(gene_token_id_dict.get(token_tuple, np.nan))
116
+
117
+ def n_detections(token, dict_list, mode, anchor_token):
118
+ cos_sim_megalist = []
119
+ for dict_i in dict_list:
120
+ if mode == "cell":
121
+ cos_sim_megalist += dict_i.get((token, "cell_emb"),[])
122
+ elif mode == "gene":
123
+ cos_sim_megalist += dict_i.get((anchor_token, token),[])
124
+ return len(cos_sim_megalist)
125
+
126
+ def get_fdr(pvalues):
127
+ return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
128
+
129
+ def get_impact_component(test_value, gaussian_mixture_model):
130
+ impact_border = gaussian_mixture_model.means_[0][0]
131
+ nonimpact_border = gaussian_mixture_model.means_[1][0]
132
+ if test_value > nonimpact_border:
133
+ impact_component = 0
134
+ elif test_value < impact_border:
135
+ impact_component = 1
136
+ else:
137
+ impact_component_raw = gaussian_mixture_model.predict([[test_value]])[0]
138
+ if impact_component_raw == 1:
139
+ impact_component = 0
140
+ elif impact_component_raw == 0:
141
+ impact_component = 1
142
+ return impact_component
143
+
144
+ # aggregate data for single perturbation in multiple cells
145
+ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
146
+ names=["Cosine_shift"]
147
+ cos_sims_full_df = pd.DataFrame(columns=names)
148
+
149
+ cos_shift_data = []
150
+ token = cos_sims_df["Gene"][0]
151
+ for dict_i in dict_list:
152
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
153
+ cos_sims_full_df["Cosine_shift"] = cos_shift_data
154
+ return cos_sims_full_df
155
+
156
+ # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
157
+ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_perturbed):
158
+ cell_state_key = cell_states_to_model["start_state"]
159
+ if ("alt_states" not in cell_states_to_model.keys()) \
160
+ or (len(cell_states_to_model["alt_states"]) == 0) \
161
+ or (cell_states_to_model["alt_states"] == [None]):
162
+ alt_end_state_exists = False
163
+ elif (len(cell_states_to_model["alt_states"]) > 0) and (cell_states_to_model["alt_states"] != [None]):
164
+ alt_end_state_exists = True
165
+
166
+ # for single perturbation in multiple cells, there are no random perturbations to compare to
167
+ if genes_perturbed != "all":
168
+ names=["Shift_to_goal_end",
169
+ "Shift_to_alt_end"]
170
+ if alt_end_state_exists == False:
171
+ names.remove("Shift_to_alt_end")
172
+ cos_sims_full_df = pd.DataFrame(columns=names)
173
+
174
+ cos_shift_data = []
175
+ token = cos_sims_df["Gene"][0]
176
+ for dict_i in dict_list:
177
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
178
+ if alt_end_state_exists == False:
179
+ cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end in cos_shift_data]
180
+ if alt_end_state_exists == True:
181
+ cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
182
+ cos_sims_full_df["Shift_to_alt_end"] = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
183
+
184
+ # sort by shift to desired state
185
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_to_goal_end"],
186
+ ascending=[False])
187
+ return cos_sims_full_df
188
+
189
+ elif genes_perturbed == "all":
190
+ random_tuples = []
191
+ for i in trange(cos_sims_df.shape[0]):
192
+ token = cos_sims_df["Gene"][i]
193
+ for dict_i in dict_list:
194
+ random_tuples += dict_i.get((token, "cell_emb"),[])
195
+
196
+ if alt_end_state_exists == False:
197
+ goal_end_random_megalist = [goal_end for start_state,goal_end in random_tuples]
198
+ elif alt_end_state_exists == True:
199
+ goal_end_random_megalist = [goal_end for start_state,goal_end,alt_end in random_tuples]
200
+ alt_end_random_megalist = [alt_end for start_state,goal_end,alt_end in random_tuples]
201
+
202
+ # downsample to improve speed of ranksums
203
+ if len(goal_end_random_megalist) > 100_000:
204
+ random.seed(42)
205
+ goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
206
+ if alt_end_state_exists == True:
207
+ if len(alt_end_random_megalist) > 100_000:
208
+ random.seed(42)
209
+ alt_end_random_megalist = random.sample(alt_end_random_megalist, k=100_000)
210
+
211
+ names=["Gene",
212
+ "Gene_name",
213
+ "Ensembl_ID",
214
+ "Shift_to_goal_end",
215
+ "Shift_to_alt_end",
216
+ "Goal_end_vs_random_pval",
217
+ "Alt_end_vs_random_pval"]
218
+ if alt_end_state_exists == False:
219
+ names.remove("Shift_to_alt_end")
220
+ names.remove("Alt_end_vs_random_pval")
221
+ cos_sims_full_df = pd.DataFrame(columns=names)
222
+
223
+ for i in trange(cos_sims_df.shape[0]):
224
+ token = cos_sims_df["Gene"][i]
225
+ name = cos_sims_df["Gene_name"][i]
226
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
227
+ cos_shift_data = []
228
+
229
+ for dict_i in dict_list:
230
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
231
+
232
+ if alt_end_state_exists == False:
233
+ goal_end_cos_sim_megalist = [goal_end for start_state,goal_end in cos_shift_data]
234
+ elif alt_end_state_exists == True:
235
+ goal_end_cos_sim_megalist = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
236
+ alt_end_cos_sim_megalist = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
237
+ mean_alt_end = np.mean(alt_end_cos_sim_megalist)
238
+ pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
239
+
240
+ mean_goal_end = np.mean(goal_end_cos_sim_megalist)
241
+ pval_goal_end = ranksums(goal_end_random_megalist,goal_end_cos_sim_megalist).pvalue
242
+
243
+ if alt_end_state_exists == False:
244
+ data_i = [token,
245
+ name,
246
+ ensembl_id,
247
+ mean_goal_end,
248
+ pval_goal_end]
249
+ elif alt_end_state_exists == True:
250
+ data_i = [token,
251
+ name,
252
+ ensembl_id,
253
+ mean_goal_end,
254
+ mean_alt_end,
255
+ pval_goal_end,
256
+ pval_alt_end]
257
+
258
+ cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
259
+ cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
260
+
261
+ cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
262
+ if alt_end_state_exists == True:
263
+ cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
264
+
265
+ # quantify number of detections of each gene
266
+ cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
267
+
268
+ # sort by shift to desired state\
269
+ cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]]
270
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
271
+ "Shift_to_goal_end",
272
+ "Goal_end_FDR"],
273
+ ascending=[False,False,True])
274
+
275
+ return cos_sims_full_df
276
+
277
+ # stats comparing cos sim shifts of test perturbations vs null distribution
278
+ def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
279
+ cos_sims_full_df = cos_sims_df.copy()
280
+
281
+ cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
282
+ cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
283
+ cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
284
+ cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
285
+ cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
286
+ cos_sims_full_df["N_Detections_test"] = np.zeros(cos_sims_df.shape[0], dtype="uint32")
287
+ cos_sims_full_df["N_Detections_null"] = np.zeros(cos_sims_df.shape[0], dtype="uint32")
288
+
289
+ for i in trange(cos_sims_df.shape[0]):
290
+ token = cos_sims_df["Gene"][i]
291
+ test_shifts = []
292
+ null_shifts = []
293
+
294
+ for dict_i in dict_list:
295
+ test_shifts += dict_i.get((token, "cell_emb"),[])
296
+
297
+ for dict_i in null_dict_list:
298
+ null_shifts += dict_i.get((token, "cell_emb"),[])
299
+
300
+ cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
301
+ cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
302
+ cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(test_shifts)-np.mean(null_shifts)
303
+ cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(test_shifts,
304
+ null_shifts, nan_policy="omit").pvalue
305
+ # remove nan values
306
+ cos_sims_full_df.Test_vs_null_pval = np.where(np.isnan(cos_sims_full_df.Test_vs_null_pval), 1, cos_sims_full_df.Test_vs_null_pval)
307
+ cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
308
+ cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
309
+
310
+ cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(cos_sims_full_df["Test_vs_null_pval"])
311
+
312
+ cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]]
313
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
314
+ "Test_vs_null_avg_shift",
315
+ "Test_vs_null_FDR"],
316
+ ascending=[False,False,True])
317
+ return cos_sims_full_df
318
+
319
+ # stats for identifying perturbations with largest effect within a given set of cells
320
+ # fits a mixture model to 2 components (impact vs. non-impact) and
321
+ # reports the most likely component for each test perturbation
322
+ # Note: because assumes given perturbation has a consistent effect in the cells tested,
323
+ # we recommend only using the mixture model strategy with uniform cell populations
324
+ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
325
+
326
+ names=["Gene",
327
+ "Gene_name",
328
+ "Ensembl_ID"]
329
+
330
+ if combos == 0:
331
+ names += ["Test_avg_shift"]
332
+ elif combos == 1:
333
+ names += ["Anchor_shift",
334
+ "Test_token_shift",
335
+ "Sum_of_indiv_shifts",
336
+ "Combo_shift",
337
+ "Combo_minus_sum_shift"]
338
+
339
+ names += ["Impact_component",
340
+ "Impact_component_percent"]
341
+
342
+ cos_sims_full_df = pd.DataFrame(columns=names)
343
+ avg_values = []
344
+ gene_names = []
345
+
346
+ for i in trange(cos_sims_df.shape[0]):
347
+ token = cos_sims_df["Gene"][i]
348
+ name = cos_sims_df["Gene_name"][i]
349
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
350
+ cos_shift_data = []
351
+
352
+ for dict_i in dict_list:
353
+ if (combos == 0) and (anchor_token is not None):
354
+ cos_shift_data += dict_i.get((anchor_token, token),[])
355
+ else:
356
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
357
+
358
+ # Extract values for current gene
359
+ if combos == 0:
360
+ test_values = cos_shift_data
361
+ elif combos == 1:
362
+ test_values = []
363
+ for tup in cos_shift_data:
364
+ test_values.append(tup[2])
365
+
366
+ if len(test_values) > 0:
367
+ avg_value = np.mean(test_values)
368
+ avg_values.append(avg_value)
369
+ gene_names.append(name)
370
+
371
+ # fit Gaussian mixture model to dataset of mean for each gene
372
+ avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
373
+ gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
374
+
375
+ for i in trange(cos_sims_df.shape[0]):
376
+ token = cos_sims_df["Gene"][i]
377
+ name = cos_sims_df["Gene_name"][i]
378
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
379
+ cos_shift_data = []
380
+
381
+ for dict_i in dict_list:
382
+ if (combos == 0) and (anchor_token is not None):
383
+ cos_shift_data += dict_i.get((anchor_token, token),[])
384
+ else:
385
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
386
+
387
+ if combos == 0:
388
+ mean_test = np.mean(cos_shift_data)
389
+ impact_components = [get_impact_component(value,gm) for value in cos_shift_data]
390
+ elif combos == 1:
391
+ anchor_cos_sim_megalist = [anchor for anchor,token,combo in cos_shift_data]
392
+ token_cos_sim_megalist = [token for anchor,token,combo in cos_shift_data]
393
+ anchor_plus_token_cos_sim_megalist = [1-((1-anchor)+(1-token)) for anchor,token,combo in cos_shift_data]
394
+ combo_anchor_token_cos_sim_megalist = [combo for anchor,token,combo in cos_shift_data]
395
+ combo_minus_sum_cos_sim_megalist = [combo-(1-((1-anchor)+(1-token))) for anchor,token,combo in cos_shift_data]
396
+
397
+ mean_anchor = np.mean(anchor_cos_sim_megalist)
398
+ mean_token = np.mean(token_cos_sim_megalist)
399
+ mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
400
+ mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
401
+ mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
402
+
403
+ impact_components = [get_impact_component(value,gm) for value in combo_anchor_token_cos_sim_megalist]
404
+
405
+ impact_component = get_impact_component(mean_test,gm)
406
+ impact_component_percent = np.mean(impact_components)*100
407
+
408
+ data_i = [token,
409
+ name,
410
+ ensembl_id]
411
+ if combos == 0:
412
+ data_i += [mean_test]
413
+ elif combos == 1:
414
+ data_i += [mean_anchor,
415
+ mean_token,
416
+ mean_sum,
417
+ mean_test,
418
+ mean_combo_minus_sum]
419
+ data_i += [impact_component,
420
+ impact_component_percent]
421
+
422
+ cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
423
+ cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
424
+
425
+ # quantify number of detections of each gene
426
+ cos_sims_full_df["N_Detections"] = [n_detections(i,
427
+ dict_list,
428
+ "gene",
429
+ anchor_token) for i in cos_sims_full_df["Gene"]]
430
+
431
+ if combos == 0:
432
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
433
+ "Test_avg_shift"],
434
+ ascending=[False,True])
435
+ elif combos == 1:
436
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
437
+ "Combo_minus_sum_shift"],
438
+ ascending=[False,True])
439
+ return cos_sims_full_df
440
+
441
+ class InSilicoPerturberStats:
442
+ valid_option_dict = {
443
+ "mode": {"goal_state_shift","vs_null","mixture_model","aggregate_data"},
444
+ "combos": {0,1},
445
+ "anchor_gene": {None, str},
446
+ "cell_states_to_model": {None, dict},
447
+ "pickle_suffix": {None, str}
448
+ }
449
+ def __init__(
450
+ self,
451
+ mode="mixture_model",
452
+ genes_perturbed="all",
453
+ combos=0,
454
+ anchor_gene=None,
455
+ cell_states_to_model=None,
456
+ pickle_suffix="_raw.pickle",
457
+ token_dictionary_file=TOKEN_DICTIONARY_FILE,
458
+ gene_name_id_dictionary_file=GENE_NAME_ID_DICTIONARY_FILE,
459
+ ):
460
+ """
461
+ Initialize in silico perturber stats generator.
462
+
463
+ Parameters
464
+ ----------
465
+ mode : {"goal_state_shift","vs_null","mixture_model","aggregate_data"}
466
+ Type of stats.
467
+ "goal_state_shift": perturbation vs. random for desired cell state shift
468
+ "vs_null": perturbation vs. null from provided null distribution dataset
469
+ "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
470
+ "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
471
+ genes_perturbed : "all", list
472
+ Genes perturbed in isp experiment.
473
+ Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
474
+ Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
475
+ combos : {0,1,2}
476
+ Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
477
+ anchor_gene : None, str
478
+ ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
479
+ For example, if combos=1 and anchor_gene="ENSG00000136574":
480
+ analyzes data for anchor gene perturbed in combination with each other gene.
481
+ However, if combos=0 and anchor_gene="ENSG00000136574":
482
+ analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
483
+ cell_states_to_model: None, dict
484
+ Cell states to model if testing perturbations that achieve goal state change.
485
+ Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
486
+ state_key: key specifying name of column in .dataset that defines the start/goal states
487
+ start_state: value in the state_key column that specifies the start state
488
+ goal_state: value in the state_key column taht specifies the goal end state
489
+ alt_states: list of values in the state_key column that specify the alternate end states
490
+ For example: {"state_key": "disease",
491
+ "start_state": "dcm",
492
+ "goal_state": "nf",
493
+ "alt_states": ["hcm", "other1", "other2"]}
494
+ token_dictionary_file : Path
495
+ Path to pickle file containing token dictionary (Ensembl ID:token).
496
+ gene_name_id_dictionary_file : Path
497
+ Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID).
498
+ """
499
+
500
+ self.mode = mode
501
+ self.genes_perturbed = genes_perturbed
502
+ self.combos = combos
503
+ self.anchor_gene = anchor_gene
504
+ self.cell_states_to_model = cell_states_to_model
505
+ self.pickle_suffix = pickle_suffix
506
+
507
+ self.validate_options()
508
+
509
+ # load token dictionary (Ensembl IDs:token)
510
+ with open(token_dictionary_file, "rb") as f:
511
+ self.gene_token_dict = pickle.load(f)
512
+
513
+ # load gene name dictionary (gene name:Ensembl ID)
514
+ with open(gene_name_id_dictionary_file, "rb") as f:
515
+ self.gene_name_id_dict = pickle.load(f)
516
+
517
+ if anchor_gene is None:
518
+ self.anchor_token = None
519
+ else:
520
+ self.anchor_token = self.gene_token_dict[self.anchor_gene]
521
+
522
+ def validate_options(self):
523
+ for attr_name,valid_options in self.valid_option_dict.items():
524
+ attr_value = self.__dict__[attr_name]
525
+ if type(attr_value) not in {list, dict}:
526
+ if attr_name in {"anchor_gene"}:
527
+ continue
528
+ elif attr_value in valid_options:
529
+ continue
530
+ valid_type = False
531
+ for option in valid_options:
532
+ # not sure what the last check is for?
533
+ if isinstance(attr_value, option): # and (option in [int,list,dict]):
534
+ valid_type = True
535
+ break
536
+ if not valid_type:
537
+ logger.error(
538
+ f"Invalid option for {attr_name}. " \
539
+ f"Valid options for {attr_name}: {valid_options}"
540
+ )
541
+ raise
542
+
543
+ if self.cell_states_to_model is not None:
544
+ if len(self.cell_states_to_model.items()) == 1:
545
+ logger.warning(
546
+ "The single value dictionary for cell_states_to_model will be " \
547
+ "replaced with a dictionary with named keys for start, goal, and alternate states. " \
548
+ "Please specify state_key, start_state, goal_state, and alt_states " \
549
+ "in the cell_states_to_model dictionary for future use. " \
550
+ "For example, cell_states_to_model={" \
551
+ "'state_key': 'disease', " \
552
+ "'start_state': 'dcm', " \
553
+ "'goal_state': 'nf', " \
554
+ "'alt_states': ['hcm', 'other1', 'other2']}"
555
+ )
556
+ for key,value in self.cell_states_to_model.items():
557
+ if (len(value) == 3) and isinstance(value, tuple):
558
+ if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
559
+ if len(value[0]) == 1 and len(value[1]) == 1:
560
+ all_values = value[0]+value[1]+value[2]
561
+ if len(all_values) == len(set(all_values)):
562
+ continue
563
+ # reformat to the new named key format
564
+ state_values = flatten_list(list(self.cell_states_to_model.values()))
565
+ self.cell_states_to_model = {
566
+ "state_key": list(self.cell_states_to_model.keys())[0],
567
+ "start_state": state_values[0][0],
568
+ "goal_state": state_values[1][0],
569
+ "alt_states": state_values[2:][0]
570
+ }
571
+ elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
572
+ if (self.cell_states_to_model["state_key"] is None) \
573
+ or (self.cell_states_to_model["start_state"] is None) \
574
+ or (self.cell_states_to_model["goal_state"] is None):
575
+ logger.error(
576
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
577
+ raise
578
+
579
+ if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
580
+ logger.error(
581
+ "All states must be unique.")
582
+ raise
583
+
584
+ if self.cell_states_to_model["alt_states"] is not None:
585
+ if type(self.cell_states_to_model["alt_states"]) is not list:
586
+ logger.error(
587
+ "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
588
+ )
589
+ raise
590
+ if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
591
+ logger.error(
592
+ "All states must be unique.")
593
+ raise
594
+
595
+ else:
596
+ logger.error(
597
+ "cell_states_to_model must only have the following four keys: " \
598
+ "'state_key', 'start_state', 'goal_state', 'alt_states'." \
599
+ "For example, cell_states_to_model={" \
600
+ "'state_key': 'disease', " \
601
+ "'start_state': 'dcm', " \
602
+ "'goal_state': 'nf', " \
603
+ "'alt_states': ['hcm', 'other1', 'other2']}"
604
+ )
605
+ raise
606
+
607
+ if self.anchor_gene is not None:
608
+ self.anchor_gene = None
609
+ logger.warning(
610
+ "anchor_gene set to None. " \
611
+ "Currently, anchor gene not available " \
612
+ "when modeling multiple cell states.")
613
+
614
+ if self.combos > 0:
615
+ if self.anchor_gene is None:
616
+ logger.error(
617
+ "Currently, stats are only supported for combination " \
618
+ "in silico perturbation run with anchor gene. Please add " \
619
+ "anchor gene when using with combos > 0. ")
620
+ raise
621
+
622
+ if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
623
+ logger.error(
624
+ "Mixture model mode requires multiple gene perturbations to fit model " \
625
+ "so is incompatible with a single grouped perturbation.")
626
+ raise
627
+ if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
628
+ logger.error(
629
+ "Simple data aggregation mode is for single perturbation in multiple cells " \
630
+ "so is incompatible with a genes_perturbed being 'all'.")
631
+ raise
632
+
633
+ def get_stats(self,
634
+ input_data_directory,
635
+ null_dist_data_directory,
636
+ output_directory,
637
+ output_prefix,
638
+ null_dict_list=None,
639
+ recursive=False):
640
+ """
641
+ Get stats for in silico perturbation data and save as results in output_directory.
642
+
643
+ Parameters
644
+ ----------
645
+ input_data_directory : Path
646
+ Path to directory containing cos_sim dictionary inputs
647
+ null_dist_data_directory : Path
648
+ Path to directory containing null distribution cos_sim dictionary inputs
649
+ output_directory : Path
650
+ Path to directory where perturbation data will be saved as .csv
651
+ output_prefix : str
652
+ Prefix for output .csv
653
+ null_dict_list: dict
654
+ List of loaded null distribtion dictionary if more than one comparison vs. the null is to be performed
655
+
656
+ Outputs
657
+ ----------
658
+ Definition of possible columns in .csv output file.
659
+
660
+ Of note, not all columns will be present in all output files.
661
+ Some columns are specific to particular perturbation modes.
662
+
663
+ "Gene": gene token
664
+ "Gene_name": gene name
665
+ "Ensembl_ID": gene Ensembl ID
666
+ "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
667
+ "Sig": 1 if FDR<0.05, otherwise 0
668
+
669
+ "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
670
+ "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
671
+ "Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
672
+ pvalue compares shift caused by perturbing given gene compared to random genes
673
+ "Alt_end_vs_random_pval": pvalue of cosine shift from start state towards alternate end state by Wilcoxon
674
+ pvalue compares shift caused by perturbing given gene compared to random genes
675
+ "Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
676
+ "Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
677
+
678
+ "Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
679
+ "Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
680
+ "Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
681
+ (i.e. "Test_avg_shift" minus "Null_avg_shift")
682
+ "Test_vs_null_pval": pvalue of cosine shift in test vs. null distribution
683
+ "Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
684
+ "N_Detections_test": "N_Detections" in cells from test distribution
685
+ "N_Detections_null": "N_Detections" in cells from null distribution
686
+
687
+ "Anchor_shift": cosine shift in response to given perturbation of anchor gene
688
+ "Test_token_shift": cosine shift in response to given perturbation of test gene
689
+ "Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
690
+ "Combo_shift": cosine shift in response to given perturbation of both anchor and test gene(s) in combination
691
+ "Combo_minus_sum_shift": difference of cosine shifts in response combo perturbation vs. sum of individual perturbations
692
+ (i.e. "Combo_shift" minus "Sum_of_indiv_shifts")
693
+ "Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
694
+ 1: within impact component; 0: not within impact component
695
+ "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
696
+ """
697
+
698
+ if self.mode not in ["goal_state_shift", "vs_null", "mixture_model","aggregate_data"]:
699
+ logger.error(
700
+ "Currently, only modes available are stats for goal_state_shift, " \
701
+ "vs_null (comparing to null distribution), and " \
702
+ "mixture_model (fitting mixture model for perturbations with or without impact).")
703
+ raise
704
+
705
+ self.gene_token_id_dict = invert_dict(self.gene_token_dict)
706
+ self.gene_id_name_dict = invert_dict(self.gene_name_id_dict)
707
+
708
+ # obtain total gene list
709
+ if (self.combos == 0) and (self.anchor_token is not None):
710
+ # cos sim data for effect of gene perturbation on the embedding of each other gene
711
+ dict_list = read_dictionaries(input_data_directory, "gene", self.anchor_token, self.cell_states_to_model, self.pickle_suffix, recursive=recursive)
712
+ gene_list = get_gene_list(dict_list, "gene")
713
+ else:
714
+ # cos sim data for effect of gene perturbation on the embedding of each cell
715
+ dict_list = read_dictionaries(input_data_directory, "cell", self.anchor_token, self.cell_states_to_model, self.pickle_suffix, recursive=recursive)
716
+ gene_list = get_gene_list(dict_list, "cell")
717
+
718
+ # initiate results dataframe
719
+ cos_sims_df_initial = pd.DataFrame({"Gene": gene_list,
720
+ "Gene_name": [self.token_to_gene_name(item) \
721
+ for item in gene_list],
722
+ "Ensembl_ID": [token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict) \
723
+ if self.genes_perturbed != "all" else \
724
+ self.gene_token_id_dict[genes[1]] \
725
+ if isinstance(genes,tuple) else \
726
+ self.gene_token_id_dict[genes] \
727
+ for genes in gene_list]}, \
728
+ index=[i for i in range(len(gene_list))])
729
+
730
+ if self.mode == "goal_state_shift":
731
+ cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list, self.cell_states_to_model, self.genes_perturbed)
732
+
733
+ elif self.mode == "vs_null":
734
+ if null_dict_list is None:
735
+ null_dict_list = read_dictionaries(null_dist_data_directory, "cell", self.anchor_token, self.cell_states_to_model, self.pickle_suffix)
736
+ cos_sims_df = isp_stats_vs_null(cos_sims_df_initial, dict_list, null_dict_list)
737
+
738
+ elif self.mode == "mixture_model":
739
+ cos_sims_df = isp_stats_mixture_model(cos_sims_df_initial, dict_list, self.combos, self.anchor_token)
740
+
741
+ elif self.mode == "aggregate_data":
742
+ cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
743
+
744
+ # save perturbation stats to output_path
745
+ output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
746
+ cos_sims_df.to_csv(output_path)
747
+
748
+ def token_to_gene_name(self, item):
749
+ if isinstance(item,int):
750
+ return self.gene_id_name_dict.get(self.gene_token_id_dict.get(item, np.nan), np.nan)
751
+ if isinstance(item,tuple):
752
+ return tuple([self.gene_id_name_dict.get(self.gene_token_id_dict.get(i, np.nan), np.nan) for i in item])