Christina Theodoris commited on
Commit
d20ad0a
1 Parent(s): 0637325

Add stats with mixture model to determine whether test perturbation is in impact component

Browse files
geneformer/in_silico_perturber_stats.py CHANGED
@@ -23,6 +23,7 @@ import random
23
  import statsmodels.stats.multitest as smt
24
  from pathlib import Path
25
  from scipy.stats import ranksums
 
26
  from tqdm.notebook import trange
27
 
28
  from .tokenizer import TOKEN_DICTIONARY_FILE
@@ -37,16 +38,23 @@ def invert_dict(dictionary):
37
 
38
  # read raw dictionary files
39
  def read_dictionaries(dir, cell_or_gene_emb):
 
40
  dict_list = []
41
  for file in os.listdir(dir):
42
  # process only _raw.pickle files
43
  if file.endswith("_raw.pickle"):
 
44
  with open(f"{dir}/{file}", "rb") as fp:
45
  cos_sims_dict = pickle.load(fp)
46
  if cell_or_gene_emb == "cell":
47
  cell_emb_dict = {k: v for k,
48
  v in cos_sims_dict.items() if v and "cell_emb" in k}
49
  dict_list += [cell_emb_dict]
 
 
 
 
 
50
  return dict_list
51
 
52
  # get complete gene list
@@ -67,6 +75,21 @@ def n_detections(token, dict_list):
67
  def get_fdr(pvalues):
68
  return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
71
  def isp_stats_to_goal_state(cos_sims_df, dict_list):
72
  random_tuples = []
@@ -102,13 +125,13 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list):
102
  token = cos_sims_df["Gene"][i]
103
  name = cos_sims_df["Gene_name"][i]
104
  ensembl_id = cos_sims_df["Ensembl_ID"][i]
105
- token_tuples = []
106
 
107
  for dict_i in dict_list:
108
- token_tuples += dict_i.get((token, "cell_emb"),[])
109
 
110
- goal_end_cos_sim_megalist = [goal_end for goal_end,alt_end,start_state in token_tuples]
111
- alt_end_cos_sim_megalist = [alt_end for goal_end,alt_end,start_state in token_tuples]
112
 
113
  mean_goal_end = np.mean(goal_end_cos_sim_megalist)
114
  mean_alt_end = np.mean(alt_end_cos_sim_megalist)
@@ -130,6 +153,13 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list):
130
  cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
131
  cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
132
 
 
 
 
 
 
 
 
133
  return cos_sims_full_df
134
 
135
  # stats comparing cos sim shifts of test perturbations vs null distribution
@@ -165,18 +195,134 @@ def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
165
  cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
166
 
167
  cos_sims_full_df["Test_v_null_FDR"] = get_fdr(cos_sims_full_df["Test_v_null_pval"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  return cos_sims_full_df
169
 
170
  class InSilicoPerturberStats:
171
  valid_option_dict = {
172
- "mode": {"goal_state_shift","vs_null","vs_random"},
173
- "combos": {0,1,2},
174
  "anchor_gene": {None, str},
175
  "cell_states_to_model": {None, dict},
176
  }
177
  def __init__(
178
  self,
179
- mode="vs_random",
180
  combos=0,
181
  anchor_gene=None,
182
  cell_states_to_model=None,
@@ -188,11 +334,11 @@ class InSilicoPerturberStats:
188
 
189
  Parameters
190
  ----------
191
- mode : {"goal_state_shift","vs_null","vs_random"}
192
  Type of stats.
193
  "goal_state_shift": perturbation vs. random for desired cell state shift
194
  "vs_null": perturbation vs. null from provided null distribution dataset
195
- "vs_random": perturbation vs. random gene perturbations in that cell (no goal direction)
196
  combos : {0,1,2}
197
  Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
198
  anchor_gene : None, str
@@ -233,7 +379,9 @@ class InSilicoPerturberStats:
233
  for attr_name,valid_options in self.valid_option_dict.items():
234
  attr_value = self.__dict__[attr_name]
235
  if type(attr_value) not in {list, dict}:
236
- if attr_value in valid_options:
 
 
237
  continue
238
  valid_type = False
239
  for option in valid_options:
@@ -271,6 +419,14 @@ class InSilicoPerturberStats:
271
  "anchor_gene set to None. " \
272
  "Currently, anchor gene not available " \
273
  "when modeling multiple cell states.")
 
 
 
 
 
 
 
 
274
 
275
  def get_stats(self,
276
  input_data_directory,
@@ -292,10 +448,11 @@ class InSilicoPerturberStats:
292
  Prefix for output .dataset
293
  """
294
 
295
- if self.mode not in ["goal_state_shift", "vs_null"]:
296
  logger.error(
297
- "Currently, only modes available are stats for goal_state_shift \
298
- and vs_null (comparing to null distribution).")
 
299
  raise
300
 
301
  self.gene_token_id_dict = invert_dict(self.gene_token_dict)
@@ -318,19 +475,12 @@ class InSilicoPerturberStats:
318
  if self.mode == "goal_state_shift":
319
  cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list)
320
 
321
- # quantify number of detections of each gene
322
- cos_sims_df["N_Detections"] = [n_detections(i, dict_list) for i in cos_sims_df["Gene"]]
323
-
324
- # sort by shift to desired state
325
- cos_sims_df = cos_sims_df.sort_values(by=["Shift_from_goal_end",
326
- "Goal_end_FDR"])
327
  elif self.mode == "vs_null":
328
- dict_list = read_dictionaries(input_data_directory, "cell")
329
  null_dict_list = read_dictionaries(null_dist_data_directory, "cell")
330
- cos_sims_df = isp_stats_vs_null(cos_sims_df_initial, dict_list,
331
- null_dict_list)
332
- cos_sims_df = cos_sims_df.sort_values(by=["Test_v_null_avg_shift",
333
- "Test_v_null_FDR"])
334
 
335
  # save perturbation stats to output_path
336
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
 
23
  import statsmodels.stats.multitest as smt
24
  from pathlib import Path
25
  from scipy.stats import ranksums
26
+ from sklearn.mixture import GaussianMixture
27
  from tqdm.notebook import trange
28
 
29
  from .tokenizer import TOKEN_DICTIONARY_FILE
 
38
 
39
  # read raw dictionary files
40
  def read_dictionaries(dir, cell_or_gene_emb):
41
+ file_found = 0
42
  dict_list = []
43
  for file in os.listdir(dir):
44
  # process only _raw.pickle files
45
  if file.endswith("_raw.pickle"):
46
+ file_found = 1
47
  with open(f"{dir}/{file}", "rb") as fp:
48
  cos_sims_dict = pickle.load(fp)
49
  if cell_or_gene_emb == "cell":
50
  cell_emb_dict = {k: v for k,
51
  v in cos_sims_dict.items() if v and "cell_emb" in k}
52
  dict_list += [cell_emb_dict]
53
+ if file_found == 0:
54
+ logger.error(
55
+ "No raw data for processing found within provided directory. " \
56
+ "Please ensure data files end with '_raw.pickle'.")
57
+ raise
58
  return dict_list
59
 
60
  # get complete gene list
 
75
  def get_fdr(pvalues):
76
  return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
77
 
78
+ def get_impact_component(test_value, gaussian_mixture_model):
79
+ impact_border = gaussian_mixture_model.means_[0][0]
80
+ nonimpact_border = gaussian_mixture_model.means_[1][0]
81
+ if test_value > nonimpact_border:
82
+ impact_component = 0
83
+ elif test_value < impact_border:
84
+ impact_component = 1
85
+ else:
86
+ impact_component_raw = gaussian_mixture_model.predict([[test_value]])[0]
87
+ if impact_component_raw == 1:
88
+ impact_component = 0
89
+ elif impact_component_raw == 0:
90
+ impact_component = 1
91
+ return impact_component
92
+
93
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
94
  def isp_stats_to_goal_state(cos_sims_df, dict_list):
95
  random_tuples = []
 
125
  token = cos_sims_df["Gene"][i]
126
  name = cos_sims_df["Gene_name"][i]
127
  ensembl_id = cos_sims_df["Ensembl_ID"][i]
128
+ cos_shift_data = []
129
 
130
  for dict_i in dict_list:
131
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
132
 
133
+ goal_end_cos_sim_megalist = [goal_end for goal_end,alt_end,start_state in cos_shift_data]
134
+ alt_end_cos_sim_megalist = [alt_end for goal_end,alt_end,start_state in cos_shift_data]
135
 
136
  mean_goal_end = np.mean(goal_end_cos_sim_megalist)
137
  mean_alt_end = np.mean(alt_end_cos_sim_megalist)
 
153
  cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
154
  cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
155
 
156
+ # quantify number of detections of each gene
157
+ cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list) for i in cos_sims_full_df["Gene"]]
158
+
159
+ # sort by shift to desired state
160
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_from_goal_end",
161
+ "Goal_end_FDR"])
162
+
163
  return cos_sims_full_df
164
 
165
  # stats comparing cos sim shifts of test perturbations vs null distribution
 
195
  cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
196
 
197
  cos_sims_full_df["Test_v_null_FDR"] = get_fdr(cos_sims_full_df["Test_v_null_pval"])
198
+
199
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Test_v_null_avg_shift",
200
+ "Test_v_null_FDR"])
201
+ return cos_sims_full_df
202
+
203
+ # stats for identifying perturbations with largest effect within a given set of cells
204
+ # fits a mixture model to 2 components (impact vs. non-impact) and
205
+ # reports the most likely component for each test perturbation
206
+ # Note: because assumes given perturbation has a consistent effect in the cells tested,
207
+ # we recommend only using the mixture model strategy with uniform cell populations
208
+ def isp_stats_mixture_model(cos_sims_df, dict_list, combos):
209
+
210
+ names=["Gene",
211
+ "Gene_name",
212
+ "Ensembl_ID"]
213
+
214
+ if combos == 0:
215
+ names += ["Test_avg_shift"]
216
+ elif combos == 1:
217
+ names += ["Anchor_shift",
218
+ "Test_token_shift",
219
+ "Sum_of_indiv_shifts",
220
+ "Combo_shift",
221
+ "Combo_minus_sum_shift"]
222
+
223
+ names += ["Impact_component",
224
+ "Impact_component_percent"]
225
+
226
+ cos_sims_full_df = pd.DataFrame(columns=names)
227
+ avg_values = []
228
+ gene_names = []
229
+
230
+ for i in trange(cos_sims_df.shape[0]):
231
+ token = cos_sims_df["Gene"][i]
232
+ name = cos_sims_df["Gene_name"][i]
233
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
234
+ cos_shift_data = []
235
+
236
+ for dict_i in dict_list:
237
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
238
+
239
+ # Extract values for current gene
240
+ if combos == 0:
241
+ test_values = cos_shift_data
242
+ elif combos == 1:
243
+ test_values = []
244
+ for tup in cos_shift_data:
245
+ test_values.append(tup[2])
246
+
247
+ if len(test_values) > 0:
248
+ avg_value = np.mean(test_values)
249
+ avg_values.append(avg_value)
250
+ gene_names.append(name)
251
+
252
+ # fit Gaussian mixture model to dataset of mean for each gene
253
+ avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
254
+ gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
255
+
256
+ for i in trange(cos_sims_df.shape[0]):
257
+ token = cos_sims_df["Gene"][i]
258
+ name = cos_sims_df["Gene_name"][i]
259
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
260
+ cos_shift_data = []
261
+
262
+ for dict_i in dict_list:
263
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
264
+
265
+ if combos == 0:
266
+ mean_test = np.mean(cos_shift_data)
267
+ impact_components = [get_impact_component(value,gm) for value in cos_shift_data]
268
+ elif combos == 1:
269
+ anchor_cos_sim_megalist = [anchor for anchor,token,combo in cos_shift_data]
270
+ token_cos_sim_megalist = [token for anchor,token,combo in cos_shift_data]
271
+ anchor_plus_token_cos_sim_megalist = [1-((1-anchor)+(1-token)) for anchor,token,combo in cos_shift_data]
272
+ combo_anchor_token_cos_sim_megalist = [combo for anchor,token,combo in cos_shift_data]
273
+ combo_minus_sum_cos_sim_megalist = [combo-(1-((1-anchor)+(1-token))) for anchor,token,combo in cos_shift_data]
274
+
275
+ mean_anchor = np.mean(anchor_cos_sim_megalist)
276
+ mean_token = np.mean(token_cos_sim_megalist)
277
+ mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
278
+ mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
279
+ mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
280
+
281
+ impact_components = [get_impact_component(value,gm) for value in combo_anchor_token_cos_sim_megalist]
282
+
283
+ impact_component = get_impact_component(mean_test,gm)
284
+ impact_component_percent = np.mean(impact_components)*100
285
+
286
+ data_i = [token,
287
+ name,
288
+ ensembl_id]
289
+ if combos == 0:
290
+ data_i += [mean_test]
291
+ elif combos == 1:
292
+ data_i += [mean_anchor,
293
+ mean_token,
294
+ mean_sum,
295
+ mean_test,
296
+ mean_combo_minus_sum]
297
+ data_i += [impact_component,
298
+ impact_component_percent]
299
+
300
+ cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
301
+ cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
302
+
303
+ # quantify number of detections of each gene
304
+ cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list) for i in cos_sims_full_df["Gene"]]
305
+
306
+ if combos == 0:
307
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
308
+ "Test_avg_shift"],
309
+ ascending=[False,True])
310
+ elif combos == 1:
311
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
312
+ "Combo_minus_sum_shift"],
313
+ ascending=[False,True])
314
  return cos_sims_full_df
315
 
316
  class InSilicoPerturberStats:
317
  valid_option_dict = {
318
+ "mode": {"goal_state_shift","vs_null","mixture_model"},
319
+ "combos": {0,1},
320
  "anchor_gene": {None, str},
321
  "cell_states_to_model": {None, dict},
322
  }
323
  def __init__(
324
  self,
325
+ mode="mixture_model",
326
  combos=0,
327
  anchor_gene=None,
328
  cell_states_to_model=None,
 
334
 
335
  Parameters
336
  ----------
337
+ mode : {"goal_state_shift","vs_null","mixture_model"}
338
  Type of stats.
339
  "goal_state_shift": perturbation vs. random for desired cell state shift
340
  "vs_null": perturbation vs. null from provided null distribution dataset
341
+ "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
342
  combos : {0,1,2}
343
  Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
344
  anchor_gene : None, str
 
379
  for attr_name,valid_options in self.valid_option_dict.items():
380
  attr_value = self.__dict__[attr_name]
381
  if type(attr_value) not in {list, dict}:
382
+ if attr_name in {"anchor_gene"}:
383
+ continue
384
+ elif attr_value in valid_options:
385
  continue
386
  valid_type = False
387
  for option in valid_options:
 
419
  "anchor_gene set to None. " \
420
  "Currently, anchor gene not available " \
421
  "when modeling multiple cell states.")
422
+
423
+ if self.combos > 0:
424
+ if self.anchor_gene is None:
425
+ logger.error(
426
+ "Currently, stats are only supported for combination " \
427
+ "in silico perturbation run with anchor gene. Please add " \
428
+ "anchor gene when using with combos > 0. ")
429
+ raise
430
 
431
  def get_stats(self,
432
  input_data_directory,
 
448
  Prefix for output .dataset
449
  """
450
 
451
+ if self.mode not in ["goal_state_shift", "vs_null", "mixture_model"]:
452
  logger.error(
453
+ "Currently, only modes available are stats for goal_state_shift, " \
454
+ "vs_null (comparing to null distribution), and " \
455
+ "mixture_model (fitting mixture model for perturbations with or without impact.")
456
  raise
457
 
458
  self.gene_token_id_dict = invert_dict(self.gene_token_dict)
 
475
  if self.mode == "goal_state_shift":
476
  cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list)
477
 
 
 
 
 
 
 
478
  elif self.mode == "vs_null":
 
479
  null_dict_list = read_dictionaries(null_dist_data_directory, "cell")
480
+ cos_sims_df = isp_stats_vs_null(cos_sims_df_initial, dict_list, null_dict_list)
481
+
482
+ elif self.mode == "mixture_model":
483
+ cos_sims_df = isp_stats_mixture_model(cos_sims_df_initial, dict_list, self.combos)
484
 
485
  # save perturbation stats to output_path
486
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")