Christina Theodoris commited on
Commit
acd253c
·
1 Parent(s): 45b9d69

Update isp to allow modeling single perturbation in multiple cells as batches

Browse files
examples/in_silico_perturbation.ipynb CHANGED
@@ -13,7 +13,7 @@
13
  },
14
  {
15
  "cell_type": "code",
16
- "execution_count": 2,
17
  "id": "67b44366-f255-4415-a865-6a27a8ffcce7",
18
  "metadata": {
19
  "tags": []
@@ -24,21 +24,20 @@
24
  "# deletion in the dilated cardiomyopathy (dcm) state significantly shifts\n",
25
  "# the embedding towards non-failing (nf) state\n",
26
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
27
- " perturb_rank_shift=None,\n",
28
- " genes_to_perturb=\"all\",\n",
29
- " combos=0,\n",
30
- " anchor_gene=None,\n",
31
- " model_type=\"CellClassifier\",\n",
32
- " num_classes=3,\n",
33
- " emb_mode=\"cell\",\n",
34
- " cell_emb_style=\"mean_pool\",\n",
35
- " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
36
- " cell_states_to_model={\"disease\":([\"dcm\"],[\"nf\"],[\"hcm\"])},\n",
37
- " max_ncells=2000,\n",
38
- " emb_layer=0,\n",
39
- " forward_batch_size=400,\n",
40
- " nproc=16,\n",
41
- " save_raw_data=True)"
42
  ]
43
  },
44
  {
@@ -50,22 +49,23 @@
50
  "source": [
51
  "# outputs intermediate files from in silico perturbation\n",
52
  "isp.perturb_data(\"path/to/model\",\n",
53
- " \"path/to/input_data\",\n",
54
- " \"path/to/output_directory\",\n",
55
- " \"output_prefix\")"
56
  ]
57
  },
58
  {
59
  "cell_type": "code",
60
- "execution_count": 2,
61
  "id": "f8aadabb-516a-4dc0-b307-6de880e64e26",
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
65
  "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
66
- " combos=0,\n",
67
- " anchor_gene=None,\n",
68
- " cell_states_to_model={\"disease\":([\"dcm\"],[\"nf\"],[\"hcm\"])})"
 
69
  ]
70
  },
71
  {
 
13
  },
14
  {
15
  "cell_type": "code",
16
+ "execution_count": null,
17
  "id": "67b44366-f255-4415-a865-6a27a8ffcce7",
18
  "metadata": {
19
  "tags": []
 
24
  "# deletion in the dilated cardiomyopathy (dcm) state significantly shifts\n",
25
  "# the embedding towards non-failing (nf) state\n",
26
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
27
+ " perturb_rank_shift=None,\n",
28
+ " genes_to_perturb=\"all\",\n",
29
+ " combos=0,\n",
30
+ " anchor_gene=None,\n",
31
+ " model_type=\"CellClassifier\",\n",
32
+ " num_classes=3,\n",
33
+ " emb_mode=\"cell\",\n",
34
+ " cell_emb_style=\"mean_pool\",\n",
35
+ " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
36
+ " cell_states_to_model={\"disease\":([\"dcm\"],[\"nf\"],[\"hcm\"])},\n",
37
+ " max_ncells=2000,\n",
38
+ " emb_layer=0,\n",
39
+ " forward_batch_size=400,\n",
40
+ " nproc=16)"
 
41
  ]
42
  },
43
  {
 
49
  "source": [
50
  "# outputs intermediate files from in silico perturbation\n",
51
  "isp.perturb_data(\"path/to/model\",\n",
52
+ " \"path/to/input_data\",\n",
53
+ " \"path/to/output_directory\",\n",
54
+ " \"output_prefix\")"
55
  ]
56
  },
57
  {
58
  "cell_type": "code",
59
+ "execution_count": null,
60
  "id": "f8aadabb-516a-4dc0-b307-6de880e64e26",
61
  "metadata": {},
62
  "outputs": [],
63
  "source": [
64
  "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
65
+ " genes_perturbed=\"all\",\n",
66
+ " combos=0,\n",
67
+ " anchor_gene=None,\n",
68
+ " cell_states_to_model={\"disease\":([\"dcm\"],[\"nf\"],[\"hcm\"])})"
69
  ]
70
  },
71
  {
geneformer/in_silico_perturber.py CHANGED
@@ -17,8 +17,7 @@ Usage:
17
  max_ncells=None,
18
  emb_layer=-1,
19
  forward_batch_size=100,
20
- nproc=4,
21
- save_raw_data=False)
22
  isp.perturb_data("path/to/model",
23
  "path/to/input_data",
24
  "path/to/output_directory",
@@ -28,7 +27,9 @@ Usage:
28
  # imports
29
  import itertools as it
30
  import logging
 
31
  import pickle
 
32
  import seaborn as sns; sns.set()
33
  import torch
34
  from collections import defaultdict
@@ -47,9 +48,16 @@ def quant_layers(model):
47
  layer_nums += [int(name.split("layer.")[1].split(".")[0])]
48
  return int(max(layer_nums))+1
49
 
 
 
 
50
  def flatten_list(megalist):
51
  return [item for sublist in megalist for item in sublist]
52
 
 
 
 
 
53
  def forward_pass_single_cell(model, example_cell, layer_to_quant):
54
  example_cell.set_format(type="torch")
55
  input_data = example_cell["input_ids"]
@@ -66,15 +74,16 @@ def perturb_emb_by_index(emb, indices):
66
  mask[indices] = False
67
  return emb[mask]
68
 
69
- def delete_index(example):
70
- indexes = example["perturb_index"]
71
- if len(indexes)>1:
72
- indexes = flatten_list(indexes)
73
- for index in sorted(indexes, reverse=True):
74
  del example["input_ids"][index]
75
  return example
76
 
77
- def overexpress_index(example):
 
78
  indexes = example["perturb_index"]
79
  if len(indexes)>1:
80
  indexes = flatten_list(indexes)
@@ -82,11 +91,19 @@ def overexpress_index(example):
82
  example["input_ids"].insert(0, example["input_ids"].pop(index))
83
  return example
84
 
 
 
 
 
 
 
 
 
85
  def make_perturbation_batch(example_cell,
86
  perturb_type,
87
  tokens_to_perturb,
88
  anchor_token,
89
- combo_lvl,
90
  num_proc):
91
  if tokens_to_perturb == "all":
92
  if perturb_type in ["overexpress","activate"]:
@@ -114,21 +131,38 @@ def make_perturbation_batch(example_cell,
114
  all_indices = [index for index in all_indices if index not in indices_to_perturb]
115
  indices_to_perturb = [[[j for i in indices_to_perturb for j in i], x] for x in all_indices]
116
  length = len(indices_to_perturb)
117
- perturbation_dataset = Dataset.from_dict({"input_ids": example_cell["input_ids"]*length, "perturb_index": indices_to_perturb})
 
118
  if length<400:
119
  num_proc_i = 1
120
  else:
121
  num_proc_i = num_proc
122
  if perturb_type == "delete":
123
- perturbation_dataset = perturbation_dataset.map(delete_index, num_proc=num_proc_i)
124
  elif perturb_type == "overexpress":
125
- perturbation_dataset = perturbation_dataset.map(overexpress_index, num_proc=num_proc_i)
126
  return perturbation_dataset, indices_to_perturb
127
 
128
- # original cell emb removing the respective perturbed gene emb
129
- def make_comparison_batch(original_emb, indices_to_perturb):
 
 
130
  all_embs_list = []
131
- for indices in indices_to_perturb:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  emb_list = []
133
  start = 0
134
  if len(indices)>1 and isinstance(indices[0],list):
@@ -138,28 +172,22 @@ def make_comparison_batch(original_emb, indices_to_perturb):
138
  start = i+1
139
  emb_list += [original_emb[start:]]
140
  all_embs_list += [torch.cat(emb_list)]
 
 
 
 
141
  return torch.stack(all_embs_list)
142
 
143
- # perturbed cell emb removing the activated/overexpressed/inhibited gene emb
144
- # so that only non-perturbed gene embeddings are compared to each other
145
- # in original or perturbed context
146
- def make_perturbed_remainder_batch(emb_batch, indices_to_remove):
147
- if type(indices_to_remove) == int:
148
- indices_to_keep = [i for i in range(emb_batch.size()[1])]
149
- indices_to_keep.pop(indices_to_remove)
150
- perturbed_remainder_batch = torch.stack([emb[indices_to_keep,:] for emb in emb_batch])
151
- elif type(indices_to_remove) == list:
152
- perturbed_remainder_batch = torch.stack([make_comparison_batch(emb_batch[i],indices_to_remove[i]) for i in range(len(emb_batch))])
153
- return perturbed_remainder_batch
154
-
155
  # average embedding position of goal cell states
156
  def get_cell_state_avg_embs(model,
157
  filtered_input_data,
158
  cell_states_to_model,
159
  layer_to_quant,
160
- token_dictionary,
161
  forward_batch_size,
162
  num_proc):
 
 
163
  possible_states = [value[0]+value[1]+value[2] for value in cell_states_to_model.values()][0]
164
  state_embs_dict = dict()
165
  for possible_state in possible_states:
@@ -179,7 +207,10 @@ def get_cell_state_avg_embs(model,
179
  state_minibatch.set_format(type="torch")
180
 
181
  input_data_minibatch = state_minibatch["input_ids"]
182
- input_data_minibatch = pad_tensor_list(input_data_minibatch, max_len, token_dictionary)
 
 
 
183
 
184
  with torch.no_grad():
185
  outputs = model(
@@ -204,51 +235,131 @@ def quant_cos_sims(model,
204
  perturbation_batch,
205
  forward_batch_size,
206
  layer_to_quant,
207
- original_emb,
 
208
  indices_to_perturb,
 
209
  cell_states_to_model,
210
- state_embs_dict):
 
 
 
 
211
  cos = torch.nn.CosineSimilarity(dim=2)
212
  total_batch_length = len(perturbation_batch)
213
  if ((total_batch_length-1)/forward_batch_size).is_integer():
214
  forward_batch_size = forward_batch_size-1
215
  if cell_states_to_model is None:
216
- comparison_batch = make_comparison_batch(original_emb, indices_to_perturb)
 
217
  cos_sims = []
218
  else:
219
  possible_states = [value[0]+value[1]+value[2] for value in cell_states_to_model.values()][0]
220
  cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
 
 
 
 
 
 
221
  for i in range(0, total_batch_length, forward_batch_size):
222
  max_range = min(i+forward_batch_size, total_batch_length)
223
 
224
  perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  perturbation_minibatch.set_format(type="torch")
226
 
227
  input_data_minibatch = perturbation_minibatch["input_ids"]
228
-
 
229
  with torch.no_grad():
230
  outputs = model(
231
  input_ids = input_data_minibatch.to("cuda")
232
  )
233
  del input_data_minibatch
234
  del perturbation_minibatch
235
- # cosine similarity between original emb and batch items
236
  if len(indices_to_perturb)>1:
237
  minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
238
  else:
239
  minibatch_emb = outputs.hidden_states[layer_to_quant]
240
- if cell_states_to_model is None:
241
- minibatch_comparison = comparison_batch[i:max_range]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  if perturb_type == "overexpress":
243
- index_to_remove = 0
244
- minibatch_emb = make_perturbed_remainder_batch(minibatch_emb, index_to_remove)
245
- # elif (perturb_type == "inhibit") or (perturb_type == "activate"):
246
- # index_to_remove = placeholder
247
- # minibatch_emb = make_perturbed_remainder_batch(minibatch_emb, index_to_remove)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
249
  elif cell_states_to_model is not None:
250
  for state in possible_states:
251
- cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb, minibatch_emb, state_embs_dict[state])
 
 
 
 
 
 
 
 
 
252
  del outputs
253
  del minibatch_emb
254
  if cell_states_to_model is None:
@@ -263,17 +374,55 @@ def quant_cos_sims(model,
263
  return cos_sims_vs_alt_dict
264
 
265
  # calculate cos sim shift of perturbation with respect to origin and alternative cell
266
- def cos_sim_shift(original_emb, minibatch_emb, alt_emb):
267
  cos = torch.nn.CosineSimilarity(dim=2)
268
- original_emb = torch.mean(original_emb,dim=0,keepdim=True)[None, :]
 
 
269
  origin_v_end = cos(original_emb,alt_emb)
270
- perturb_v_end = cos(torch.mean(minibatch_emb,dim=1,keepdim=True),alt_emb)
 
271
  return [(perturb_v_end-origin_v_end).to("cpu")]
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  # pad list of tensors and convert to tensor
274
- def pad_tensor_list(tensor_list, dynamic_or_constant, token_dictionary):
275
-
276
- pad_token_id = token_dictionary.get("<pad>")
277
 
278
  # Determine maximum tensor length
279
  if dynamic_or_constant == "dynamic":
@@ -281,15 +430,13 @@ def pad_tensor_list(tensor_list, dynamic_or_constant, token_dictionary):
281
  elif type(dynamic_or_constant) == int:
282
  max_len = dynamic_or_constant
283
  else:
 
284
  logger.warning(
285
  "If padding style is constant, must provide integer value. " \
286
- "Setting padding to max input size 2048.")
287
 
288
  # pad all tensors to maximum length
289
- tensor_list = [torch.nn.functional.pad(tensor, pad=(0,
290
- max_len - tensor.numel()),
291
- mode='constant',
292
- value=pad_token_id) for tensor in tensor_list]
293
 
294
  # return stacked tensors
295
  return torch.stack(tensor_list)
@@ -299,7 +446,7 @@ class InSilicoPerturber:
299
  "perturb_type": {"delete","overexpress","inhibit","activate"},
300
  "perturb_rank_shift": {None, 1, 2, 3},
301
  "genes_to_perturb": {"all", list},
302
- "combos": {0, 1, 2},
303
  "anchor_gene": {None, str},
304
  "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
305
  "num_classes": {int},
@@ -311,7 +458,6 @@ class InSilicoPerturber:
311
  "emb_layer": {-1, 0},
312
  "forward_batch_size": {int},
313
  "nproc": {int},
314
- "save_raw_data": {False, True},
315
  }
316
  def __init__(
317
  self,
@@ -330,7 +476,6 @@ class InSilicoPerturber:
330
  emb_layer=-1,
331
  forward_batch_size=100,
332
  nproc=4,
333
- save_raw_data=False,
334
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
335
  ):
336
  """
@@ -358,8 +503,10 @@ class InSilicoPerturber:
358
  genes_to_perturb : "all", list
359
  Default is perturbing each gene detected in each cell in the dataset.
360
  Otherwise, may provide a list of ENSEMBL IDs of genes to perturb.
361
- combos : {0,1,2}
362
- Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
 
 
363
  anchor_gene : None, str
364
  ENSEMBL ID of gene to use as anchor in combination perturbations.
365
  For example, if combos=1 and anchor_gene="ENSG00000148400":
@@ -393,8 +540,6 @@ class InSilicoPerturber:
393
  Batch size for forward pass.
394
  nproc : int
395
  Number of CPU processes to use.
396
- save_raw_data: {False,True}
397
- Whether to save raw perturbation data for each gene/cell.
398
  token_dictionary_file : Path
399
  Path to pickle file containing token dictionary (Ensembl ID:token).
400
  """
@@ -404,6 +549,18 @@ class InSilicoPerturber:
404
  self.genes_to_perturb = genes_to_perturb
405
  self.combos = combos
406
  self.anchor_gene = anchor_gene
 
 
 
 
 
 
 
 
 
 
 
 
407
  self.model_type = model_type
408
  self.num_classes = num_classes
409
  self.emb_mode = emb_mode
@@ -414,7 +571,6 @@ class InSilicoPerturber:
414
  self.emb_layer = emb_layer
415
  self.forward_batch_size = forward_batch_size
416
  self.nproc = nproc
417
- self.save_raw_data = save_raw_data
418
 
419
  self.validate_options()
420
 
@@ -422,22 +578,39 @@ class InSilicoPerturber:
422
  with open(token_dictionary_file, "rb") as f:
423
  self.gene_token_dict = pickle.load(f)
424
 
425
- if anchor_gene is None:
 
 
426
  self.anchor_token = None
427
  else:
428
- self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
 
 
 
 
 
 
429
 
430
- if genes_to_perturb == "all":
431
  self.tokens_to_perturb = "all"
432
  else:
433
- self.tokens_to_perturb = [self.gene_token_dict[gene] for gene in self.genes_to_perturb]
 
 
 
 
 
 
 
 
 
434
 
435
  def validate_options(self):
436
  # first disallow options under development
437
  if self.perturb_type in ["inhibit", "activate"]:
438
  logger.error(
439
- f"In silico inhibition and activation currently under developemnt. " \
440
- f"Current valid options for 'perturb_type': 'delete' or 'overexpress'"
441
  )
442
  raise
443
 
@@ -462,7 +635,7 @@ class InSilicoPerturber:
462
  f"Valid options for {attr_name}: {valid_options}"
463
  )
464
  raise
465
-
466
  if self.perturb_type in ["delete","overexpress"]:
467
  if self.perturb_rank_shift is not None:
468
  if self.perturb_type == "delete":
@@ -538,9 +711,9 @@ class InSilicoPerturber:
538
  input_data_file : Path
539
  Path to directory containing .dataset inputs
540
  output_directory : Path
541
- Path to directory where perturbation data will be saved as .csv
542
  output_prefix : str
543
- Prefix for output .dataset
544
  """
545
 
546
  filtered_input_data = self.load_and_filter(input_data_file)
@@ -555,7 +728,7 @@ class InSilicoPerturber:
555
  filtered_input_data,
556
  self.cell_states_to_model,
557
  layer_to_quant,
558
- self.gene_token_dict,
559
  self.forward_batch_size,
560
  self.nproc)
561
  # filter for start state cells
@@ -571,13 +744,6 @@ class InSilicoPerturber:
571
  state_embs_dict,
572
  output_directory,
573
  output_prefix)
574
-
575
- # if self.save_raw_data is False:
576
- # # delete intermediate dictionaries
577
- # output_dir = os.listdir(output_directory)
578
- # for output_file in output_dir:
579
- # if output_file.endswith("_raw.pickle"):
580
- # os.remove(os.path.join(output_directory, output_file))
581
 
582
  # load data and filter by defined criteria
583
  def load_and_filter(self, input_data_file):
@@ -632,6 +798,7 @@ class InSilicoPerturber:
632
  output_prefix):
633
 
634
  output_path_prefix = f"{output_directory}in_silico_{self.perturb_type}_{output_prefix}_dict_1Kbatch"
 
635
 
636
  # filter dataset for cells that have tokens to be perturbed
637
  if self.anchor_token is not None:
@@ -639,183 +806,290 @@ class InSilicoPerturber:
639
  return (len(set(example["input_ids"]).intersection(self.anchor_token))==len(self.anchor_token))
640
  filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
641
  logger.info(f"# cells with anchor gene: {len(filtered_input_data)}")
642
- if self.tokens_to_perturb != "all":
 
 
643
  def if_has_tokens_to_perturb(example):
644
- return (len(set(example["input_ids"]).intersection(self.tokens_to_perturb))>self.combos)
645
  filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
646
 
647
  cos_sims_dict = defaultdict(list)
648
  pickle_batch = -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
 
650
- for i in trange(len(filtered_input_data)):
651
- example_cell = filtered_input_data.select([i])
652
- original_emb = forward_pass_single_cell(model, example_cell, layer_to_quant)
653
- gene_list = torch.squeeze(example_cell["input_ids"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
 
655
- # reset to original type to prevent downstream issues due to forward_pass_single_cell modifying as torch format in place
656
- example_cell = filtered_input_data.select([i])
 
 
 
657
 
658
- if self.anchor_token is None:
659
- for combo_lvl in range(self.combos+1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
  perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
661
- self.perturb_type,
662
- self.tokens_to_perturb,
663
- self.anchor_token,
664
- combo_lvl,
665
- self.nproc)
666
  cos_sims_data = quant_cos_sims(model,
667
  self.perturb_type,
668
- perturbation_batch,
669
- self.forward_batch_size,
670
- layer_to_quant,
671
- original_emb,
 
672
  indices_to_perturb,
 
673
  self.cell_states_to_model,
674
- state_embs_dict)
675
-
676
- if self.cell_states_to_model is None:
677
- # update cos sims dict
678
- # key is tuple of (perturbed_gene, affected_gene)
679
- # or (perturbed_gene, "cell_emb") for avg cell emb change
680
- cos_sims_data = cos_sims_data.to("cuda")
681
- for j in range(cos_sims_data.shape[0]):
682
- if self.genes_to_perturb != "all":
683
- j_index = torch.tensor(indices_to_perturb[j])
684
- if j_index.shape[0]>1:
685
- j_index = torch.squeeze(j_index)
686
- else:
687
- j_index = torch.tensor([j])
688
- perturbed_gene = torch.index_select(gene_list, 0, j_index)
689
-
690
- if perturbed_gene.shape[0]==1:
691
- perturbed_gene = perturbed_gene.item()
692
- elif perturbed_gene.shape[0]>1:
693
- perturbed_gene = tuple(perturbed_gene.tolist())
694
-
695
- cell_cos_sim = torch.mean(cos_sims_data[j]).item()
696
- cos_sims_dict[(perturbed_gene, "cell_emb")] += [cell_cos_sim]
697
-
698
- # not_j_index = list(set(i for i in range(gene_list.shape[0])).difference(j_index))
699
- # gene_list_j = torch.index_select(gene_list, 0, j_index)
700
- if self.emb_mode == "cell_and_gene":
701
- for k in range(cos_sims_data.shape[1]):
702
- cos_sim_value = cos_sims_data[j][k]
703
- affected_gene = gene_list[k].item()
704
- cos_sims_dict[(perturbed_gene, affected_gene)] += [cos_sim_value.item()]
705
- else:
706
- # update cos sims dict
707
- # key is tuple of (perturbed_gene, "cell_emb")
708
- # value is list of tuples of cos sims for cell_states_to_model
709
- origin_state_key = [value[0] for value in self.cell_states_to_model.values()][0][0]
710
- cos_sims_origin = cos_sims_data[origin_state_key]
711
-
712
- for j in range(cos_sims_origin.shape[0]):
713
- if (self.genes_to_perturb != "all") or (combo_lvl>0):
714
- j_index = torch.tensor(indices_to_perturb[j])
715
- if j_index.shape[0]>1:
716
- j_index = torch.squeeze(j_index)
717
- else:
718
- j_index = torch.tensor([j])
719
- perturbed_gene = torch.index_select(gene_list, 0, j_index)
720
-
721
- if perturbed_gene.shape[0]==1:
722
- perturbed_gene = perturbed_gene.item()
723
- elif perturbed_gene.shape[0]>1:
724
- perturbed_gene = tuple(perturbed_gene.tolist())
725
-
726
- data_list = []
727
- for data in list(cos_sims_data.values()):
728
- data_item = data.to("cuda")
729
- cell_data = torch.mean(data_item[j]).item()
730
- data_list += [cell_data]
731
- cos_sims_dict[(perturbed_gene, "cell_emb")] += [tuple(data_list)]
732
-
733
- elif self.anchor_token is not None:
734
- perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
735
- self.perturb_type,
736
- self.tokens_to_perturb,
737
- None, # first run without anchor token to test individual gene perturbations
738
- 0,
739
- self.nproc)
740
- cos_sims_data = quant_cos_sims(model,
741
- self.perturb_type,
742
- perturbation_batch,
743
- self.forward_batch_size,
744
- layer_to_quant,
745
- original_emb,
746
- indices_to_perturb,
747
- self.cell_states_to_model,
748
- state_embs_dict)
749
- cos_sims_data = cos_sims_data.to("cuda")
750
 
751
- combo_perturbation_batch, combo_indices_to_perturb = make_perturbation_batch(example_cell,
752
- self.perturb_type,
753
- self.tokens_to_perturb,
754
- self.anchor_token,
755
- 1,
756
- self.nproc)
757
- combo_cos_sims_data = quant_cos_sims(model,
758
- self.perturb_type,
759
- combo_perturbation_batch,
760
- self.forward_batch_size,
761
- layer_to_quant,
762
- original_emb,
763
- combo_indices_to_perturb,
764
- self.cell_states_to_model,
765
- state_embs_dict)
766
- combo_cos_sims_data = combo_cos_sims_data.to("cuda")
 
 
 
 
 
767
 
768
- # update cos sims dict
769
- # key is tuple of (perturbed_gene, "cell_emb") for avg cell emb change
770
- anchor_index = example_cell["input_ids"][0].index(self.anchor_token[0])
771
- anchor_cell_cos_sim = torch.mean(cos_sims_data[anchor_index]).item()
772
- non_anchor_indices = [k for k in range(cos_sims_data.shape[0]) if k != anchor_index]
773
- cos_sims_data = cos_sims_data[non_anchor_indices,:]
774
 
775
- for j in range(cos_sims_data.shape[0]):
776
 
777
- if j<anchor_index:
778
- j_index = torch.tensor([j])
779
- else:
780
- j_index = torch.tensor([j+1])
781
 
782
- perturbed_gene = torch.index_select(gene_list, 0, j_index)
783
- perturbed_gene = perturbed_gene.item()
784
 
785
- cell_cos_sim = torch.mean(cos_sims_data[j]).item()
786
- combo_cos_sim = torch.mean(combo_cos_sims_data[j]).item()
787
- cos_sims_dict[(perturbed_gene, "cell_emb")] += [(anchor_cell_cos_sim, # cos sim anchor gene alone
788
- cell_cos_sim, # cos sim deleted gene alone
789
- combo_cos_sim)] # cos sim anchor gene + deleted gene
790
-
791
- # save dict to disk every 100 cells
792
- if (i/100).is_integer():
793
- with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
794
- pickle.dump(cos_sims_dict, fp)
795
- # reset and clear memory every 1000 cells
796
- if (i/1000).is_integer():
797
- pickle_batch = pickle_batch+1
798
- # clear memory
799
- del perturbed_gene
800
- del cos_sims_data
801
- if self.cell_states_to_model is None:
802
- del cell_cos_sim
803
- if self.cell_states_to_model is not None:
804
- del cell_data
805
- del data_list
806
- elif self.anchor_token is None:
807
- if self.emb_mode == "cell_and_gene":
808
- del affected_gene
809
- del cos_sim_value
810
- else:
811
- del combo_cos_sim
812
- del combo_cos_sims_data
813
- # reset dict
814
- del cos_sims_dict
815
- cos_sims_dict = defaultdict(list)
816
- torch.cuda.empty_cache()
817
-
818
- # save remainder cells
819
- with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
820
- pickle.dump(cos_sims_dict, fp)
821
 
 
17
  max_ncells=None,
18
  emb_layer=-1,
19
  forward_batch_size=100,
20
+ nproc=4)
 
21
  isp.perturb_data("path/to/model",
22
  "path/to/input_data",
23
  "path/to/output_directory",
 
27
  # imports
28
  import itertools as it
29
  import logging
30
+ import numpy as np
31
  import pickle
32
+ import re
33
  import seaborn as sns; sns.set()
34
  import torch
35
  from collections import defaultdict
 
48
  layer_nums += [int(name.split("layer.")[1].split(".")[0])]
49
  return int(max(layer_nums))+1
50
 
51
+ def get_model_input_size(model):
52
+ return int(re.split("\(|,",str(model.bert.embeddings.position_embeddings))[1])
53
+
54
  def flatten_list(megalist):
55
  return [item for sublist in megalist for item in sublist]
56
 
57
+ def measure_length(example):
58
+ example["length"] = len(example["input_ids"])
59
+ return example
60
+
61
  def forward_pass_single_cell(model, example_cell, layer_to_quant):
62
  example_cell.set_format(type="torch")
63
  input_data = example_cell["input_ids"]
 
74
  mask[indices] = False
75
  return emb[mask]
76
 
77
+ def delete_indices(example):
78
+ indices = example["perturb_index"]
79
+ if len(indices)>1:
80
+ indices = flatten_list(indices)
81
+ for index in sorted(indices, reverse=True):
82
  del example["input_ids"][index]
83
  return example
84
 
85
+ # for genes_to_perturb = "all" where only genes within cell are overexpressed
86
+ def overexpress_indices(example):
87
  indexes = example["perturb_index"]
88
  if len(indexes)>1:
89
  indexes = flatten_list(indexes)
 
91
  example["input_ids"].insert(0, example["input_ids"].pop(index))
92
  return example
93
 
94
+ # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
95
+ def overexpress_tokens(example):
96
+ # -100 indicates tokens to overexpress are not present in rank value encoding
97
+ if example["perturb_index"] != [-100]:
98
+ example = delete_indices(example)
99
+ [example["input_ids"].insert(0, token) for token in example["tokens_to_perturb"][::-1]]
100
+ return example
101
+
102
  def make_perturbation_batch(example_cell,
103
  perturb_type,
104
  tokens_to_perturb,
105
  anchor_token,
106
+ combo_lvl,
107
  num_proc):
108
  if tokens_to_perturb == "all":
109
  if perturb_type in ["overexpress","activate"]:
 
131
  all_indices = [index for index in all_indices if index not in indices_to_perturb]
132
  indices_to_perturb = [[[j for i in indices_to_perturb for j in i], x] for x in all_indices]
133
  length = len(indices_to_perturb)
134
+ perturbation_dataset = Dataset.from_dict({"input_ids": example_cell["input_ids"]*length,
135
+ "perturb_index": indices_to_perturb})
136
  if length<400:
137
  num_proc_i = 1
138
  else:
139
  num_proc_i = num_proc
140
  if perturb_type == "delete":
141
+ perturbation_dataset = perturbation_dataset.map(delete_indices, num_proc=num_proc_i)
142
  elif perturb_type == "overexpress":
143
+ perturbation_dataset = perturbation_dataset.map(overexpress_indices, num_proc=num_proc_i)
144
  return perturbation_dataset, indices_to_perturb
145
 
146
+ # perturbed cell emb removing the activated/overexpressed/inhibited gene emb
147
+ # so that only non-perturbed gene embeddings are compared to each other
148
+ # in original or perturbed context
149
+ def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
150
  all_embs_list = []
151
+
152
+ # if making comparison batch for multiple perturbations in single cell
153
+ if perturb_group == False:
154
+ original_emb_list = [original_emb_batch]*len(indices_to_perturb)
155
+ # if making comparison batch for single perturbation in multiple cells
156
+ elif perturb_group == True:
157
+ original_emb_list = original_emb_batch
158
+
159
+
160
+ for i in range(len(original_emb_list)):
161
+ original_emb = original_emb_list[i]
162
+ indices = indices_to_perturb[i]
163
+ if indices == [-100]:
164
+ all_embs_list += [original_emb[:]]
165
+ continue
166
  emb_list = []
167
  start = 0
168
  if len(indices)>1 and isinstance(indices[0],list):
 
172
  start = i+1
173
  emb_list += [original_emb[start:]]
174
  all_embs_list += [torch.cat(emb_list)]
175
+ len_set = set([emb.size()[0] for emb in all_embs_list])
176
+ if len(len_set) > 1:
177
+ max_len = max(len_set)
178
+ all_embs_list = [pad_2d_tensor(emb, None, max_len, 0) for emb in all_embs_list]
179
  return torch.stack(all_embs_list)
180
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  # average embedding position of goal cell states
182
  def get_cell_state_avg_embs(model,
183
  filtered_input_data,
184
  cell_states_to_model,
185
  layer_to_quant,
186
+ pad_token_id,
187
  forward_batch_size,
188
  num_proc):
189
+
190
+ model_input_size = get_model_input_size(model)
191
  possible_states = [value[0]+value[1]+value[2] for value in cell_states_to_model.values()][0]
192
  state_embs_dict = dict()
193
  for possible_state in possible_states:
 
207
  state_minibatch.set_format(type="torch")
208
 
209
  input_data_minibatch = state_minibatch["input_ids"]
210
+ input_data_minibatch = pad_tensor_list(input_data_minibatch,
211
+ max_len,
212
+ pad_token_id,
213
+ model_input_size)
214
 
215
  with torch.no_grad():
216
  outputs = model(
 
235
  perturbation_batch,
236
  forward_batch_size,
237
  layer_to_quant,
238
+ original_emb,
239
+ tokens_to_perturb,
240
  indices_to_perturb,
241
+ perturb_group,
242
  cell_states_to_model,
243
+ state_embs_dict,
244
+ pad_token_id,
245
+ model_input_size,
246
+ nproc):
247
+
248
  cos = torch.nn.CosineSimilarity(dim=2)
249
  total_batch_length = len(perturbation_batch)
250
  if ((total_batch_length-1)/forward_batch_size).is_integer():
251
  forward_batch_size = forward_batch_size-1
252
  if cell_states_to_model is None:
253
+ if perturb_group == False: # (if perturb_group is True, original_emb is filtered_input_data)
254
+ comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
255
  cos_sims = []
256
  else:
257
  possible_states = [value[0]+value[1]+value[2] for value in cell_states_to_model.values()][0]
258
  cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
259
+
260
+ # measure length of each element in perturbation_batch
261
+ perturbation_batch = perturbation_batch.map(
262
+ measure_length, num_proc=nproc
263
+ )
264
+
265
  for i in range(0, total_batch_length, forward_batch_size):
266
  max_range = min(i+forward_batch_size, total_batch_length)
267
 
268
  perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
269
+
270
+ # determine if need to pad or truncate batch
271
+ minibatch_length_set = set(perturbation_minibatch["length"])
272
+ if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > model_input_size):
273
+ needs_pad_or_trunc = True
274
+ else:
275
+ needs_pad_or_trunc = False
276
+
277
+ if needs_pad_or_trunc == True:
278
+ max_len = min(max(minibatch_length_set),model_input_size)
279
+ def pad_or_trunc_example(example):
280
+ example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
281
+ pad_token_id,
282
+ max_len)
283
+ return example
284
+ perturbation_minibatch = perturbation_minibatch.map(pad_or_trunc_example, num_proc=nproc)
285
  perturbation_minibatch.set_format(type="torch")
286
 
287
  input_data_minibatch = perturbation_minibatch["input_ids"]
288
+
289
+ # extract embeddings for perturbation minibatch
290
  with torch.no_grad():
291
  outputs = model(
292
  input_ids = input_data_minibatch.to("cuda")
293
  )
294
  del input_data_minibatch
295
  del perturbation_minibatch
296
+
297
  if len(indices_to_perturb)>1:
298
  minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
299
  else:
300
  minibatch_emb = outputs.hidden_states[layer_to_quant]
301
+
302
+ if perturb_type == "overexpress":
303
+ # remove overexpressed genes to quantify effect on remaining genes
304
+ if perturb_group == False:
305
+ overexpressed_to_remove = 1
306
+ if perturb_group == True:
307
+ overexpressed_to_remove = len(tokens_to_perturb)
308
+ minibatch_emb = minibatch_emb[:,overexpressed_to_remove:,:]
309
+
310
+ # if quantifying single perturbation in multiple different cells, pad original batch and extract embs
311
+ if perturb_group == True:
312
+ # pad minibatch of original batch to extract embeddings
313
+ # truncate to the (model input size - # tokens to overexpress) to ensure comparability
314
+ # since max input size of perturb batch will be reduced by # tokens to overexpress
315
+ original_minibatch = original_emb.select([i for i in range(i, max_range)])
316
+ original_minibatch_length_set = set(original_minibatch["length"])
317
  if perturb_type == "overexpress":
318
+ new_max_len = model_input_size - len(tokens_to_perturb)
319
+ else:
320
+ new_max_len = model_input_size
321
+ if (len(original_minibatch_length_set) > 1) or (max(original_minibatch_length_set) > new_max_len):
322
+ original_max_len = min(max(original_minibatch_length_set),new_max_len)
323
+ def pad_or_trunc_example(example):
324
+ example["input_ids"] = pad_or_truncate_encoding(example["input_ids"], pad_token_id, original_max_len)
325
+ return example
326
+ original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
327
+ original_minibatch.set_format(type="torch")
328
+ original_input_data_minibatch = original_minibatch["input_ids"]
329
+ # extract embeddings for original minibatch
330
+ with torch.no_grad():
331
+ original_outputs = model(
332
+ input_ids = original_input_data_minibatch.to("cuda")
333
+ )
334
+ del original_input_data_minibatch
335
+ del original_minibatch
336
+
337
+ if len(indices_to_perturb)>1:
338
+ original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
339
+ else:
340
+ original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
341
+
342
+ # cosine similarity between original emb and batch items
343
+ if cell_states_to_model is None:
344
+ if perturb_group == False:
345
+ minibatch_comparison = comparison_batch[i:max_range]
346
+ elif perturb_group == True:
347
+ minibatch_comparison = make_comparison_batch(original_minibatch_emb,
348
+ indices_to_perturb,
349
+ perturb_group)
350
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
351
  elif cell_states_to_model is not None:
352
  for state in possible_states:
353
+ if perturb_group == False:
354
+ cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb,
355
+ minibatch_emb,
356
+ state_embs_dict[state],
357
+ perturb_group)
358
+ elif perturb_group == True:
359
+ cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
360
+ minibatch_emb,
361
+ state_embs_dict[state],
362
+ perturb_group)
363
  del outputs
364
  del minibatch_emb
365
  if cell_states_to_model is None:
 
374
  return cos_sims_vs_alt_dict
375
 
376
  # calculate cos sim shift of perturbation with respect to origin and alternative cell
377
+ def cos_sim_shift(original_emb, minibatch_emb, alt_emb, perturb_group):
378
  cos = torch.nn.CosineSimilarity(dim=2)
379
+ original_emb = torch.mean(original_emb,dim=0,keepdim=True)
380
+ if perturb_group == False:
381
+ original_emb = original_emb[None, :]
382
  origin_v_end = cos(original_emb,alt_emb)
383
+ perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
384
+ perturb_v_end = cos(perturb_emb,alt_emb)
385
  return [(perturb_v_end-origin_v_end).to("cpu")]
386
 
387
+ def pad_list(input_ids, pad_token_id, max_len):
388
+ input_ids = np.pad(input_ids,
389
+ (0, max_len-len(input_ids)),
390
+ mode='constant', constant_values=pad_token_id)
391
+ return input_ids
392
+
393
+ def pad_tensor(tensor, pad_token_id, max_len):
394
+ tensor = torch.nn.functional.pad(tensor, pad=(0,
395
+ max_len - tensor.numel()),
396
+ mode='constant',
397
+ value=pad_token_id)
398
+ return tensor
399
+
400
+ def pad_2d_tensor(tensor, pad_token_id, max_len, dim):
401
+ if dim == 0:
402
+ pad = (0, 0, 0, max_len - tensor.size()[dim])
403
+ elif dim == 1:
404
+ pad = (0, max_len - tensor.size()[dim], 0, 0)
405
+ tensor = torch.nn.functional.pad(tensor, pad=pad,
406
+ mode='constant',
407
+ value=pad_token_id)
408
+ return tensor
409
+
410
+ def pad_or_truncate_encoding(encoding, pad_token_id, max_len):
411
+ if isinstance(encoding, torch.Tensor):
412
+ encoding_len = tensor.size()[0]
413
+ elif isinstance(encoding, list):
414
+ encoding_len = len(encoding)
415
+ if encoding_len > max_len:
416
+ encoding = encoding[0:max_len]
417
+ elif encoding_len < max_len:
418
+ if isinstance(encoding, torch.Tensor):
419
+ encoding = pad_tensor(encoding, pad_token_id, max_len)
420
+ elif isinstance(encoding, list):
421
+ encoding = pad_list(encoding, pad_token_id, max_len)
422
+ return encoding
423
+
424
  # pad list of tensors and convert to tensor
425
+ def pad_tensor_list(tensor_list, dynamic_or_constant, pad_token_id, model_input_size):
 
 
426
 
427
  # Determine maximum tensor length
428
  if dynamic_or_constant == "dynamic":
 
430
  elif type(dynamic_or_constant) == int:
431
  max_len = dynamic_or_constant
432
  else:
433
+ max_len = model_input_size
434
  logger.warning(
435
  "If padding style is constant, must provide integer value. " \
436
+ f"Setting padding to max input size {model_input_size}.")
437
 
438
  # pad all tensors to maximum length
439
+ tensor_list = [pad_tensor(tensor, pad_token_id, max_len) for tensor in tensor_list]
 
 
 
440
 
441
  # return stacked tensors
442
  return torch.stack(tensor_list)
 
446
  "perturb_type": {"delete","overexpress","inhibit","activate"},
447
  "perturb_rank_shift": {None, 1, 2, 3},
448
  "genes_to_perturb": {"all", list},
449
+ "combos": {0, 1},
450
  "anchor_gene": {None, str},
451
  "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
452
  "num_classes": {int},
 
458
  "emb_layer": {-1, 0},
459
  "forward_batch_size": {int},
460
  "nproc": {int},
 
461
  }
462
  def __init__(
463
  self,
 
476
  emb_layer=-1,
477
  forward_batch_size=100,
478
  nproc=4,
 
479
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
480
  ):
481
  """
 
503
  genes_to_perturb : "all", list
504
  Default is perturbing each gene detected in each cell in the dataset.
505
  Otherwise, may provide a list of ENSEMBL IDs of genes to perturb.
506
+ If gene list is provided, then perturber will only test perturbing them all together
507
+ (rather than testing each possible combination of the provided genes).
508
+ combos : {0,1}
509
+ Whether to perturb genes individually (0) or in pairs (1).
510
  anchor_gene : None, str
511
  ENSEMBL ID of gene to use as anchor in combination perturbations.
512
  For example, if combos=1 and anchor_gene="ENSG00000148400":
 
540
  Batch size for forward pass.
541
  nproc : int
542
  Number of CPU processes to use.
 
 
543
  token_dictionary_file : Path
544
  Path to pickle file containing token dictionary (Ensembl ID:token).
545
  """
 
549
  self.genes_to_perturb = genes_to_perturb
550
  self.combos = combos
551
  self.anchor_gene = anchor_gene
552
+ if self.genes_to_perturb == "all":
553
+ self.perturb_group = False
554
+ else:
555
+ self.perturb_group = True
556
+ if (self.anchor_gene != None) or (self.combos != 0):
557
+ self.anchor_gene = None
558
+ self.combos = 0
559
+ logger.warning(
560
+ "anchor_gene set to None and combos set to 0. " \
561
+ "If providing list of genes to perturb, " \
562
+ "list of genes_to_perturb will be perturbed together, "\
563
+ "without anchor gene or combinations.")
564
  self.model_type = model_type
565
  self.num_classes = num_classes
566
  self.emb_mode = emb_mode
 
571
  self.emb_layer = emb_layer
572
  self.forward_batch_size = forward_batch_size
573
  self.nproc = nproc
 
574
 
575
  self.validate_options()
576
 
 
578
  with open(token_dictionary_file, "rb") as f:
579
  self.gene_token_dict = pickle.load(f)
580
 
581
+ self.pad_token_id = self.gene_token_dict.get("<pad>")
582
+
583
+ if self.anchor_gene is None:
584
  self.anchor_token = None
585
  else:
586
+ try:
587
+ self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
588
+ except KeyError:
589
+ logger.error(
590
+ f"Anchor gene {self.anchor_gene} not in token dictionary."
591
+ )
592
+ raise
593
 
594
+ if self.genes_to_perturb == "all":
595
  self.tokens_to_perturb = "all"
596
  else:
597
+ missing_genes = [gene for gene in self.genes_to_perturb if gene not in self.gene_token_dict.keys()]
598
+ if len(missing_genes) == len(self.genes_to_perturb):
599
+ logger.error(
600
+ "None of the provided genes to perturb are in token dictionary."
601
+ )
602
+ raise
603
+ elif len(missing_genes)>0:
604
+ logger.warning(
605
+ f"Genes to perturb {missing_genes} are not in token dictionary.")
606
+ self.tokens_to_perturb = [self.gene_token_dict.get(gene) for gene in self.genes_to_perturb]
607
 
608
  def validate_options(self):
609
  # first disallow options under development
610
  if self.perturb_type in ["inhibit", "activate"]:
611
  logger.error(
612
+ "In silico inhibition and activation currently under development. " \
613
+ "Current valid options for 'perturb_type': 'delete' or 'overexpress'"
614
  )
615
  raise
616
 
 
635
  f"Valid options for {attr_name}: {valid_options}"
636
  )
637
  raise
638
+
639
  if self.perturb_type in ["delete","overexpress"]:
640
  if self.perturb_rank_shift is not None:
641
  if self.perturb_type == "delete":
 
711
  input_data_file : Path
712
  Path to directory containing .dataset inputs
713
  output_directory : Path
714
+ Path to directory where perturbation data will be saved as batched pickle files
715
  output_prefix : str
716
+ Prefix for output files
717
  """
718
 
719
  filtered_input_data = self.load_and_filter(input_data_file)
 
728
  filtered_input_data,
729
  self.cell_states_to_model,
730
  layer_to_quant,
731
+ self.pad_token_id,
732
  self.forward_batch_size,
733
  self.nproc)
734
  # filter for start state cells
 
744
  state_embs_dict,
745
  output_directory,
746
  output_prefix)
 
 
 
 
 
 
 
747
 
748
  # load data and filter by defined criteria
749
  def load_and_filter(self, input_data_file):
 
798
  output_prefix):
799
 
800
  output_path_prefix = f"{output_directory}in_silico_{self.perturb_type}_{output_prefix}_dict_1Kbatch"
801
+ model_input_size = get_model_input_size(model)
802
 
803
  # filter dataset for cells that have tokens to be perturbed
804
  if self.anchor_token is not None:
 
806
  return (len(set(example["input_ids"]).intersection(self.anchor_token))==len(self.anchor_token))
807
  filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
808
  logger.info(f"# cells with anchor gene: {len(filtered_input_data)}")
809
+ if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
810
+ # minimum # genes needed for perturbation test
811
+ min_genes = len(self.tokens_to_perturb)
812
  def if_has_tokens_to_perturb(example):
813
+ return (len(set(example["input_ids"]).intersection(self.tokens_to_perturb))>min_genes)
814
  filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
815
 
816
  cos_sims_dict = defaultdict(list)
817
  pickle_batch = -1
818
+
819
+ # make perturbation batch w/ single perturbation in multiple cells
820
+ if self.perturb_group == True:
821
+
822
+ def make_group_perturbation_batch(example):
823
+ example_input_ids = example["input_ids"]
824
+ example["tokens_to_perturb"] = self.tokens_to_perturb
825
+ indices_to_perturb = [example_input_ids.index(token) if token in example_input_ids else None for token in self.tokens_to_perturb]
826
+ indices_to_perturb = [item for item in indices_to_perturb if item is not None]
827
+ if len(indices_to_perturb) > 0:
828
+ example["perturb_index"] = indices_to_perturb
829
+ else:
830
+ # -100 indicates tokens to overexpress are not present in rank value encoding
831
+ example["perturb_index"] = [-100]
832
+ if self.perturb_type == "delete":
833
+ example = delete_indices(example)
834
+ elif self.perturb_type == "overexpress":
835
+ example = overexpress_tokens(example)
836
+ return example
837
+
838
+ perturbation_batch = filtered_input_data.map(make_group_perturbation_batch, num_proc=self.nproc)
839
+ indices_to_perturb = perturbation_batch["perturb_index"]
840
+
841
+ cos_sims_data = quant_cos_sims(model,
842
+ self.perturb_type,
843
+ perturbation_batch,
844
+ self.forward_batch_size,
845
+ layer_to_quant,
846
+ filtered_input_data,
847
+ self.tokens_to_perturb,
848
+ indices_to_perturb,
849
+ self.perturb_group,
850
+ self.cell_states_to_model,
851
+ state_embs_dict,
852
+ self.pad_token_id,
853
+ model_input_size,
854
+ self.nproc)
855
+
856
+ perturbed_genes = tuple(self.tokens_to_perturb)
857
+ original_lengths = filtered_input_data["length"]
858
+ if self.cell_states_to_model is None:
859
+ # update cos sims dict
860
+ # key is tuple of (perturbed_gene, affected_gene)
861
+ # or (perturbed_genes, "cell_emb") for avg cell emb change
862
+ cos_sims_data = cos_sims_data.to("cuda")
863
+ max_padded_len = cos_sims_data.shape[1]
864
 
865
+ for j in range(cos_sims_data.shape[0]):
866
+ # remove padding before mean pooling cell embedding
867
+ original_length = original_lengths[j]
868
+ gene_list = filtered_input_data[j]["input_ids"]
869
+ indices_removed = indices_to_perturb[j]
870
+ padding_to_remove = max_padded_len - (original_length \
871
+ - len(self.tokens_to_perturb) \
872
+ - len(indices_removed))
873
+ nonpadding_cos_sims_data = cos_sims_data[j][:-padding_to_remove]
874
+ cell_cos_sim = torch.mean(nonpadding_cos_sims_data).item()
875
+ cos_sims_dict[(perturbed_genes, "cell_emb")] += [cell_cos_sim]
876
+
877
+ if self.emb_mode == "cell_and_gene":
878
+ for k in range(cos_sims_data.shape[1]):
879
+ cos_sim_value = nonpadding_cos_sims_data[k]
880
+ affected_gene = gene_list[k].item()
881
+ cos_sims_dict[(perturbed_genes, affected_gene)] += [cos_sim_value.item()]
882
+ else:
883
+ # update cos sims dict
884
+ # key is tuple of (perturbed_genes, "cell_emb")
885
+ # value is list of tuples of cos sims for cell_states_to_model
886
+ origin_state_key = [value[0] for value in self.cell_states_to_model.values()][0][0]
887
+ cos_sims_origin = cos_sims_data[origin_state_key]
888
+ for j in range(cos_sims_origin.shape[0]):
889
+ original_length = original_lengths[j]
890
+ max_padded_len = cos_sims_origin.shape[1]
891
+ indices_removed = indices_to_perturb[j]
892
+ padding_to_remove = max_padded_len - (original_length \
893
+ - len(self.tokens_to_perturb) \
894
+ - len(indices_removed))
895
+ data_list = []
896
+ for data in list(cos_sims_data.values()):
897
+ data_item = data.to("cuda")
898
+ nonpadding_data_item = data_item[j][:-padding_to_remove]
899
+ cell_data = torch.mean(nonpadding_data_item).item()
900
+ data_list += [cell_data]
901
+ cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
902
 
903
+ with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
904
+ pickle.dump(cos_sims_dict, fp)
905
+
906
+ # make perturbation batch w/ multiple perturbations in single cell
907
+ if self.perturb_group == False:
908
 
909
+ for i in trange(len(filtered_input_data)):
910
+ example_cell = filtered_input_data.select([i])
911
+ original_emb = forward_pass_single_cell(model, example_cell, layer_to_quant)
912
+ gene_list = torch.squeeze(example_cell["input_ids"])
913
+
914
+ # reset to original type to prevent downstream issues due to forward_pass_single_cell modifying as torch format in place
915
+ example_cell = filtered_input_data.select([i])
916
+
917
+ if self.anchor_token is None:
918
+ for combo_lvl in range(self.combos+1):
919
+ perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
920
+ self.perturb_type,
921
+ self.tokens_to_perturb,
922
+ self.anchor_token,
923
+ combo_lvl,
924
+ self.nproc)
925
+ cos_sims_data = quant_cos_sims(model,
926
+ self.perturb_type,
927
+ perturbation_batch,
928
+ self.forward_batch_size,
929
+ layer_to_quant,
930
+ original_emb,
931
+ self.tokens_to_perturb,
932
+ indices_to_perturb,
933
+ self.perturb_group,
934
+ self.cell_states_to_model,
935
+ state_embs_dict,
936
+ self.pad_token_id,
937
+ model_input_size,
938
+ self.nproc)
939
+
940
+ if self.cell_states_to_model is None:
941
+ # update cos sims dict
942
+ # key is tuple of (perturbed_gene, affected_gene)
943
+ # or (perturbed_gene, "cell_emb") for avg cell emb change
944
+ cos_sims_data = cos_sims_data.to("cuda")
945
+ for j in range(cos_sims_data.shape[0]):
946
+ if self.tokens_to_perturb != "all":
947
+ j_index = torch.tensor(indices_to_perturb[j])
948
+ if j_index.shape[0]>1:
949
+ j_index = torch.squeeze(j_index)
950
+ else:
951
+ j_index = torch.tensor([j])
952
+ perturbed_gene = torch.index_select(gene_list, 0, j_index)
953
+
954
+ if perturbed_gene.shape[0]==1:
955
+ perturbed_gene = perturbed_gene.item()
956
+ elif perturbed_gene.shape[0]>1:
957
+ perturbed_gene = tuple(perturbed_gene.tolist())
958
+
959
+ cell_cos_sim = torch.mean(cos_sims_data[j]).item()
960
+ cos_sims_dict[(perturbed_gene, "cell_emb")] += [cell_cos_sim]
961
+
962
+ # not_j_index = list(set(i for i in range(gene_list.shape[0])).difference(j_index))
963
+ # gene_list_j = torch.index_select(gene_list, 0, j_index)
964
+ if self.emb_mode == "cell_and_gene":
965
+ for k in range(cos_sims_data.shape[1]):
966
+ cos_sim_value = cos_sims_data[j][k]
967
+ affected_gene = gene_list[k].item()
968
+ cos_sims_dict[(perturbed_gene, affected_gene)] += [cos_sim_value.item()]
969
+ else:
970
+ # update cos sims dict
971
+ # key is tuple of (perturbed_gene, "cell_emb")
972
+ # value is list of tuples of cos sims for cell_states_to_model
973
+ origin_state_key = [value[0] for value in self.cell_states_to_model.values()][0][0]
974
+ cos_sims_origin = cos_sims_data[origin_state_key]
975
+
976
+ for j in range(cos_sims_origin.shape[0]):
977
+ if (self.tokens_to_perturb != "all") or (combo_lvl>0):
978
+ j_index = torch.tensor(indices_to_perturb[j])
979
+ if j_index.shape[0]>1:
980
+ j_index = torch.squeeze(j_index)
981
+ else:
982
+ j_index = torch.tensor([j])
983
+ perturbed_gene = torch.index_select(gene_list, 0, j_index)
984
+
985
+ if perturbed_gene.shape[0]==1:
986
+ perturbed_gene = perturbed_gene.item()
987
+ elif perturbed_gene.shape[0]>1:
988
+ perturbed_gene = tuple(perturbed_gene.tolist())
989
+
990
+ data_list = []
991
+ for data in list(cos_sims_data.values()):
992
+ data_item = data.to("cuda")
993
+ cell_data = torch.mean(data_item[j]).item()
994
+ data_list += [cell_data]
995
+ cos_sims_dict[(perturbed_gene, "cell_emb")] += [tuple(data_list)]
996
+
997
+ elif self.anchor_token is not None:
998
  perturbation_batch, indices_to_perturb = make_perturbation_batch(example_cell,
999
+ self.perturb_type,
1000
+ self.tokens_to_perturb,
1001
+ None, # first run without anchor token to test individual gene perturbations
1002
+ 0,
1003
+ self.nproc)
1004
  cos_sims_data = quant_cos_sims(model,
1005
  self.perturb_type,
1006
+ perturbation_batch,
1007
+ self.forward_batch_size,
1008
+ layer_to_quant,
1009
+ original_emb,
1010
+ self.tokens_to_perturb,
1011
  indices_to_perturb,
1012
+ self.perturb_group,
1013
  self.cell_states_to_model,
1014
+ state_embs_dict,
1015
+ self.pad_token_id,
1016
+ model_input_size,
1017
+ self.nproc)
1018
+ cos_sims_data = cos_sims_data.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1019
 
1020
+ combo_perturbation_batch, combo_indices_to_perturb = make_perturbation_batch(example_cell,
1021
+ self.perturb_type,
1022
+ self.tokens_to_perturb,
1023
+ self.anchor_token,
1024
+ 1,
1025
+ self.nproc)
1026
+ combo_cos_sims_data = quant_cos_sims(model,
1027
+ self.perturb_type,
1028
+ combo_perturbation_batch,
1029
+ self.forward_batch_size,
1030
+ layer_to_quant,
1031
+ original_emb,
1032
+ self.tokens_to_perturb,
1033
+ combo_indices_to_perturb,
1034
+ self.perturb_group,
1035
+ self.cell_states_to_model,
1036
+ state_embs_dict,
1037
+ self.pad_token_id,
1038
+ model_input_size,
1039
+ self.nproc)
1040
+ combo_cos_sims_data = combo_cos_sims_data.to("cuda")
1041
 
1042
+ # update cos sims dict
1043
+ # key is tuple of (perturbed_gene, "cell_emb") for avg cell emb change
1044
+ anchor_index = example_cell["input_ids"][0].index(self.anchor_token[0])
1045
+ anchor_cell_cos_sim = torch.mean(cos_sims_data[anchor_index]).item()
1046
+ non_anchor_indices = [k for k in range(cos_sims_data.shape[0]) if k != anchor_index]
1047
+ cos_sims_data = cos_sims_data[non_anchor_indices,:]
1048
 
1049
+ for j in range(cos_sims_data.shape[0]):
1050
 
1051
+ if j<anchor_index:
1052
+ j_index = torch.tensor([j])
1053
+ else:
1054
+ j_index = torch.tensor([j+1])
1055
 
1056
+ perturbed_gene = torch.index_select(gene_list, 0, j_index)
1057
+ perturbed_gene = perturbed_gene.item()
1058
 
1059
+ cell_cos_sim = torch.mean(cos_sims_data[j]).item()
1060
+ combo_cos_sim = torch.mean(combo_cos_sims_data[j]).item()
1061
+ cos_sims_dict[(perturbed_gene, "cell_emb")] += [(anchor_cell_cos_sim, # cos sim anchor gene alone
1062
+ cell_cos_sim, # cos sim deleted gene alone
1063
+ combo_cos_sim)] # cos sim anchor gene + deleted gene
1064
+
1065
+ # save dict to disk every 100 cells
1066
+ if (i/100).is_integer():
1067
+ with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
1068
+ pickle.dump(cos_sims_dict, fp)
1069
+ # reset and clear memory every 1000 cells
1070
+ if (i/1000).is_integer():
1071
+ pickle_batch = pickle_batch+1
1072
+ # clear memory
1073
+ del perturbed_gene
1074
+ del cos_sims_data
1075
+ if self.cell_states_to_model is None:
1076
+ del cell_cos_sim
1077
+ if self.cell_states_to_model is not None:
1078
+ del cell_data
1079
+ del data_list
1080
+ elif self.anchor_token is None:
1081
+ if self.emb_mode == "cell_and_gene":
1082
+ del affected_gene
1083
+ del cos_sim_value
1084
+ else:
1085
+ del combo_cos_sim
1086
+ del combo_cos_sims_data
1087
+ # reset dict
1088
+ del cos_sims_dict
1089
+ cos_sims_dict = defaultdict(list)
1090
+ torch.cuda.empty_cache()
1091
+
1092
+ # save remainder cells
1093
+ with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
1094
+ pickle.dump(cos_sims_dict, fp)
1095
 
geneformer/in_silico_perturber_stats.py CHANGED
@@ -79,6 +79,9 @@ def get_gene_list(dict_list,mode):
79
  gene_list.sort()
80
  return gene_list
81
 
 
 
 
82
  def n_detections(token, dict_list, mode, anchor_token):
83
  cos_sim_megalist = []
84
  for dict_i in dict_list:
@@ -106,98 +109,130 @@ def get_impact_component(test_value, gaussian_mixture_model):
106
  impact_component = 1
107
  return impact_component
108
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
110
- def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
111
  cell_state_key = list(cell_states_to_model.keys())[0]
112
  if cell_states_to_model[cell_state_key][2] == []:
113
  alt_end_state_exists = False
114
  elif (len(cell_states_to_model[cell_state_key][2]) > 0) and (cell_states_to_model[cell_state_key][2] != [None]):
115
  alt_end_state_exists = True
116
 
117
- random_tuples = []
118
- for i in trange(cos_sims_df.shape[0]):
119
- token = cos_sims_df["Gene"][i]
120
- for dict_i in dict_list:
121
- random_tuples += dict_i.get((token, "cell_emb"),[])
122
-
123
- if alt_end_state_exists == False:
124
- goal_end_random_megalist = [goal_end for start_state,goal_end in random_tuples]
125
- elif alt_end_state_exists == True:
126
- goal_end_random_megalist = [goal_end for start_state,goal_end,alt_end in random_tuples]
127
- alt_end_random_megalist = [alt_end for start_state,goal_end,alt_end in random_tuples]
128
-
129
- # downsample to improve speed of ranksums
130
- if len(goal_end_random_megalist) > 100_000:
131
- random.seed(42)
132
- goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
133
- if alt_end_state_exists == True:
134
- if len(alt_end_random_megalist) > 100_000:
135
- random.seed(42)
136
- alt_end_random_megalist = random.sample(alt_end_random_megalist, k=100_000)
137
-
138
- names=["Gene",
139
- "Gene_name",
140
- "Ensembl_ID",
141
- "Shift_to_goal_end",
142
- "Shift_to_alt_end",
143
- "Goal_end_vs_random_pval",
144
- "Alt_end_vs_random_pval"]
145
- if alt_end_state_exists == False:
146
- names.remove("Shift_to_alt_end")
147
- names.remove("Alt_end_vs_random_pval")
148
- cos_sims_full_df = pd.DataFrame(columns=names)
149
-
150
- for i in trange(cos_sims_df.shape[0]):
151
- token = cos_sims_df["Gene"][i]
152
- name = cos_sims_df["Gene_name"][i]
153
- ensembl_id = cos_sims_df["Ensembl_ID"][i]
154
- cos_shift_data = []
155
 
 
 
156
  for dict_i in dict_list:
157
  cos_shift_data += dict_i.get((token, "cell_emb"),[])
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  if alt_end_state_exists == False:
160
- goal_end_cos_sim_megalist = [goal_end for start_state,goal_end in cos_shift_data]
161
  elif alt_end_state_exists == True:
162
- goal_end_cos_sim_megalist = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
163
- alt_end_cos_sim_megalist = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
164
- mean_alt_end = np.mean(alt_end_cos_sim_megalist)
165
- pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
166
-
167
- mean_goal_end = np.mean(goal_end_cos_sim_megalist)
168
- pval_goal_end = ranksums(goal_end_random_megalist,goal_end_cos_sim_megalist).pvalue
169
-
 
 
 
 
 
 
 
 
 
 
 
170
  if alt_end_state_exists == False:
171
- data_i = [token,
172
- name,
173
- ensembl_id,
174
- mean_goal_end,
175
- pval_goal_end]
176
- elif alt_end_state_exists == True:
177
- data_i = [token,
178
- name,
179
- ensembl_id,
180
- mean_goal_end,
181
- mean_alt_end,
182
- pval_goal_end,
183
- pval_alt_end]
184
-
185
- cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
186
- cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
187
-
188
- cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
189
- if alt_end_state_exists == True:
190
- cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
191
-
192
- # quantify number of detections of each gene
193
- cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
194
 
195
- # sort by shift to desired state
196
- cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_to_goal_end",
197
- "Goal_end_FDR"],
198
- ascending=[False,True])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
- return cos_sims_full_df
201
 
202
  # stats comparing cos sim shifts of test perturbations vs null distribution
203
  def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
@@ -362,7 +397,7 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
362
 
363
  class InSilicoPerturberStats:
364
  valid_option_dict = {
365
- "mode": {"goal_state_shift","vs_null","mixture_model"},
366
  "combos": {0,1},
367
  "anchor_gene": {None, str},
368
  "cell_states_to_model": {None, dict},
@@ -370,6 +405,7 @@ class InSilicoPerturberStats:
370
  def __init__(
371
  self,
372
  mode="mixture_model",
 
373
  combos=0,
374
  anchor_gene=None,
375
  cell_states_to_model=None,
@@ -381,11 +417,16 @@ class InSilicoPerturberStats:
381
 
382
  Parameters
383
  ----------
384
- mode : {"goal_state_shift","vs_null","mixture_model"}
385
  Type of stats.
386
  "goal_state_shift": perturbation vs. random for desired cell state shift
387
  "vs_null": perturbation vs. null from provided null distribution dataset
388
  "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
 
 
 
 
 
389
  combos : {0,1,2}
390
  Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
391
  anchor_gene : None, str
@@ -406,6 +447,7 @@ class InSilicoPerturberStats:
406
  """
407
 
408
  self.mode = mode
 
409
  self.combos = combos
410
  self.anchor_gene = anchor_gene
411
  self.cell_states_to_model = cell_states_to_model
@@ -477,6 +519,17 @@ class InSilicoPerturberStats:
477
  "in silico perturbation run with anchor gene. Please add " \
478
  "anchor gene when using with combos > 0. ")
479
  raise
 
 
 
 
 
 
 
 
 
 
 
480
 
481
  def get_stats(self,
482
  input_data_directory,
@@ -495,7 +548,7 @@ class InSilicoPerturberStats:
495
  output_directory : Path
496
  Path to directory where perturbation data will be saved as .csv
497
  output_prefix : str
498
- Prefix for output .dataset
499
 
500
  Outputs
501
  ----------
@@ -538,11 +591,11 @@ class InSilicoPerturberStats:
538
  "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
539
  """
540
 
541
- if self.mode not in ["goal_state_shift", "vs_null", "mixture_model"]:
542
  logger.error(
543
  "Currently, only modes available are stats for goal_state_shift, " \
544
- "vs_null (comparing to null distribution), and " \
545
- "mixture_model (fitting mixture model for perturbations with or without impact.")
546
  raise
547
 
548
  self.gene_token_id_dict = invert_dict(self.gene_token_dict)
@@ -562,14 +615,16 @@ class InSilicoPerturberStats:
562
  cos_sims_df_initial = pd.DataFrame({"Gene": gene_list,
563
  "Gene_name": [self.token_to_gene_name(item) \
564
  for item in gene_list], \
565
- "Ensembl_ID": [self.gene_token_id_dict[genes[1]] \
 
 
566
  if isinstance(genes,tuple) else \
567
  self.gene_token_id_dict[genes] \
568
  for genes in gene_list]}, \
569
  index=[i for i in range(len(gene_list))])
570
 
571
  if self.mode == "goal_state_shift":
572
- cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list, self.cell_states_to_model)
573
 
574
  elif self.mode == "vs_null":
575
  null_dict_list = read_dictionaries(null_dist_data_directory, "cell", self.anchor_token)
@@ -577,6 +632,9 @@ class InSilicoPerturberStats:
577
 
578
  elif self.mode == "mixture_model":
579
  cos_sims_df = isp_stats_mixture_model(cos_sims_df_initial, dict_list, self.combos, self.anchor_token)
 
 
 
580
 
581
  # save perturbation stats to output_path
582
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
 
79
  gene_list.sort()
80
  return gene_list
81
 
82
+ def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
83
+ return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
84
+
85
  def n_detections(token, dict_list, mode, anchor_token):
86
  cos_sim_megalist = []
87
  for dict_i in dict_list:
 
109
  impact_component = 1
110
  return impact_component
111
 
112
+ # aggregate data for single perturbation in multiple cells
113
+ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
114
+ names=["Cosine_shift"]
115
+ cos_sims_full_df = pd.DataFrame(columns=names)
116
+
117
+ cos_shift_data = []
118
+ token = cos_sims_df["Gene"][0]
119
+ for dict_i in dict_list:
120
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
121
+ cos_sims_full_df["Cosine_shift"] = cos_shift_data
122
+ return cos_sims_full_df
123
+
124
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
125
+ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_perturbed):
126
  cell_state_key = list(cell_states_to_model.keys())[0]
