hchen725 commited on
Commit
f0ec9ca
1 Parent(s): ead0550

Update geneformer/emb_extractor.py

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +54 -12
geneformer/emb_extractor.py CHANGED
@@ -38,12 +38,14 @@ def get_embs(
38
  layer_to_quant,
39
  pad_token_id,
40
  forward_batch_size,
 
 
41
  summary_stat=None,
42
  silent=False,
43
  ):
44
  model_input_size = pu.get_model_input_size(model)
45
  total_batch_length = len(filtered_input_data)
46
-
47
  if summary_stat is None:
48
  embs_list = []
49
  elif summary_stat is not None:
@@ -67,9 +69,25 @@ def get_embs(
67
  k: [TDigest() for _ in range(emb_dims)] for k in gene_set
68
  }
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  overall_max_len = 0
71
-
72
- for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
73
  max_range = min(i + forward_batch_size, total_batch_length)
74
 
75
  minibatch = filtered_input_data.select([i for i in range(i, max_range)])
@@ -90,9 +108,16 @@ def get_embs(
90
  )
91
 
92
  embs_i = outputs.hidden_states[layer_to_quant]
93
-
94
  if emb_mode == "cell":
95
- mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
 
 
 
 
 
 
 
96
  if summary_stat is None:
97
  embs_list.append(mean_embs)
98
  elif summary_stat is not None:
@@ -121,7 +146,13 @@ def get_embs(
121
  accumulate_tdigests(
122
  embs_tdigests_dict[int(k)], dict_h[k], emb_dims
123
  )
124
-
 
 
 
 
 
 
125
  overall_max_len = max(overall_max_len, max_len)
126
  del outputs
127
  del minibatch
@@ -129,7 +160,8 @@ def get_embs(
129
  del embs_i
130
 
131
  torch.cuda.empty_cache()
132
-
 
133
  if summary_stat is None:
134
  if emb_mode == "cell":
135
  embs_stack = torch.cat(embs_list, dim=0)
@@ -142,6 +174,8 @@ def get_embs(
142
  1,
143
  pu.pad_3d_tensor,
144
  )
 
 
145
 
146
  # calculate summary stat embs from approximated tdigests
147
  elif summary_stat is not None:
@@ -348,7 +382,7 @@ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
348
  bbox_to_anchor=(0.5, 1),
349
  facecolor="white",
350
  )
351
-
352
  plt.savefig(output_file, bbox_inches="tight")
353
 
354
 
@@ -356,7 +390,7 @@ class EmbExtractor:
356
  valid_option_dict = {
357
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
358
  "num_classes": {int},
359
- "emb_mode": {"cell", "gene"},
360
  "cell_emb_style": {"mean_pool"},
361
  "gene_emb_style": {"mean_pool"},
362
  "filter_data": {None, dict},
@@ -365,6 +399,7 @@ class EmbExtractor:
365
  "emb_label": {None, list},
366
  "labels_to_plot": {None, list},
367
  "forward_batch_size": {int},
 
368
  "nproc": {int},
369
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
370
  }
@@ -384,7 +419,7 @@ class EmbExtractor:
384
  forward_batch_size=100,
385
  nproc=4,
386
  summary_stat=None,
387
- token_dictionary_file=TOKEN_DICTIONARY_FILE,
388
  ):
389
  """
390
  Initialize embedding extractor.
@@ -434,6 +469,7 @@ class EmbExtractor:
434
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
435
  | Non-exact is slower but more memory-efficient.
436
  token_dictionary_file : Path
 
437
  | Path to pickle file containing token dictionary (Ensembl ID:token).
438
 
439
  **Examples:**
@@ -463,6 +499,7 @@ class EmbExtractor:
463
  self.emb_layer = emb_layer
464
  self.emb_label = emb_label
465
  self.labels_to_plot = labels_to_plot
 
466
  self.forward_batch_size = forward_batch_size
467
  self.nproc = nproc
468
  if (summary_stat is not None) and ("exact" in summary_stat):
@@ -475,6 +512,8 @@ class EmbExtractor:
475
  self.validate_options()
476
 
477
  # load token dictionary (Ensembl IDs:token)
 
 
478
  with open(token_dictionary_file, "rb") as f:
479
  self.gene_token_dict = pickle.load(f)
480
 
@@ -490,7 +529,7 @@ class EmbExtractor:
490
  continue
491
  valid_type = False
492
  for option in valid_options:
493
- if (option in [int, list, dict, bool]) and isinstance(
494
  attr_value, option
495
  ):
496
  valid_type = True
@@ -570,6 +609,7 @@ class EmbExtractor:
570
  layer_to_quant,
571
  self.pad_token_id,
572
  self.forward_batch_size,
 
573
  self.summary_stat,
574
  )
575
 
@@ -584,6 +624,8 @@ class EmbExtractor:
584
  elif self.summary_stat is not None:
585
  embs_df = pd.DataFrame(embs).T
586
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
 
 
587
 
588
  # save embeddings to output_path
589
  if cell_state is None:
@@ -781,7 +823,7 @@ class EmbExtractor:
781
  f"not present in provided embeddings dataframe."
782
  )
783
  continue
784
- output_prefix_label = "_" + output_prefix + f"_umap_{label}"
785
  output_file = (
786
  Path(output_directory) / output_prefix_label
787
  ).with_suffix(".pdf")
 
38
  layer_to_quant,
39
  pad_token_id,
40
  forward_batch_size,
41
+ token_gene_dict,
42
+ special_token=False,
43
  summary_stat=None,
44
  silent=False,
45
  ):
46
  model_input_size = pu.get_model_input_size(model)
47
  total_batch_length = len(filtered_input_data)
48
+
49
  if summary_stat is None:
50
  embs_list = []
51
  elif summary_stat is not None:
 
69
  k: [TDigest() for _ in range(emb_dims)] for k in gene_set
70
  }
71
 
72
+ # Check if CLS and EOS token is present in the token dictionary
73
+ cls_present = any("<cls>" in value for value in token_gene_dict.values())
74
+ eos_present = any("<eos>" in value for value in token_gene_dict.values())
75
+ if emb_mode == "cls":
76
+ assert cls_present, "<cls> token missing in token dictionary"
77
+ # Check to make sure that the first token of the filtered input data is cls token
78
+ for key, value in token_gene_dict.items():
79
+ if value == "<cls>":
80
+ cls_token_id = key
81
+ assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
82
+ else:
83
+ if cls_present:
84
+ logger.warning("CLS token present in token dictionary, excluding from average")
85
+ if eos_present:
86
+ logger.warning("EOS token present in token dictionary, excluding from average")
87
+
88
  overall_max_len = 0
89
+
90
+ for i in trange(0, total_batch_length, forward_batch_size, leave = (not silent)):
91
  max_range = min(i + forward_batch_size, total_batch_length)
92
 
93
  minibatch = filtered_input_data.select([i for i in range(i, max_range)])
 
108
  )
109
 
110
  embs_i = outputs.hidden_states[layer_to_quant]
111
+
112
  if emb_mode == "cell":
113
+ if cls_present:
114
+ non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
115
+ if eos_present:
116
+ mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
117
+ else:
118
+ mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 1)
119
+ else:
120
+ mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
121
  if summary_stat is None:
122
  embs_list.append(mean_embs)
123
  elif summary_stat is not None:
 
146
  accumulate_tdigests(
147
  embs_tdigests_dict[int(k)], dict_h[k], emb_dims
148
  )
149
+ del embs_h
150
+ del dict_h
151
+ elif emb_mode == "cls":
152
+ cls_embs = embs_i[:,0,:] # CLS token layer
153
+ embs_list.append(cls_embs)
154
+ del cls_embs
155
+
156
  overall_max_len = max(overall_max_len, max_len)
157
  del outputs
158
  del minibatch
 
160
  del embs_i
161
 
162
  torch.cuda.empty_cache()
163
+
164
+
165
  if summary_stat is None:
166
  if emb_mode == "cell":
167
  embs_stack = torch.cat(embs_list, dim=0)
 
174
  1,
175
  pu.pad_3d_tensor,
176
  )
177
+ elif emb_mode == "cls":
178
+ embs_stack = torch.cat(embs_list, dim=0)
179
 
180
  # calculate summary stat embs from approximated tdigests
181
  elif summary_stat is not None:
 
382
  bbox_to_anchor=(0.5, 1),
383
  facecolor="white",
384
  )
385
+ print(f"Output file: {output_file}")
386
  plt.savefig(output_file, bbox_inches="tight")
387
 
388
 
 
390
  valid_option_dict = {
391
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
392
  "num_classes": {int},
393
+ "emb_mode": {"cell", "gene", "cls"},
394
  "cell_emb_style": {"mean_pool"},
395
  "gene_emb_style": {"mean_pool"},
396
  "filter_data": {None, dict},
 
399
  "emb_label": {None, list},
400
  "labels_to_plot": {None, list},
401
  "forward_batch_size": {int},
402
+ "token_dictionary_file" : {None, str},
403
  "nproc": {int},
404
  "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
405
  }
 
419
  forward_batch_size=100,
420
  nproc=4,
421
  summary_stat=None,
422
+ token_dictionary_file=None,
423
  ):
424
  """
425
  Initialize embedding extractor.
 
469
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
470
  | Non-exact is slower but more memory-efficient.
471
  token_dictionary_file : Path
472
+ | Default is to the geneformer token dictionary
473
  | Path to pickle file containing token dictionary (Ensembl ID:token).
474
 
475
  **Examples:**
 
499
  self.emb_layer = emb_layer
500
  self.emb_label = emb_label
501
  self.labels_to_plot = labels_to_plot
502
+ self.token_dictionary_file = token_dictionary_file
503
  self.forward_batch_size = forward_batch_size
504
  self.nproc = nproc
505
  if (summary_stat is not None) and ("exact" in summary_stat):
 
512
  self.validate_options()
513
 
514
  # load token dictionary (Ensembl IDs:token)
515
+ if self.token_dictionary_file is None:
516
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
517
  with open(token_dictionary_file, "rb") as f:
518
  self.gene_token_dict = pickle.load(f)
519
 
 
529
  continue
530
  valid_type = False
531
  for option in valid_options:
532
+ if (option in [int, list, dict, bool, str]) and isinstance(
533
  attr_value, option
534
  ):
535
  valid_type = True
 
609
  layer_to_quant,
610
  self.pad_token_id,
611
  self.forward_batch_size,
612
+ self.token_gene_dict,
613
  self.summary_stat,
614
  )
615
 
 
624
  elif self.summary_stat is not None:
625
  embs_df = pd.DataFrame(embs).T
626
  embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
627
+ elif self.emb_mode == "cls":
628
+ embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
629
 
630
  # save embeddings to output_path
631
  if cell_state is None:
 
823
  f"not present in provided embeddings dataframe."
824
  )
825
  continue
826
+ output_prefix_label = output_prefix + f"_umap_{label}"
827
  output_file = (
828
  Path(output_directory) / output_prefix_label
829
  ).with_suffix(".pdf")