Fixed bugs related to overexpressing genes

#229
by davidjwen - opened
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +92 -78
geneformer/in_silico_perturber.py CHANGED
@@ -151,6 +151,7 @@ def overexpress_tokens(example):
151
  if example["perturb_index"] != [-100]:
152
  example = delete_indices(example)
153
  [example["input_ids"].insert(0, token) for token in example["tokens_to_perturb"][::-1]]
 
154
  return example
155
 
156
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
@@ -163,8 +164,8 @@ def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
163
 
164
  def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
165
  output_batch = torch.stack([
166
- remove_indices_from_emb(emb_batch[i, :, :], idx, gene_dim-1) for
167
- i, idx in enumerate(list_of_indices_to_remove)
168
  ])
169
  return output_batch
170
 
@@ -179,7 +180,7 @@ def make_perturbation_batch(example_cell,
179
  range_start = 1
180
  elif perturb_type in ["delete","inhibit"]:
181
  range_start = 0
182
- indices_to_perturb = [[i] for i in range(range_start,example_cell["length"][0])]
183
  elif combo_lvl>0 and (anchor_token is not None):
184
  example_input_ids = example_cell["input_ids "][0]
185
  anchor_index = example_input_ids.index(anchor_token[0])
@@ -323,47 +324,52 @@ def quant_cos_sims(model,
323
  nproc):
324
  cos = torch.nn.CosineSimilarity(dim=2)
325
  total_batch_length = len(perturbation_batch)
 
326
  if ((total_batch_length-1)/forward_batch_size).is_integer():
327
  forward_batch_size = forward_batch_size-1
 
 
 
 
328
  if cell_states_to_model is None:
329
- if perturb_group == False: # (if perturb_group is True, original_emb is filtered_input_data)
330
- comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
331
  cos_sims = []
332
  else:
333
  possible_states = get_possible_states(cell_states_to_model)
334
- cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
335
 
336
  # measure length of each element in perturbation_batch
337
  perturbation_batch = perturbation_batch.map(
338
  measure_length, num_proc=nproc
339
  )
340
-
341
- for i in range(0, total_batch_length, forward_batch_size):
342
- max_range = min(i+forward_batch_size, total_batch_length)
343
-
344
- perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
345
- # determine if need to pad or truncate batch
346
- minibatch_length_set = set(perturbation_minibatch["length"])
347
- minibatch_lengths = perturbation_minibatch["length"]
348
- if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > model_input_size):
349
  needs_pad_or_trunc = True
350
  else:
351
  needs_pad_or_trunc = False
352
  max_len = max(minibatch_length_set)
353
 
354
- if needs_pad_or_trunc == True:
355
- max_len = min(max(minibatch_length_set),model_input_size)
 
 
 
 
356
  def pad_or_trunc_example(example):
357
  example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
358
  pad_token_id,
359
  max_len)
360
  return example
361
- perturbation_minibatch = perturbation_minibatch.map(pad_or_trunc_example, num_proc=nproc)
362
 
363
- perturbation_minibatch.set_format(type="torch")
364
 
365
- input_data_minibatch = perturbation_minibatch["input_ids"]
366
- attention_mask = gen_attention_mask(perturbation_minibatch, max_len)
367
 
368
  # extract embeddings for perturbation minibatch
369
  with torch.no_grad():
@@ -371,9 +377,13 @@ def quant_cos_sims(model,
371
  input_ids = input_data_minibatch.to("cuda"),
372
  attention_mask = attention_mask
373
  )
374
- del input_data_minibatch
375
- del perturbation_minibatch
376
- del attention_mask
 
 
 
 
377
 
378
  if len(indices_to_perturb)>1:
379
  minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