127
  if cell_states_to_model[cell_state_key][2] == []:
128
  alt_end_state_exists = False
129
  elif (len(cell_states_to_model[cell_state_key][2]) > 0) and (cell_states_to_model[cell_state_key][2] != [None]):
130
  alt_end_state_exists = True
131
 
132
+ # for single perturbation in multiple cells, there are no random perturbations to compare to
133
+ if genes_perturbed != "all":
134
+ names=["Shift_to_goal_end",
135
+ "Shift_to_alt_end"]
136
+ if alt_end_state_exists == False:
137
+ names.remove("Shift_to_alt_end")
138
+ cos_sims_full_df = pd.DataFrame(columns=names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ cos_shift_data = []
141
+ token = cos_sims_df["Gene"][0]
142
  for dict_i in dict_list:
143
  cos_shift_data += dict_i.get((token, "cell_emb"),[])
144
+ if alt_end_state_exists == False:
145
+ cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end in cos_shift_data]
146
+ if alt_end_state_exists == True:
147
+ cos_sims_full_df["Shift_to_goal_end"] = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
148
+ cos_sims_full_df["Shift_to_alt_end"] = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
149
+ return cos_sims_full_df
150
+
151
+ elif genes_perturbed == "all":
152
+ random_tuples = []
153
+ for i in trange(cos_sims_df.shape[0]):
154
+ token = cos_sims_df["Gene"][i]
155
+ for dict_i in dict_list:
156
+ random_tuples += dict_i.get((token, "cell_emb"),[])
157
 
158
  if alt_end_state_exists == False:
159
+ goal_end_random_megalist = [goal_end for start_state,goal_end in random_tuples]
160
  elif alt_end_state_exists == True:
161
+ goal_end_random_megalist = [goal_end for start_state,goal_end,alt_end in random_tuples]
162
+ alt_end_random_megalist = [alt_end for start_state,goal_end,alt_end in random_tuples]
163
+
164
+ # downsample to improve speed of ranksums
165
+ if len(goal_end_random_megalist) > 100_000:
166
+ random.seed(42)
167
+ goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
168
+ if alt_end_state_exists == True:
169
+ if len(alt_end_random_megalist) > 100_000:
170
+ random.seed(42)
171
+ alt_end_random_megalist = random.sample(alt_end_random_megalist, k=100_000)
172
+
173
+ names=["Gene",
174
+ "Gene_name",
175
+ "Ensembl_ID",
176
+ "Shift_to_goal_end",
177
+ "Shift_to_alt_end",
178
+ "Goal_end_vs_random_pval",
179
+ "Alt_end_vs_random_pval"]
180
  if alt_end_state_exists == False:
181
+ names.remove("Shift_to_alt_end")
182
+ names.remove("Alt_end_vs_random_pval")
183
+ cos_sims_full_df = pd.DataFrame(columns=names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ for i in trange(cos_sims_df.shape[0]):
186
+ token = cos_sims_df["Gene"][i]
187
+ name = cos_sims_df["Gene_name"][i]
188
+ ensembl_id = cos_sims_df["Ensembl_ID"][i]
189
+ cos_shift_data = []
190
+
191
+ for dict_i in dict_list:
192
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
193
+
194
+ if alt_end_state_exists == False:
195
+ goal_end_cos_sim_megalist = [goal_end for start_state,goal_end in cos_shift_data]
196
+ elif alt_end_state_exists == True:
197
+ goal_end_cos_sim_megalist = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
198
+ alt_end_cos_sim_megalist = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
199
+ mean_alt_end = np.mean(alt_end_cos_sim_megalist)
200
+ pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
201
+
202
+ mean_goal_end = np.mean(goal_end_cos_sim_megalist)
203
+ pval_goal_end = ranksums(goal_end_random_megalist,goal_end_cos_sim_megalist).pvalue
204
+
205
+ if alt_end_state_exists == False:
206
+ data_i = [token,
207
+ name,
208
+ ensembl_id,
209
+ mean_goal_end,
210
+ pval_goal_end]
211
+ elif alt_end_state_exists == True:
212
+ data_i = [token,
213
+ name,
214
+ ensembl_id,
215
+ mean_goal_end,
216
+ mean_alt_end,
217
+ pval_goal_end,
218
+ pval_alt_end]
219
+
220
+ cos_sims_df_i = pd.DataFrame(dict(zip(names,data_i)),index=[i])
221
+ cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
222
+
223
+ cos_sims_full_df["Goal_end_FDR"] = get_fdr(list(cos_sims_full_df["Goal_end_vs_random_pval"]))
224
+ if alt_end_state_exists == True:
225
+ cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
226
+
227
+ # quantify number of detections of each gene
228
+ cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
229
+
230
+ # sort by shift to desired state
231
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_to_goal_end",
232
+ "Goal_end_FDR"],
233
+ ascending=[False,True])
234
 
235
+ return cos_sims_full_df
236
 
237
  # stats comparing cos sim shifts of test perturbations vs null distribution
238
  def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
 
397
 
398
  class InSilicoPerturberStats:
399
  valid_option_dict = {
400
+ "mode": {"goal_state_shift","vs_null","mixture_model","aggregate_data"},
401
  "combos": {0,1},
402
  "anchor_gene": {None, str},
403
  "cell_states_to_model": {None, dict},
 
405
  def __init__(
406
  self,
407
  mode="mixture_model",
408
+ genes_perturbed="all",
409
  combos=0,
410
  anchor_gene=None,
411
  cell_states_to_model=None,
 
417
 
418
  Parameters
419
  ----------
420
+ mode : {"goal_state_shift","vs_null","mixture_model","aggregate_data"}
421
  Type of stats.
422
  "goal_state_shift": perturbation vs. random for desired cell state shift
423
  "vs_null": perturbation vs. null from provided null distribution dataset
424
  "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
425
+ "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
426
+ genes_perturbed : "all", list
427
+ Genes perturbed in isp experiment.
428
+ Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
429
+ Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
430
  combos : {0,1,2}
431
  Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
432
  anchor_gene : None, str
 
447
  """
448
 
449
  self.mode = mode
450
+ self.genes_perturbed = genes_perturbed
451
  self.combos = combos
452
  self.anchor_gene = anchor_gene
453
  self.cell_states_to_model = cell_states_to_model
 
519
  "in silico perturbation run with anchor gene. Please add " \
520
  "anchor gene when using with combos > 0. ")
521
  raise
522
+
523
+ if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
524
+ logger.error(
525
+ "Mixture model mode requires multiple gene perturbations to fit model " \
526
+ "so is incompatible with a single grouped perturbation.")
527
+ raise
528
+ if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
529
+ logger.error(
530
+ "Simple data aggregation mode is for single perturbation in multiple cells " \
531
+ "so is incompatible with a genes_perturbed being 'all'.")
532
+ raise
533
 
534
  def get_stats(self,
535
  input_data_directory,
 
548
  output_directory : Path
549
  Path to directory where perturbation data will be saved as .csv
550
  output_prefix : str
551
+ Prefix for output .csv
552
 
553
  Outputs
554
  ----------
 
591
  "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
592
  """
593
 
594
+ if self.mode not in ["goal_state_shift", "vs_null", "mixture_model","aggregate_data"]:
595
  logger.error(
596
  "Currently, only modes available are stats for goal_state_shift, " \
597
+ "vs_null (comparing to null distribution), and " \
598
+ "mixture_model (fitting mixture model for perturbations with or without impact.")
599
  raise
600
 
601
  self.gene_token_id_dict = invert_dict(self.gene_token_dict)
 
615
  cos_sims_df_initial = pd.DataFrame({"Gene": gene_list,
616
  "Gene_name": [self.token_to_gene_name(item) \
617
  for item in gene_list], \
618
+ "Ensembl_ID": [token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict) \
619
+ if self.genes_perturbed != "all" else \
620
+ self.gene_token_id_dict[genes[1]] \
621
  if isinstance(genes,tuple) else \
622
  self.gene_token_id_dict[genes] \
623
  for genes in gene_list]}, \
624
  index=[i for i in range(len(gene_list))])
625
 
626
  if self.mode == "goal_state_shift":
627
+ cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list, self.cell_states_to_model, self.genes_perturbed)
628
 
629
  elif self.mode == "vs_null":
630
  null_dict_list = read_dictionaries(null_dist_data_directory, "cell", self.anchor_token)
 
632
 
633
  elif self.mode == "mixture_model":
634
  cos_sims_df = isp_stats_mixture_model(cos_sims_df_initial, dict_list, self.combos, self.anchor_token)
635
+
636
+ elif self.mode == "aggregate_data":
637
+ cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list)
638
 
639
  # save perturbation stats to output_path
640
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")