ctheodoris commited on
Commit
916546e
1 Parent(s): f0ec9ca

adjust logging

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +15 -17
geneformer/emb_extractor.py CHANGED
@@ -75,19 +75,18 @@ def get_embs(
75
  if emb_mode == "cls":
76
  assert cls_present, "<cls> token missing in token dictionary"
77
  # Check to make sure that the first token of the filtered input data is cls token
78
- for key, value in token_gene_dict.items():
79
- if value == "<cls>":
80
- cls_token_id = key
81
  assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
82
  else:
83
  if cls_present:
84
- logger.warning("CLS token present in token dictionary, excluding from average")
85
  if eos_present:
86
- logger.warning("EOS token present in token dictionary, excluding from average")
87
 
88
  overall_max_len = 0
89
 
90
- for i in trange(0, total_batch_length, forward_batch_size, leave = (not silent)):
91
  max_range = min(i + forward_batch_size, total_batch_length)
92
 
93
  minibatch = filtered_input_data.select([i for i in range(i, max_range)])
@@ -163,7 +162,7 @@ def get_embs(
163
 
164
 
165
  if summary_stat is None:
166
- if emb_mode == "cell":
167
  embs_stack = torch.cat(embs_list, dim=0)
168
  elif emb_mode == "gene":
169
  embs_stack = pu.pad_tensor_list(
@@ -174,8 +173,6 @@ def get_embs(
174
  1,
175
  pu.pad_3d_tensor,
176
  )
177
- elif emb_mode == "cls":
178
- embs_stack = torch.cat(embs_list, dim=0)
179
 
180
  # calculate summary stat embs from approximated tdigests
181
  elif summary_stat is not None:
@@ -382,7 +379,7 @@ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
382
  bbox_to_anchor=(0.5, 1),
383
  facecolor="white",
384
  )
385
- print(f"Output file: {output_file}")
386
  plt.savefig(output_file, bbox_inches="tight")
387
 
388
 
@@ -390,7 +387,7 @@ class EmbExtractor:
390
  valid_option_dict = {
391
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
392
  "num_classes": {int},
393
- "emb_mode": {"cell", "gene", "cls"},
394
  "cell_emb_style": {"mean_pool"},
395
  "gene_emb_style": {"mean_pool"},
396
  "filter_data": {None, dict},
@@ -431,10 +428,11 @@ class EmbExtractor:
431
  num_classes : int
432
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
433
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
434
- emb_mode : {"cell", "gene"}
435
- | Whether to output cell or gene embeddings.
436
- cell_emb_style : "mean_pool"
437
- | Method for summarizing cell embeddings.
 
438
  | Currently only option is mean pooling of gene embeddings for given cell.
439
  gene_emb_style : "mean_pool"
440
  | Method for summarizing gene embeddings.
@@ -469,7 +467,7 @@ class EmbExtractor:
469
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
470
  | Non-exact is slower but more memory-efficient.
471
  token_dictionary_file : Path
472
- | Default is to the geneformer token dictionary
473
  | Path to pickle file containing token dictionary (Ensembl ID:token).
474
 
475
  **Examples:**
@@ -841,4 +839,4 @@ class EmbExtractor:
841
  output_file = (
842
  Path(output_directory) / output_prefix_label
843
  ).with_suffix(".pdf")
844
- plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
 
75
  if emb_mode == "cls":
76
  assert cls_present, "<cls> token missing in token dictionary"
77
  # Check to make sure that the first token of the filtered input data is cls token
78
+ gene_token_dict = {v:k for k,v in token_gene_dict}
79
+ cls_token_id = gene_token_dict["<cls>"]
 
80
  assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
81
  else:
82
  if cls_present:
83
+ logger.warning("CLS token present in token dictionary, excluding from average.")
84
  if eos_present:
85
+ logger.warning("EOS token present in token dictionary, excluding from average.")
86
 
87
  overall_max_len = 0
88
 
89
+ for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
90
  max_range = min(i + forward_batch_size, total_batch_length)
91
 
92
  minibatch = filtered_input_data.select([i for i in range(i, max_range)])
 
162
 
163
 
164
  if summary_stat is None:
165
+ if (emb_mode == "cell") or (emb_mode == "cls"):
166
  embs_stack = torch.cat(embs_list, dim=0)
167
  elif emb_mode == "gene":
168
  embs_stack = pu.pad_tensor_list(
 
173
  1,
174
  pu.pad_3d_tensor,
175
  )
 
 
176
 
177
  # calculate summary stat embs from approximated tdigests
178
  elif summary_stat is not None:
 
379
  bbox_to_anchor=(0.5, 1),
380
  facecolor="white",
381
  )
382
+ logger.info(f"Output file: {output_file}")
383
  plt.savefig(output_file, bbox_inches="tight")
384
 
385
 
 
387
  valid_option_dict = {
388
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
389
  "num_classes": {int},
390
+ "emb_mode": {"cls", "cell", "gene"},
391
  "cell_emb_style": {"mean_pool"},
392
  "gene_emb_style": {"mean_pool"},
393
  "filter_data": {None, dict},
 
428
  num_classes : int
429
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
430
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
431
+ emb_mode : {"cls", "cell", "gene"}
432
+ | Whether to output CLS, cell, or gene embeddings.
433
+ | CLS embeddings are cell embeddings derived from the CLS token in the front of the rank value encoding.
434
+ cell_emb_style : {"mean_pool"}
435
+ | Method for summarizing cell embeddings if not using CLS token.
436
  | Currently only option is mean pooling of gene embeddings for given cell.
437
  gene_emb_style : "mean_pool"
438
  | Method for summarizing gene embeddings.
 
467
  | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
468
  | Non-exact is slower but more memory-efficient.
469
  token_dictionary_file : Path
470
+ | Default is the Geneformer token dictionary
471
  | Path to pickle file containing token dictionary (Ensembl ID:token).
472
 
473
  **Examples:**
 
839
  output_file = (
840
  Path(output_directory) / output_prefix_label
841
  ).with_suffix(".pdf")
842
+ plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)