@@ -386,7 +396,8 @@ def quant_cos_sims(model,
386
  overexpressed_to_remove = 1
387
  if perturb_group == True:
388
  overexpressed_to_remove = len(tokens_to_perturb)
389
- minibatch_emb = minibatch_emb[:,overexpressed_to_remove:,:]
 
390
 
391
  # if quantifying single perturbation in multiple different cells, pad original batch and extract embs
392
  if perturb_group == True:
@@ -394,56 +405,50 @@ def quant_cos_sims(model,
394
  # truncate to the (model input size - # tokens to overexpress) to ensure comparability
395
  # since max input size of perturb batch will be reduced by # tokens to overexpress
396
  original_minibatch = original_emb.select([i for i in range(i, max_range)])
397
- original_minibatch_lengths = original_minibatch["length"]
398
- original_minibatch_length_set = set(original_minibatch["length"])
399
-
400
- indices_to_perturb_minibatch = indices_to_perturb[i:i+forward_batch_size]
401
-
402
- if perturb_type == "overexpress":
403
- new_max_len = model_input_size - len(tokens_to_perturb)
404
- else:
405
- new_max_len = model_input_size
406
- if (len(original_minibatch_length_set) > 1) or (max(original_minibatch_length_set) > new_max_len):
407
- new_max_len = min(max(original_minibatch_length_set),new_max_len)
408
- def pad_or_trunc_example(example):
409
- example["input_ids"] = pad_or_truncate_encoding(example["input_ids"], pad_token_id, new_max_len)
410
- return example
411
- original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
412
- original_minibatch.set_format(type="torch")
413
- original_input_data_minibatch = original_minibatch["input_ids"]
414
- attention_mask = gen_attention_mask(original_minibatch, new_max_len)
415
- # extract embeddings for original minibatch
416
- with torch.no_grad():
417
- original_outputs = model(
418
- input_ids = original_input_data_minibatch.to("cuda"),
419
- attention_mask = attention_mask
420
- )
421
- del original_input_data_minibatch
422
- del original_minibatch
423
- del attention_mask
424
 
425
  if len(indices_to_perturb)>1:
426
  original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
427
  else:
428
  original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
429
 
430
- # embedding dimension of the genes
431
- gene_dim = 1
432
- # exclude overexpression due to case when genes are not expressed but being overexpressed
433
- if perturb_type != "overexpress":
434
- original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
435
- indices_to_perturb_minibatch,
436
- gene_dim)
 
 
 
 
 
 
 
 
 
 
 
 
437
 
 
 
 
 
438
  # cosine similarity between original emb and batch items
439
  if cell_states_to_model is None:
440
  if perturb_group == False:
441
  minibatch_comparison = comparison_batch[i:max_range]
442
  elif perturb_group == True:
443
  minibatch_comparison = original_minibatch_emb
444
-
445
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
446
  elif cell_states_to_model is not None:
 
 
 
 
 
447
  for state in possible_states:
448
  if perturb_group == False:
449
  cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb,
@@ -455,12 +460,14 @@ def quant_cos_sims(model,
455
  minibatch_emb,
456
  state_embs_dict[state],
457
  perturb_group,
458
- torch.tensor(original_minibatch_lengths, device="cuda"),
459
- torch.tensor(minibatch_lengths, device="cuda"))
460
  del outputs
461
  del minibatch_emb
462
  if cell_states_to_model is None:
463
  del minibatch_comparison
 
 
464
  torch.cuda.empty_cache()
465
  if cell_states_to_model is None:
466
  cos_sims_stack = torch.cat(cos_sims)
@@ -470,6 +477,7 @@ def quant_cos_sims(model,
470
  cos_sims_vs_alt_dict[state] = torch.cat(cos_sims_vs_alt_dict[state])
471
  return cos_sims_vs_alt_dict
472
 
 
473
  # calculate cos sim shift of perturbation with respect to origin and alternative cell
474
  def cos_sim_shift(original_emb,
475
  minibatch_emb,
@@ -478,34 +486,32 @@ def cos_sim_shift(original_emb,
478
  original_minibatch_lengths = None,
479
  minibatch_lengths = None):
480
  cos = torch.nn.CosineSimilarity(dim=2)
 
 
 
 
 
 
 
481
  if not perturb_group:
482
- original_emb = torch.mean(original_emb,dim=0,keepdim=True)
483
- original_emb = original_emb[None, :]
484
- origin_v_end = torch.squeeze(cos(original_emb, end_emb)) #test
485
  else:
486
- if original_emb.size() != minibatch_emb.size():
487
- logger.error(
488
- f"Embeddings are not the same dimensions. " \
489
- f"original_emb is {original_emb.size()}. " \
490
- f"minibatch_emb is {minibatch_emb.size()}. "
491
- )
492
- raise
493
-
494
  if original_minibatch_lengths is not None:
495
  original_emb = mean_nonpadding_embs(original_emb, original_minibatch_lengths)
496
  # else:
497
  # original_emb = torch.mean(original_emb,dim=1,keepdim=True)
498
 
499
  end_emb = torch.unsqueeze(end_emb, 1)
500
- origin_v_end = cos(original_emb, end_emb)
501
- origin_v_end = torch.squeeze(origin_v_end)
502
  if minibatch_lengths is not None:
503
  perturb_emb = mean_nonpadding_embs(minibatch_emb, minibatch_lengths)
504
  else:
505
  perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
506
-
507
  perturb_v_end = cos(perturb_emb, end_emb)
508
  perturb_v_end = torch.squeeze(perturb_v_end)
 
 
509
  return [(perturb_v_end-origin_v_end).to("cpu")]
510
 
511
  def pad_list(input_ids, pad_token_id, max_len):
@@ -1152,7 +1158,11 @@ class InSilicoPerturber:
1152
  j_index = torch.squeeze(j_index)
1153
  else:
1154
  j_index = torch.tensor([j])
1155
- perturbed_gene = torch.index_select(gene_list, 0, j_index)
 
 
 
 
1156
 
1157
  if perturbed_gene.shape[0]==1:
1158
  perturbed_gene = perturbed_gene.item()
@@ -1183,7 +1193,11 @@ class InSilicoPerturber:
1183
  j_index = torch.squeeze(j_index)
1184
  else:
1185
  j_index = torch.tensor([j])
1186
- perturbed_gene = torch.index_select(gene_list, 0, j_index)
 
 
 
 
1187
 
1188
  if perturbed_gene.shape[0]==1:
1189
  perturbed_gene = perturbed_gene.item()
 
151
  if example["perturb_index"] != [-100]:
152
  example = delete_indices(example)
153
  [example["input_ids"].insert(0, token) for token in example["tokens_to_perturb"][::-1]]
154
+
155
  return example
156
 
157
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
 
164
 
165
  def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
166
  output_batch = torch.stack([
167
+ remove_indices_from_emb(emb_batch[i, :, :], idxs, gene_dim-1) for
168
+ i, idxs in enumerate(list_of_indices_to_remove)
169
  ])
170
  return output_batch
171
 
 
180
  range_start = 1
181
  elif perturb_type in ["delete","inhibit"]:
182
  range_start = 0
183
+ indices_to_perturb = [[i] for i in range(range_start, example_cell["length"][0])]
184
  elif combo_lvl>0 and (anchor_token is not None):
185
  example_input_ids = example_cell["input_ids "][0]
186
  anchor_index = example_input_ids.index(anchor_token[0])
 
324
  nproc):
325
  cos = torch.nn.CosineSimilarity(dim=2)
326
  total_batch_length = len(perturbation_batch)
327
+
328
  if ((total_batch_length-1)/forward_batch_size).is_integer():
329
  forward_batch_size = forward_batch_size-1
330
+
331
+ if perturb_group == False:
332
+ comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
333
+
334
  if cell_states_to_model is None:
 
 
335
  cos_sims = []
336
  else:
337
  possible_states = get_possible_states(cell_states_to_model)
338
+ cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for _ in range(len(possible_states))]))
339
 
340
  # measure length of each element in perturbation_batch
341
  perturbation_batch = perturbation_batch.map(
342
  measure_length, num_proc=nproc
343
  )
344
+
345
+ def compute_batch_embeddings(minibatch, _max_len = None):
346
+ minibatch_lengths = minibatch["length"]
347
+ minibatch_length_set = set(minibatch_lengths)
348
+ max_len = model_input_size
349
+
350
+ if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > max_len):
 
 
351
  needs_pad_or_trunc = True
352
  else:
353
  needs_pad_or_trunc = False
354
  max_len = max(minibatch_length_set)
355
 
356
+
357
+ if needs_pad_or_trunc == True:
358
+ if _max_len is None:
359
+ max_len = min(max(minibatch_length_set), max_len)
360
+ else:
361
+ max_len = _max_len
362
  def pad_or_trunc_example(example):
363
  example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
364
  pad_token_id,
365
  max_len)
366
  return example
367
+ minibatch = minibatch.map(pad_or_trunc_example, num_proc=nproc)
368
 
369
+ minibatch.set_format(type="torch")
370
 
371
+ input_data_minibatch = minibatch["input_ids"]
372
+ attention_mask = gen_attention_mask(minibatch, max_len)
373
 
374
  # extract embeddings for perturbation minibatch
375
  with torch.no_grad():
 
377
  input_ids = input_data_minibatch.to("cuda"),
378
  attention_mask = attention_mask
379
  )
380
+
381
+ return outputs, max_len
382
+
383
+ for i in range(0, total_batch_length, forward_batch_size):
384
+ max_range = min(i+forward_batch_size, total_batch_length)
385
+ perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
386
+ outputs, mini_max_len = compute_batch_embeddings(perturbation_minibatch)
387
 
388
  if len(indices_to_perturb)>1:
389
  minibatch_emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
 
396
  overexpressed_to_remove = 1
397
  if perturb_group == True:
398
  overexpressed_to_remove = len(tokens_to_perturb)
399
+ minibatch_emb = minibatch_emb[:, overexpressed_to_remove: ,:]
400
+
401
 
402
  # if quantifying single perturbation in multiple different cells, pad original batch and extract embs
403
  if perturb_group == True:
 
405
  # truncate to the (model input size - # tokens to overexpress) to ensure comparability
406
  # since max input size of perturb batch will be reduced by # tokens to overexpress
407
  original_minibatch = original_emb.select([i for i in range(i, max_range)])
408
+ original_outputs, orig_max_len = compute_batch_embeddings(original_minibatch, mini_max_len)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
  if len(indices_to_perturb)>1:
411
  original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
412
  else:
413
  original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
414
 
415
+ # if we overexpress genes that aren't already expressed,
416
+ # we need to remove genes to make sure the embeddings are of a consistent size
417
+ # get rid of the bottom n genes/padding since those will get truncated anyways
418
+ # multiple perturbations is more complicated because if 1 out of n perturbed genes is expressed
419
+ # the idxs will still not be [-100]
420
+ if len(tokens_to_perturb) == 1:
421
+ indices_to_perturb_minibatch = [idx if idx != [-100] else [orig_max_len - 1]
422
+ for idx in indices_to_perturb[i:max_range]]
423
+ else:
424
+ num_perturbed = len(tokens_to_perturb)
425
+ indices_to_perturb_minibatch = []
426
+ end_range = [i for i in range(orig_max_len - tokens_to_perturb, orig_max_len)]
427
+ for idx in indices_to_perturb[i:i+max_range]:
428
+ if idx == [-100]:
429
+ indices_to_perturb_minibatch.append(end_range)
430
+ elif len(idx) < len(tokens_to_perturb):
431
+ indices_to_perturb_minibatch.append(idx + end_range[-num_perturbed:])
432
+ else:
433
+ indices_to_perturb_minibatch.append(idx)
434
 
435
+ original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
436
+ indices_to_perturb_minibatch,
437
+ gene_dim=1)
438
+
439
  # cosine similarity between original emb and batch items
440
  if cell_states_to_model is None:
441
  if perturb_group == False:
442
  minibatch_comparison = comparison_batch[i:max_range]
443
  elif perturb_group == True:
444
  minibatch_comparison = original_minibatch_emb
 
445
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
446
  elif cell_states_to_model is not None:
447
+ if perturb_group == False:
448
+ original_emb = comparison_batch[i:max_range]
449
+ else:
450
+ original_minibatch_lengths = torch.tensor(original_minibatch["length"], device="cuda")
451
+ minibatch_lengths = torch.tensor(perturbation_minibatch["length"], device="cuda")
452
  for state in possible_states:
453
  if perturb_group == False:
454
  cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb,
 
460
  minibatch_emb,
461
  state_embs_dict[state],
462
  perturb_group,
463
+ original_minibatch_lengths,
464
+ minibatch_lengths)
465
  del outputs
466
  del minibatch_emb
467
  if cell_states_to_model is None:
468
  del minibatch_comparison
469
+ if perturb_group == True:
470
+ del original_minibatch_emb
471
  torch.cuda.empty_cache()
472
  if cell_states_to_model is None:
473
  cos_sims_stack = torch.cat(cos_sims)
 
477
  cos_sims_vs_alt_dict[state] = torch.cat(cos_sims_vs_alt_dict[state])
478
  return cos_sims_vs_alt_dict
479
 
480
+
481
  # calculate cos sim shift of perturbation with respect to origin and alternative cell
482
  def cos_sim_shift(original_emb,
483
  minibatch_emb,
 
486
  original_minibatch_lengths = None,
487
  minibatch_lengths = None):
488
  cos = torch.nn.CosineSimilarity(dim=2)
489
+ if original_emb.size() != minibatch_emb.size():
490
+ logger.error(
491
+ f"Embeddings are not the same dimensions. " \
492
+ f"original_emb is {original_emb.size()}. " \
493
+ f"minibatch_emb is {minibatch_emb.size()}. "
494
+ )
495
+ raise
496
  if not perturb_group:
497
+ original_emb = torch.mean(original_emb,dim=1,keepdim=True)
498
+ origin_v_end = torch.squeeze(cos(original_emb, end_emb))
 
499
  else:
 
 
 
 
 
 
 
 
500
  if original_minibatch_lengths is not None:
501
  original_emb = mean_nonpadding_embs(original_emb, original_minibatch_lengths)
502
  # else:
503
  # original_emb = torch.mean(original_emb,dim=1,keepdim=True)
504
 
505
  end_emb = torch.unsqueeze(end_emb, 1)
506
+ origin_v_end = torch.squeeze(cos(original_emb, end_emb))
 
507
  if minibatch_lengths is not None:
508
  perturb_emb = mean_nonpadding_embs(minibatch_emb, minibatch_lengths)
509
  else:
510
  perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
 
511
  perturb_v_end = cos(perturb_emb, end_emb)
512
  perturb_v_end = torch.squeeze(perturb_v_end)
513
+ if (perturb_v_end-origin_v_end).numel() == 1:
514
+ return [([perturb_v_end-origin_v_end]).to("cpu")]
515
  return [(perturb_v_end-origin_v_end).to("cpu")]
516
 
517
  def pad_list(input_ids, pad_token_id, max_len):
 
1158
  j_index = torch.squeeze(j_index)
1159
  else:
1160
  j_index = torch.tensor([j])
1161
+
1162
+ if self.perturb_type in ("overexpress", "activate"):
1163
+ perturbed_gene = torch.index_select(gene_list, 0, j_index + 1)
1164
+ else:
1165
+ perturbed_gene = torch.index_select(gene_list, 0, j_index)
1166
 
1167
  if perturbed_gene.shape[0]==1:
1168
  perturbed_gene = perturbed_gene.item()
 
1193
  j_index = torch.squeeze(j_index)
1194
  else:
1195
  j_index = torch.tensor([j])
1196
+
1197
+ if self.perturb_type in ("overexpress", "activate"):
1198
+ perturbed_gene = torch.index_select(gene_list, 0, j_index + 1)
1199
+ else:
1200
+ perturbed_gene = torch.index_select(gene_list, 0, j_index)
1201
 
1202
  if perturbed_gene.shape[0]==1:
1203
  perturbed_gene = perturbed_gene.item()