ctheodoris davidjwen commited on
Commit
9169bfd
1 Parent(s): 65b4915

Fixed error with perturbing individual genes and updated ways to specify cell_states_to_model (#146)

Browse files

- Fixed error with perturbing individual genes and updated ways to specify cell_states_to_model (771c8bd01b1754c4387d742680e65c34697a2336)
- Fix isp perturb_group dims, reformat cell states dict to keyed, add attn mask (c2679c41d352513a56feb5490c5302c6f25ae7ba)


Co-authored-by: David Wen <davidjwen@users.noreply.huggingface.co>

examples/extract_and_plot_cell_embeddings.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
examples/in_silico_perturbation.ipynb CHANGED
@@ -33,7 +33,10 @@
33
  " emb_mode=\"cell\",\n",
34
  " cell_emb_style=\"mean_pool\",\n",
35
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
36
- " cell_states_to_model={\"disease\":([\"dcm\"],[\"nf\"],[\"hcm\"])},\n",
 
 
 
37
  " max_ncells=2000,\n",
38
  " emb_layer=0,\n",
39
  " forward_batch_size=400,\n",
 
33
  " emb_mode=\"cell\",\n",
34
  " cell_emb_style=\"mean_pool\",\n",
35
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
36
+ " cell_states_to_model={'state_key': 'disease', \n",
37
+ " 'start_state': 'dcm', \n",
38
+ " 'goal_state': 'nf', \n",
39
+ " 'alt_states': ['hcm']},\n",
40
  " max_ncells=2000,\n",
41
  " emb_layer=0,\n",
42
  " forward_batch_size=400,\n",
geneformer/emb_extractor.py CHANGED
@@ -43,32 +43,17 @@ from transformers import BertForMaskedLM, BertForTokenClassification, BertForSeq
43
 
44
  from .tokenizer import TOKEN_DICTIONARY_FILE
45
 
46
- from .in_silico_perturber import load_and_filter, \
47
- downsample_and_sort, \
 
 
48
  load_model, \
49
- quant_layers, \
50
- downsample_and_sort, \
51
  pad_tensor_list, \
52
- get_model_input_size
53
-
54
 
55
  logger = logging.getLogger(__name__)
56
 
57
- # get cell embeddings excluding padding
58
- def mean_nonpadding_embs(embs, original_lens):
59
- # mask based on padding lengths
60
- mask = torch.arange(embs.size(1)).unsqueeze(0).to("cuda") < original_lens.unsqueeze(1)
61
-
62
- # extend mask dimensions to match the embeddings tensor
63
- mask = mask.unsqueeze(2).expand_as(embs)
64
-
65
- # use the mask to zero out the embeddings in padded areas
66
- masked_embs = embs * mask.float()
67
-
68
- # sum and divide by the lengths to get the mean of non-padding embs
69
- mean_embs = masked_embs.sum(1) / original_lens.view(-1, 1).float()
70
- return mean_embs
71
-
72
  # average embedding position of goal cell states
73
  def get_embs(model,
74
  filtered_input_data,
@@ -99,7 +84,8 @@ def get_embs(model,
99
 
100
  with torch.no_grad():
101
  outputs = model(
102
- input_ids = input_data_minibatch.to("cuda")
 
103
  )
104
 
105
  embs_i = outputs.hidden_states[layer_to_quant]
 
43
 
44
  from .tokenizer import TOKEN_DICTIONARY_FILE
45
 
46
+ from .in_silico_perturber import downsample_and_sort, \
47
+ gen_attention_mask, \
48
+ get_model_input_size, \
49
+ load_and_filter, \
50
  load_model, \
51
+ mean_nonpadding_embs, \
 
52
  pad_tensor_list, \
53
+ quant_layers
 
54
 
55
  logger = logging.getLogger(__name__)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # average embedding position of goal cell states
58
  def get_embs(model,
59
  filtered_input_data,
 
84
 
85
  with torch.no_grad():
86
  outputs = model(
87
+ input_ids = input_data_minibatch.to("cuda"),
88
+ attention_mask = gen_attention_mask(minibatch)
89
  )
90
 
91
  embs_i = outputs.hidden_states[layer_to_quant]
geneformer/in_silico_perturber.py CHANGED
@@ -13,7 +13,7 @@ Usage:
13
  emb_mode="cell",
14
  cell_emb_style="mean_pool",
15
  filter_data={"cell_type":["cardiomyocyte"]},
16
- cell_states_to_model={"disease":(["dcm"],["ctrl"],["hcm"])},
17
  max_ncells=None,
18
  emb_layer=-1,
19
  forward_batch_size=100,
@@ -105,6 +105,13 @@ def downsample_and_sort(data_shuffled, max_ncells):
105
  data_sorted = data_subset.sort("length",reverse=True)
106
  return data_sorted
107
 
 
 
 
 
 
 
 
108
  def forward_pass_single_cell(model, example_cell, layer_to_quant):
109
  example_cell.set_format(type="torch")
110
  input_data = example_cell["input_ids"]
@@ -146,6 +153,21 @@ def overexpress_tokens(example):
146
  [example["input_ids"].insert(0, token) for token in example["tokens_to_perturb"][::-1]]
147
  return example
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def make_perturbation_batch(example_cell,
150
  perturb_type,
151
  tokens_to_perturb,
@@ -235,13 +257,15 @@ def get_cell_state_avg_embs(model,
235
  num_proc):
236
 
237
  model_input_size = get_model_input_size(model)
238
- possible_states = [value[0]+value[1]+value[2] for value in cell_states_to_model.values()][0]
239
  state_embs_dict = dict()
240
  for possible_state in possible_states:
241
  state_embs_list = []
 
242
 
243
  def filter_states(example):
244
- return example[list(cell_states_to_model.keys())[0]] in [possible_state]
 
245
  filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
246
  total_batch_length = len(filtered_input_data_state)
247
  if ((total_batch_length-1)/forward_batch_size).is_integer():
@@ -254,14 +278,17 @@ def get_cell_state_avg_embs(model,
254
  state_minibatch.set_format(type="torch")
255
 
256
  input_data_minibatch = state_minibatch["input_ids"]
 
257
  input_data_minibatch = pad_tensor_list(input_data_minibatch,
258
  max_len,
259
  pad_token_id,
260
  model_input_size)
 
261
 
262
  with torch.no_grad():
263
  outputs = model(
264
- input_ids = input_data_minibatch.to("cuda")
 
265
  )
266
 
267
  state_embs_i = outputs.hidden_states[layer_to_quant]
@@ -269,10 +296,13 @@ def get_cell_state_avg_embs(model,
269
  del outputs
270
  del state_minibatch
271
  del input_data_minibatch
 
272
  del state_embs_i
273
  torch.cuda.empty_cache()
274
- state_embs_stack = torch.cat(state_embs_list)
275
- avg_state_emb = torch.mean(state_embs_stack,dim=[0,1],keepdim=True)
 
 
276
  state_embs_dict[possible_state] = avg_state_emb
277
  return state_embs_dict
278
 
@@ -291,7 +321,6 @@ def quant_cos_sims(model,
291
  pad_token_id,
292
  model_input_size,
293
  nproc):
294
-
295
  cos = torch.nn.CosineSimilarity(dim=2)
296
  total_batch_length = len(perturbation_batch)
297
  if ((total_batch_length-1)/forward_batch_size).is_integer():
@@ -301,7 +330,7 @@ def quant_cos_sims(model,
301
  comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
302
  cos_sims = []
303
  else:
304
- possible_states = [value[0]+value[1]+value[2] for value in cell_states_to_model.values()][0]
305
  cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
306
 
307
  # measure length of each element in perturbation_batch
@@ -316,10 +345,12 @@ def quant_cos_sims(model,
316
 
317
  # determine if need to pad or truncate batch
318
  minibatch_length_set = set(perturbation_minibatch["length"])
 
319
  if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > model_input_size):
320
  needs_pad_or_trunc = True
321
  else:
322
  needs_pad_or_trunc = False
 
323
 
324
  if needs_pad_or_trunc == True:
325
  max_len = min(max(minibatch_length_set),model_input_size)
@@ -332,14 +363,17 @@ def quant_cos_sims(model,
332
  perturbation_minibatch.set_format(type="torch")
333
 
334
  input_data_minibatch = perturbation_minibatch["input_ids"]
 
335
 
336
  # extract embeddings for perturbation minibatch
337
  with torch.no_grad():
338
  outputs = model(
339
- input_ids = input_data_minibatch.to("cuda")
 
340
  )
341
  del input_data_minibatch
342
  del perturbation_minibatch
 
343
 
344
  if len(indices_to_perturb)>1:
345
  minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
@@ -360,6 +394,7 @@ def quant_cos_sims(model,
360
  # truncate to the (model input size - # tokens to overexpress) to ensure comparability
361
  # since max input size of perturb batch will be reduced by # tokens to overexpress
362
  original_minibatch = original_emb.select([i for i in range(i, max_range)])
 
363
  original_minibatch_length_set = set(original_minibatch["length"])
364
  if perturb_type == "overexpress":
365
  new_max_len = model_input_size - len(tokens_to_perturb)
@@ -373,19 +408,30 @@ def quant_cos_sims(model,
373
  original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
374
  original_minibatch.set_format(type="torch")
375
  original_input_data_minibatch = original_minibatch["input_ids"]
 
376
  # extract embeddings for original minibatch
377
  with torch.no_grad():
378
  original_outputs = model(
379
- input_ids = original_input_data_minibatch.to("cuda")
 
380
  )
381
  del original_input_data_minibatch
382
  del original_minibatch
 
383
 
384
  if len(indices_to_perturb)>1:
385
  original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
386
  else:
387
  original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
388
-
 
 
 
 
 
 
 
 
389
  # cosine similarity between original emb and batch items
390
  if cell_states_to_model is None:
391
  if perturb_group == False:
@@ -394,6 +440,7 @@ def quant_cos_sims(model,
394
  minibatch_comparison = make_comparison_batch(original_minibatch_emb,
395
  indices_to_perturb,
396
  perturb_group)
 
397
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
398
  elif cell_states_to_model is not None:
399
  for state in possible_states:
@@ -406,7 +453,9 @@ def quant_cos_sims(model,
406
  cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
407
  minibatch_emb,
408
  state_embs_dict[state],
409
- perturb_group)
 
 
410
  del outputs
411
  del minibatch_emb
412
  if cell_states_to_model is None:
@@ -421,14 +470,41 @@ def quant_cos_sims(model,
421
  return cos_sims_vs_alt_dict
422
 
423
  # calculate cos sim shift of perturbation with respect to origin and alternative cell
424
- def cos_sim_shift(original_emb, minibatch_emb, alt_emb, perturb_group):
 
 
 
 
 
425
  cos = torch.nn.CosineSimilarity(dim=2)
426
- original_emb = torch.mean(original_emb,dim=0,keepdim=True)
427
- if perturb_group == False:
428
  original_emb = original_emb[None, :]
429
- origin_v_end = cos(original_emb,alt_emb)
430
- perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
431
- perturb_v_end = cos(perturb_emb,alt_emb)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  return [(perturb_v_end-origin_v_end).to("cpu")]
433
 
434
  def pad_list(input_ids, pad_token_id, max_len):
@@ -488,6 +564,30 @@ def pad_tensor_list(tensor_list, dynamic_or_constant, pad_token_id, model_input_
488
  # return stacked tensors
489
  return torch.stack(tensor_list)
490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  class InSilicoPerturber:
492
  valid_option_dict = {
493
  "perturb_type": {"delete","overexpress","inhibit","activate"},
@@ -573,9 +673,15 @@ class InSilicoPerturber:
573
  Otherwise, dictionary specifying .dataset column name and list of values to filter by.
574
  cell_states_to_model: None, dict
575
  Cell states to model if testing perturbations that achieve goal state change.
576
- Single-item dictionary with key being cell attribute (e.g. "disease").
577
- Value is tuple of three lists indicating start state, goal end state, and alternate possible end states.
578
- If no alternate possible end states, third list should be empty (i.e. the third list should be []).
 
 
 
 
 
 
579
  max_ncells : None, int
580
  Maximum number of cells to test.
581
  If None, will test all cells.
@@ -706,6 +812,17 @@ class InSilicoPerturber:
706
 
707
  if self.cell_states_to_model is not None:
708
  if len(self.cell_states_to_model.items()) == 1:
 
 
 
 
 
 
 
 
 
 
 
709
  for key,value in self.cell_states_to_model.items():
710
  if (len(value) == 3) and isinstance(value, tuple):
711
  if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
@@ -713,14 +830,50 @@ class InSilicoPerturber:
713
  all_values = value[0]+value[1]+value[2]
714
  if len(all_values) == len(set(all_values)):
715
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
716
  else:
717
  logger.error(
718
- "Cell states to model must be a single-item dictionary with " \
719
- "key being cell attribute (e.g. 'disease') and value being " \
720
- "tuple of three lists indicating start state, goal end state, and alternate possible end states. " \
721
- "Values should all be unique. " \
722
- "For example: {'disease':(['dcm'],['ctrl'],['hcm'])}")
 
 
 
723
  raise
 
724
  if self.anchor_gene is not None:
725
  self.anchor_gene = None
726
  logger.warning(
@@ -770,6 +923,14 @@ class InSilicoPerturber:
770
  if self.cell_states_to_model is None:
771
  state_embs_dict = None
772
  else:
 
 
 
 
 
 
 
 
773
  # get dictionary of average cell state embeddings for comparison
774
  downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
775
  state_embs_dict = get_cell_state_avg_embs(model,
@@ -780,9 +941,9 @@ class InSilicoPerturber:
780
  self.forward_batch_size,
781
  self.nproc)
782
  # filter for start state cells
783
- start_state = list(self.cell_states_to_model.values())[0][0][0]
784
  def filter_for_origin(example):
785
- return example[list(self.cell_states_to_model.keys())[0]] in [start_state]
786
 
787
  filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
788
 
@@ -878,7 +1039,6 @@ class InSilicoPerturber:
878
  # or (perturbed_genes, "cell_emb") for avg cell emb change
879
  cos_sims_data = cos_sims_data.to("cuda")
880
  max_padded_len = cos_sims_data.shape[1]
881
-
882
  for j in range(cos_sims_data.shape[0]):
883
  # remove padding before mean pooling cell embedding
884
  original_length = original_lengths[j]
@@ -900,21 +1060,13 @@ class InSilicoPerturber:
900
  # update cos sims dict
901
  # key is tuple of (perturbed_genes, "cell_emb")
902
  # value is list of tuples of cos sims for cell_states_to_model
903
- origin_state_key = [value[0] for value in self.cell_states_to_model.values()][0][0]
904
  cos_sims_origin = cos_sims_data[origin_state_key]
905
  for j in range(cos_sims_origin.shape[0]):
906
- original_length = original_lengths[j]
907
- max_padded_len = cos_sims_origin.shape[1]
908
- indices_removed = indices_to_perturb[j]
909
- padding_to_remove = max_padded_len - (original_length \
910
- - len(self.tokens_to_perturb) \
911
- - len(indices_removed))
912
  data_list = []
913
  for data in list(cos_sims_data.values()):
914
  data_item = data.to("cuda")
915
- nonpadding_data_item = data_item[j][:-padding_to_remove]
916
- cell_data = torch.mean(nonpadding_data_item).item()
917
- data_list += [cell_data]
918
  cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
919
 
920
  with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
@@ -987,7 +1139,7 @@ class InSilicoPerturber:
987
  # update cos sims dict
988
  # key is tuple of (perturbed_gene, "cell_emb")
989
  # value is list of tuples of cos sims for cell_states_to_model
990
- origin_state_key = [value[0] for value in self.cell_states_to_model.values()][0][0]
991
  cos_sims_origin = cos_sims_data[origin_state_key]
992
 
993
  for j in range(cos_sims_origin.shape[0]):
@@ -1108,5 +1260,4 @@ class InSilicoPerturber:
1108
 
1109
  # save remainder cells
1110
  with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
1111
- pickle.dump(cos_sims_dict, fp)
1112
-
 
13
  emb_mode="cell",
14
  cell_emb_style="mean_pool",
15
  filter_data={"cell_type":["cardiomyocyte"]},
16
+ cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
17
  max_ncells=None,
18
  emb_layer=-1,
19
  forward_batch_size=100,
 
105
  data_sorted = data_subset.sort("length",reverse=True)
106
  return data_sorted
107
 
108
+ def get_possible_states(cell_states_to_model):
109
+ possible_states = []
110
+ for key in ["start_state","goal_state"]:
111
+ possible_states += [cell_states_to_model[key]]
112
+ possible_states += cell_states_to_model.get("alt_states",[])
113
+ return possible_states
114
+
115
  def forward_pass_single_cell(model, example_cell, layer_to_quant):
116
  example_cell.set_format(type="torch")
117
  input_data = example_cell["input_ids"]
 
153
  [example["input_ids"].insert(0, token) for token in example["tokens_to_perturb"][::-1]]
154
  return example
155
 
156
+ def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
157
+ # indices_to_remove is list of indices to remove
158
+ indices_to_keep = [i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove]
159
+ num_dims = emb.dim()
160
+ emb_slice = [slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)]
161
+ sliced_emb = emb[emb_slice]
162
+ return sliced_emb
163
+
164
+ def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
165
+ output_batch = torch.stack([
166
+ remove_indices_from_emb(emb_batch[i, :, :], idx, gene_dim-1) for
167
+ i, idx in enumerate(list_of_indices_to_remove)
168
+ ])
169
+ return output_batch
170
+
171
  def make_perturbation_batch(example_cell,
172
  perturb_type,
173
  tokens_to_perturb,
 
257
  num_proc):
258
 
259
  model_input_size = get_model_input_size(model)
260
+ possible_states = get_possible_states(cell_states_to_model)
261
  state_embs_dict = dict()
262
  for possible_state in possible_states:
263
  state_embs_list = []
264
+ original_lens = []
265
 
266
  def filter_states(example):
267
+ state_key = cell_states_to_model["state_key"]
268
+ return example[state_key] in [possible_state]
269
  filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
270
  total_batch_length = len(filtered_input_data_state)
271
  if ((total_batch_length-1)/forward_batch_size).is_integer():
 
278
  state_minibatch.set_format(type="torch")
279
 
280
  input_data_minibatch = state_minibatch["input_ids"]
281
+ original_lens += state_minibatch["length"]
282
  input_data_minibatch = pad_tensor_list(input_data_minibatch,
283
  max_len,
284
  pad_token_id,
285
  model_input_size)
286
+ attention_mask = gen_attention_mask(state_minibatch, max_len)
287
 
288
  with torch.no_grad():
289
  outputs = model(
290
+ input_ids = input_data_minibatch.to("cuda"),
291
+ attention_mask = attention_mask
292
  )
293
 
294
  state_embs_i = outputs.hidden_states[layer_to_quant]
 
296
  del outputs
297
  del state_minibatch
298
  del input_data_minibatch
299
+ del attention_mask
300
  del state_embs_i
301
  torch.cuda.empty_cache()
302
+
303
+ state_embs = torch.cat(state_embs_list)
304
+ avg_state_emb = mean_nonpadding_embs(state_embs, torch.Tensor(original_lens).to("cuda"))
305
+ avg_state_emb = torch.mean(avg_state_emb, dim=0, keepdim=True)
306
  state_embs_dict[possible_state] = avg_state_emb
307
  return state_embs_dict
308
 
 
321
  pad_token_id,
322
  model_input_size,
323
  nproc):
 
324
  cos = torch.nn.CosineSimilarity(dim=2)
325
  total_batch_length = len(perturbation_batch)
326
  if ((total_batch_length-1)/forward_batch_size).is_integer():
 
330
  comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
331
  cos_sims = []
332
  else:
333
+ possible_states = get_possible_states(cell_states_to_model)
334
  cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
335
 
336
  # measure length of each element in perturbation_batch
 
345
 
346
  # determine if need to pad or truncate batch
347
  minibatch_length_set = set(perturbation_minibatch["length"])
348
+ minibatch_lengths = perturbation_minibatch["length"]
349
  if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > model_input_size):
350
  needs_pad_or_trunc = True
351
  else:
352
  needs_pad_or_trunc = False
353
+ max_len = max(minibatch_length_set)
354
 
355
  if needs_pad_or_trunc == True:
356
  max_len = min(max(minibatch_length_set),model_input_size)
 
363
  perturbation_minibatch.set_format(type="torch")
364
 
365
  input_data_minibatch = perturbation_minibatch["input_ids"]
366
+ attention_mask = gen_attention_mask(perturbation_minibatch, max_len)
367
 
368
  # extract embeddings for perturbation minibatch
369
  with torch.no_grad():
370
  outputs = model(
371
+ input_ids = input_data_minibatch.to("cuda"),
372
+ attention_mask = attention_mask
373
  )
374
  del input_data_minibatch
375
  del perturbation_minibatch
376
+ del attention_mask
377
 
378
  if len(indices_to_perturb)>1:
379
  minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
 
394
  # truncate to the (model input size - # tokens to overexpress) to ensure comparability
395
  # since max input size of perturb batch will be reduced by # tokens to overexpress
396
  original_minibatch = original_emb.select([i for i in range(i, max_range)])
397
+ original_minibatch_lengths = original_minibatch["length"]
398
  original_minibatch_length_set = set(original_minibatch["length"])
399
  if perturb_type == "overexpress":
400
  new_max_len = model_input_size - len(tokens_to_perturb)
 
408
  original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
409
  original_minibatch.set_format(type="torch")
410
  original_input_data_minibatch = original_minibatch["input_ids"]
411
+ attention_mask = gen_attention_mask(original_minibatch, original_max_len)
412
  # extract embeddings for original minibatch
413
  with torch.no_grad():
414
  original_outputs = model(
415
+ input_ids = original_input_data_minibatch.to("cuda"),
416
+ attention_mask = attention_mask
417
  )
418
  del original_input_data_minibatch
419
  del original_minibatch
420
+ del attention_mask
421
 
422
  if len(indices_to_perturb)>1:
423
  original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
424
  else:
425
  original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
426
+
427
+ # embedding dimension of the genes
428
+ gene_dim = 1
429
+ # exclude overexpression due to case when genes are not expressed but being overexpressed
430
+ if perturb_type != "overexpress":
431
+ original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
432
+ indices_to_perturb,
433
+ gene_dim)
434
+
435
  # cosine similarity between original emb and batch items
436
  if cell_states_to_model is None:
437
  if perturb_group == False:
 
440
  minibatch_comparison = make_comparison_batch(original_minibatch_emb,
441
  indices_to_perturb,
442
  perturb_group)
443
+
444
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
445
  elif cell_states_to_model is not None:
446
  for state in possible_states:
 
453
  cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
454
  minibatch_emb,
455
  state_embs_dict[state],
456
+ perturb_group,
457
+ torch.tensor(original_minibatch_lengths, device="cuda"),
458
+ torch.tensor(minibatch_lengths, device="cuda"))
459
  del outputs
460
  del minibatch_emb
461
  if cell_states_to_model is None:
 
470
  return cos_sims_vs_alt_dict
471
 
472
  # calculate cos sim shift of perturbation with respect to origin and alternative cell
473
+ def cos_sim_shift(original_emb,
474
+ minibatch_emb,
475
+ end_emb,
476
+ perturb_group,
477
+ original_minibatch_lengths = None,
478
+ minibatch_lengths = None):
479
  cos = torch.nn.CosineSimilarity(dim=2)
480
+ if not perturb_group:
481
+ original_emb = torch.mean(original_emb,dim=0,keepdim=True)
482
  original_emb = original_emb[None, :]
483
+ origin_v_end = torch.squeeze(cos(original_emb, end_emb)) #test
484
+ else:
485
+ if original_emb.size() != minibatch_emb.size():
486
+ logger.error(
487
+ f"Embeddings are not the same dimensions. " \
488
+ f"original_emb is {original_emb.size()}. " \
489
+ f"minibatch_emb is {minibatch_emb.size()}. "
490
+ )
491
+ raise
492
+
493
+ if original_minibatch_lengths is not None:
494
+ original_emb = mean_nonpadding_embs(original_emb, original_minibatch_lengths)
495
+ # else:
496
+ # original_emb = torch.mean(original_emb,dim=1,keepdim=True)
497
+
498
+ end_emb = torch.unsqueeze(end_emb, 1)
499
+ origin_v_end = cos(original_emb, end_emb)
500
+ origin_v_end = torch.squeeze(origin_v_end)
501
+ if minibatch_lengths is not None:
502
+ perturb_emb = mean_nonpadding_embs(minibatch_emb, minibatch_lengths)
503
+ else:
504
+ perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
505
+
506
+ perturb_v_end = cos(perturb_emb, end_emb)
507
+ perturb_v_end = torch.squeeze(perturb_v_end)
508
  return [(perturb_v_end-origin_v_end).to("cpu")]
509
 
510
  def pad_list(input_ids, pad_token_id, max_len):
 
564
  # return stacked tensors
565
  return torch.stack(tensor_list)
566
 
567
+ def gen_attention_mask(minibatch_encoding, max_len = None):
568
+ if max_len == None:
569
+ max_len = max(minibatch_encoding["length"])
570
+ original_lens = minibatch_encoding["length"]
571
+ attention_mask = [[1]*original_len
572
+ +[0]*(max_len - original_len)
573
+ for original_len in original_lens]
574
+ return torch.tensor(attention_mask).to("cuda")
575
+
576
+ # get cell embeddings excluding padding
577
+ def mean_nonpadding_embs(embs, original_lens):
578
+ # mask based on padding lengths
579
+ mask = torch.arange(embs.size(1)).unsqueeze(0).to("cuda") < original_lens.unsqueeze(1)
580
+
581
+ # extend mask dimensions to match the embeddings tensor
582
+ mask = mask.unsqueeze(2).expand_as(embs)
583
+
584
+ # use the mask to zero out the embeddings in padded areas
585
+ masked_embs = embs * mask.float()
586
+
587
+ # sum and divide by the lengths to get the mean of non-padding embs
588
+ mean_embs = masked_embs.sum(1) / original_lens.view(-1, 1).float()
589
+ return mean_embs
590
+
591
  class InSilicoPerturber:
592
  valid_option_dict = {
593
  "perturb_type": {"delete","overexpress","inhibit","activate"},
 
673
  Otherwise, dictionary specifying .dataset column name and list of values to filter by.
674
  cell_states_to_model: None, dict
675
  Cell states to model if testing perturbations that achieve goal state change.
676
+ Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
677
+ state_key: key specifying name of column in .dataset that defines the start/goal states
678
+ start_state: value in the state_key column that specifies the start state
679
+ goal_state: value in the state_key column taht specifies the goal end state
680
+ alt_states: list of values in the state_key column that specify the alternate end states
681
+ For example: {"state_key": "disease",
682
+ "start_state": "dcm",
683
+ "goal_state": "nf",
684
+ "alt_states": ["hcm", "other1", "other2"]}
685
  max_ncells : None, int
686
  Maximum number of cells to test.
687
  If None, will test all cells.
 
812
 
813
  if self.cell_states_to_model is not None:
814
  if len(self.cell_states_to_model.items()) == 1:
815
+ logger.warning(
816
+ "The single value dictionary for cell_states_to_model will be " \
817
+ "replaced with a dictionary with named keys for start, goal, and alternate states. " \
818
+ "Please specify state_key, start_state, goal_state, and alt_states " \
819
+ "in the cell_states_to_model dictionary for future use. " \
820
+ "For example, cell_states_to_model={" \
821
+ "'state_key': 'disease', " \
822
+ "'start_state': 'dcm', " \
823
+ "'goal_state': 'nf', " \
824
+ "'alt_states': ['hcm', 'other1', 'other2']}"
825
+ )
826
  for key,value in self.cell_states_to_model.items():
827
  if (len(value) == 3) and isinstance(value, tuple):
828
  if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
 
830
  all_values = value[0]+value[1]+value[2]
831
  if len(all_values) == len(set(all_values)):
832
  continue
833
+ # reformat to the new named key format
834
+ state_values = flatten_list(list(self.cell_states_to_model.values()))
835
+ self.cell_states_to_model = {
836
+ "state_key": list(self.cell_states_to_model.keys())[0],
837
+ "start_state": state_values[0][0],
838
+ "goal_state": state_values[1][0],
839
+ "alt_states": state_values[2:][0]
840
+ }
841
+ elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
842
+ if (self.cell_states_to_model["state_key"] is None) \
843
+ or (self.cell_states_to_model["start_state"] is None) \
844
+ or (self.cell_states_to_model["goal_state"] is None):
845
+ logger.error(
846
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
847
+ raise
848
+
849
+ if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
850
+ logger.error(
851
+ "All states must be unique.")
852
+ raise
853
+
854
+ if self.cell_states_to_model["alt_states"] is not None:
855
+ if type(self.cell_states_to_model["alt_states"]) is not list:
856
+ logger.error(
857
+ "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
858
+ )
859
+ raise
860
+ if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
861
+ logger.error(
862
+ "All states must be unique.")
863
+ raise
864
+
865
  else:
866
  logger.error(
867
+ "cell_states_to_model must only have the following four keys: " \
868
+ "'state_key', 'start_state', 'goal_state', 'alt_states'." \
869
+ "For example, cell_states_to_model={" \
870
+ "'state_key': 'disease', " \
871
+ "'start_state': 'dcm', " \
872
+ "'goal_state': 'nf', " \
873
+ "'alt_states': ['hcm', 'other1', 'other2']}"
874
+ )
875
  raise
876
+
877
  if self.anchor_gene is not None:
878
  self.anchor_gene = None
879
  logger.warning(
 
923
  if self.cell_states_to_model is None:
924
  state_embs_dict = None
925
  else:
926
+ # confirm that all states are valid to prevent futile filtering
927
+ state_name = self.cell_states_to_model["state_key"]
928
+ state_values = filtered_input_data[state_name]
929
+ for value in get_possible_states(self.cell_states_to_model):
930
+ if value not in state_values:
931
+ logger.error(
932
+ f"{value} is not present in the dataset's {state_name} attribute.")
933
+ raise
934
  # get dictionary of average cell state embeddings for comparison
935
  downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
936
  state_embs_dict = get_cell_state_avg_embs(model,
 
941
  self.forward_batch_size,
942
  self.nproc)
943
  # filter for start state cells
944
+ start_state = self.cell_states_to_model["start_state"]
945
  def filter_for_origin(example):
946
+ return example[state_name] in [start_state]
947
 
948
  filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
949
 
 
1039
  # or (perturbed_genes, "cell_emb") for avg cell emb change
1040
  cos_sims_data = cos_sims_data.to("cuda")
1041
  max_padded_len = cos_sims_data.shape[1]
 
1042
  for j in range(cos_sims_data.shape[0]):
1043
  # remove padding before mean pooling cell embedding
1044
  original_length = original_lengths[j]
 
1060
  # update cos sims dict
1061
  # key is tuple of (perturbed_genes, "cell_emb")
1062
  # value is list of tuples of cos sims for cell_states_to_model
1063
+ origin_state_key = self.cell_states_to_model["start_state"]
1064
  cos_sims_origin = cos_sims_data[origin_state_key]
1065
  for j in range(cos_sims_origin.shape[0]):
 
 
 
 
 
 
1066
  data_list = []
1067
  for data in list(cos_sims_data.values()):
1068
  data_item = data.to("cuda")
1069
+ data_list += [data_item[j].item()]
 
 
1070
  cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
1071
 
1072
  with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
 
1139
  # update cos sims dict
1140
  # key is tuple of (perturbed_gene, "cell_emb")
1141
  # value is list of tuples of cos sims for cell_states_to_model
1142
+ origin_state_key = self.cell_states_to_model["start_state"]
1143
  cos_sims_origin = cos_sims_data[origin_state_key]
1144
 
1145
  for j in range(cos_sims_origin.shape[0]):
 
1260
 
1261
  # save remainder cells
1262
  with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
1263
+ pickle.dump(cos_sims_dict, fp)
 
geneformer/in_silico_perturber_stats.py CHANGED
@@ -6,7 +6,10 @@ Usage:
6
  ispstats = InSilicoPerturberStats(mode="goal_state_shift",
7
  combos=0,
8
  anchor_gene=None,
9
- cell_states_to_model={"disease":(["dcm"],["ctrl"],["hcm"])})
 
 
 
10
  ispstats.get_stats("path/to/input_data",
11
  None,
12
  "path/to/output_directory",
@@ -26,6 +29,8 @@ from scipy.stats import ranksums
26
  from sklearn.mixture import GaussianMixture
27
  from tqdm.notebook import trange, tqdm
28
 
 
 
29
  from .tokenizer import TOKEN_DICTIONARY_FILE
30
 
31
  GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
@@ -123,10 +128,10 @@ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list):
123
 
124
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
125
  def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_perturbed):
126
- cell_state_key = list(cell_states_to_model.keys())[0]
127
- if cell_states_to_model[cell_state_key][2] == []:
128
  alt_end_state_exists = False
129
- elif (len(cell_states_to_model[cell_state_key][2]) > 0) and (cell_states_to_model[cell_state_key][2] != [None]):
130
  alt_end_state_exists = True
131
 
132
  # for single perturbation in multiple cells, there are no random perturbations to compare to
@@ -231,10 +236,12 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_
231
  # quantify number of detections of each gene
232
  cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
233
 
234
- # sort by shift to desired state
235
- cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_to_goal_end",
 
 
236
  "Goal_end_FDR"],
237
- ascending=[False,True])
238
 
239
  return cos_sims_full_df
240
 
@@ -272,9 +279,11 @@ def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
272
 
273
  cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(cos_sims_full_df["Test_vs_null_pval"])
274
 
275
- cos_sims_full_df = cos_sims_full_df.sort_values(by=["Test_vs_null_avg_shift",
 
 
276
  "Test_vs_null_FDR"],
277
- ascending=[False,True])
278
  return cos_sims_full_df
279
 
280
  # stats for identifying perturbations with largest effect within a given set of cells
@@ -441,9 +450,15 @@ class InSilicoPerturberStats:
441
  analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
442
  cell_states_to_model: None, dict
443
  Cell states to model if testing perturbations that achieve goal state change.
444
- Single-item dictionary with key being cell attribute (e.g. "disease").
445
- Value is tuple of three lists indicating start state, goal end state, and alternate possible end states.
446
- If no alternate possible end states, third list should be empty (i.e. the third list should be []).
 
 
 
 
 
 
447
  token_dictionary_file : Path
448
  Path to pickle file containing token dictionary (Ensembl ID:token).
449
  gene_name_id_dictionary_file : Path
@@ -494,6 +509,17 @@ class InSilicoPerturberStats:
494
 
495
  if self.cell_states_to_model is not None:
496
  if len(self.cell_states_to_model.items()) == 1:
 
 
 
 
 
 
 
 
 
 
 
497
  for key,value in self.cell_states_to_model.items():
498
  if (len(value) == 3) and isinstance(value, tuple):
499
  if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
@@ -501,14 +527,50 @@ class InSilicoPerturberStats:
501
  all_values = value[0]+value[1]+value[2]
502
  if len(all_values) == len(set(all_values)):
503
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  else:
505
  logger.error(
506
- "Cell states to model must be a single-item dictionary with " \
507
- "key being cell attribute (e.g. 'disease') and value being " \
508
- "tuple of three lists indicating start state, goal end state, and alternate possible end states. " \
509
- "Values should all be unique. " \
510
- "For example: {'disease':(['start_state'],['ctrl'],['alt_end'])}")
 
 
 
511
  raise
 
512
  if self.anchor_gene is not None:
513
  self.anchor_gene = None
514
  logger.warning(
@@ -565,6 +627,7 @@ class InSilicoPerturberStats:
565
  "Gene_name": gene name
566
  "Ensembl_ID": gene Ensembl ID
567
  "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
 
568
 
569
  "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
570
  "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
 
6
  ispstats = InSilicoPerturberStats(mode="goal_state_shift",
7
  combos=0,
8
  anchor_gene=None,
9
+ cell_states_to_model={"state_key": "disease",
10
+ "start_state": "dcm",
11
+ "goal_state": "nf",
12
+ "alt_states": ["hcm", "other1", "other2"]})
13
  ispstats.get_stats("path/to/input_data",
14
  None,
15
  "path/to/output_directory",
 
29
  from sklearn.mixture import GaussianMixture
30
  from tqdm.notebook import trange, tqdm
31
 
32
+ from .in_silico_perturber import flatten_list
33
+
34
  from .tokenizer import TOKEN_DICTIONARY_FILE
35
 
36
  GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
 
128
 
129
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
130
  def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model, genes_perturbed):
131
+ cell_state_key = cell_states_to_model["start_state"]
132
+ if "alt_states" not in cell_states_to_model.keys():
133
  alt_end_state_exists = False
134
+ elif (len(cell_states_to_model["alt_states"]) > 0) and (cell_states_to_model["alt_states"] != [None]):
135
  alt_end_state_exists = True
136
 
137
  # for single perturbation in multiple cells, there are no random perturbations to compare to
 
236
  # quantify number of detections of each gene
237
  cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
238
 
239
+ # sort by shift to desired state\
240
+ cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]]
241
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
242
+ "Shift_to_goal_end",
243
  "Goal_end_FDR"],
244
+ ascending=[False,False,True])
245
 
246
  return cos_sims_full_df
247
 
 
279
 
280
  cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(cos_sims_full_df["Test_vs_null_pval"])
281
 
282
+ cos_sims_full_df["Sig"] = [1 if fdr<0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]]
283
+ cos_sims_full_df = cos_sims_full_df.sort_values(by=["Sig",
284
+ "Test_vs_null_avg_shift",
285
  "Test_vs_null_FDR"],
286
+ ascending=[False,False,True])
287
  return cos_sims_full_df
288
 
289
  # stats for identifying perturbations with largest effect within a given set of cells
 
450
  analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
451
  cell_states_to_model: None, dict
452
  Cell states to model if testing perturbations that achieve goal state change.
453
+ Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
454
+ state_key: key specifying name of column in .dataset that defines the start/goal states
455
+ start_state: value in the state_key column that specifies the start state
456
+ goal_state: value in the state_key column taht specifies the goal end state
457
+ alt_states: list of values in the state_key column that specify the alternate end states
458
+ For example: {"state_key": "disease",
459
+ "start_state": "dcm",
460
+ "goal_state": "nf",
461
+ "alt_states": ["hcm", "other1", "other2"]}
462
  token_dictionary_file : Path
463
  Path to pickle file containing token dictionary (Ensembl ID:token).
464
  gene_name_id_dictionary_file : Path
 
509
 
510
  if self.cell_states_to_model is not None:
511
  if len(self.cell_states_to_model.items()) == 1:
512
+ logger.warning(
513
+ "The single value dictionary for cell_states_to_model will be " \
514
+ "replaced with a dictionary with named keys for start, goal, and alternate states. " \
515
+ "Please specify state_key, start_state, goal_state, and alt_states " \
516
+ "in the cell_states_to_model dictionary for future use. " \
517
+ "For example, cell_states_to_model={" \
518
+ "'state_key': 'disease', " \
519
+ "'start_state': 'dcm', " \
520
+ "'goal_state': 'nf', " \
521
+ "'alt_states': ['hcm', 'other1', 'other2']}"
522
+ )
523
  for key,value in self.cell_states_to_model.items():
524
  if (len(value) == 3) and isinstance(value, tuple):
525
  if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
 
527
  all_values = value[0]+value[1]+value[2]
528
  if len(all_values) == len(set(all_values)):
529
  continue
530
+ # reformat to the new named key format
531
+ state_values = flatten_list(list(self.cell_states_to_model.values()))
532
+ self.cell_states_to_model = {
533
+ "state_key": list(self.cell_states_to_model.keys())[0],
534
+ "start_state": state_values[0][0],
535
+ "goal_state": state_values[1][0],
536
+ "alt_states": state_values[2:][0]
537
+ }
538
+ elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
539
+ if (self.cell_states_to_model["state_key"] is None) \
540
+ or (self.cell_states_to_model["start_state"] is None) \
541
+ or (self.cell_states_to_model["goal_state"] is None):
542
+ logger.error(
543
+ "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model.")
544
+ raise
545
+
546
+ if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
547
+ logger.error(
548
+ "All states must be unique.")
549
+ raise
550
+
551
+ if self.cell_states_to_model["alt_states"] is not None:
552
+ if type(self.cell_states_to_model["alt_states"]) is not list:
553
+ logger.error(
554
+ "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
555
+ )
556
+ raise
557
+ if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
558
+ logger.error(
559
+ "All states must be unique.")
560
+ raise
561
+
562
  else:
563
  logger.error(
564
+ "cell_states_to_model must only have the following four keys: " \
565
+ "'state_key', 'start_state', 'goal_state', 'alt_states'." \
566
+ "For example, cell_states_to_model={" \
567
+ "'state_key': 'disease', " \
568
+ "'start_state': 'dcm', " \
569
+ "'goal_state': 'nf', " \
570
+ "'alt_states': ['hcm', 'other1', 'other2']}"
571
+ )
572
  raise
573
+
574
  if self.anchor_gene is not None:
575
  self.anchor_gene = None
576
  logger.warning(
 
627
  "Gene_name": gene name
628
  "Ensembl_ID": gene Ensembl ID
629
  "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
630
+ "Sig": 1 if FDR<0.05, otherwise 0
631
 
632
  "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
633
  "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation