hchen725 commited on
Commit
9026cc1
1 Parent(s): 57f02a4

embs_df with all model embeddings

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +7 -15
geneformer/emb_extractor.py CHANGED
@@ -50,9 +50,7 @@ def get_embs(
50
  embs_list = []
51
  elif summary_stat is not None:
52
  # test embedding extraction for example cell and extract # emb dims
53
- example = filtered_input_data.select([i for i in range(1)])
54
- example.set_format(type="torch")
55
- emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
56
  if emb_mode == "cell":
57
  # initiate tdigests for # of emb dims
58
  embs_tdigests = [TDigest() for _ in range(emb_dims)]
@@ -78,7 +76,7 @@ def get_embs(
78
  gene_token_dict = {v:k for k,v in token_gene_dict.items()}
79
  cls_token_id = gene_token_dict["<cls>"]
80
  assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
81
- elif emb_mode == "cell":
82
  if cls_present:
83
  logger.warning("CLS token present in token dictionary, excluding from average.")
84
  if eos_present:
@@ -148,7 +146,7 @@ def get_embs(
148
  del embs_h
149
  del dict_h
150
  elif emb_mode == "cls":
151
- cls_embs = embs_i[:,0,:].clone().detach() # CLS token layer
152
  embs_list.append(cls_embs)
153
  del cls_embs
154
 
@@ -239,14 +237,6 @@ def tdigest_median(embs_tdigests, emb_dims):
239
  return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
240
 
241
 
242
- def test_emb(model, example, layer_to_quant):
243
- with torch.no_grad():
244
- outputs = model(input_ids=example.to("cuda"))
245
-
246
- embs_test = outputs.hidden_states[layer_to_quant]
247
- return embs_test.size()[2]
248
-
249
-
250
  def label_cell_embs(embs, downsampled_data, emb_labels):
251
  embs_df = pd.DataFrame(embs.cpu().numpy())
252
  if emb_labels is not None:
@@ -632,13 +622,15 @@ class EmbExtractor:
632
 
633
  if self.exact_summary_stat == "exact_mean":
634
  embs = embs.mean(dim=0)
 
635
  embs_df = pd.DataFrame(
636
- embs_df[0:255].mean(axis="rows"), columns=[self.exact_summary_stat]
637
  ).T
638
  elif self.exact_summary_stat == "exact_median":
639
  embs = torch.median(embs, dim=0)[0]
 
640
  embs_df = pd.DataFrame(
641
- embs_df[0:255].median(axis="rows"), columns=[self.exact_summary_stat]
642
  ).T
643
 
644
  if cell_state is not None:
 
50
  embs_list = []
51
  elif summary_stat is not None:
52
  # test embedding extraction for example cell and extract # emb dims
53
+ emb_dims = pu.get_model_embedding_dimensions(model)
 
 
54
  if emb_mode == "cell":
55
  # initiate tdigests for # of emb dims
56
  embs_tdigests = [TDigest() for _ in range(emb_dims)]
 
76
  gene_token_dict = {v:k for k,v in token_gene_dict.items()}
77
  cls_token_id = gene_token_dict["<cls>"]
78
  assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
79
+ else:
80
  if cls_present:
81
  logger.warning("CLS token present in token dictionary, excluding from average.")
82
  if eos_present:
 
146
  del embs_h
147
  del dict_h
148
  elif emb_mode == "cls":
149
+ cls_embs = embs_i[:,0,:] # CLS token layer
150
  embs_list.append(cls_embs)
151
  del cls_embs
152
 
 
237
  return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
238
 
239
 
 
 
 
 
 
 
 
 
240
  def label_cell_embs(embs, downsampled_data, emb_labels):
241
  embs_df = pd.DataFrame(embs.cpu().numpy())
242
  if emb_labels is not None:
 
622
 
623
  if self.exact_summary_stat == "exact_mean":
624
  embs = embs.mean(dim=0)
625
+ emb_dims = pu.get_model_embedding_dimensions(model)
626
  embs_df = pd.DataFrame(
627
+ embs_df[0:emb_dims-1].mean(axis="rows"), columns=[self.exact_summary_stat]
628
  ).T
629
  elif self.exact_summary_stat == "exact_median":
630
  embs = torch.median(embs, dim=0)[0]
631
+ emb_dims = pu.get_model_embedding_dimensions(model)
632
  embs_df = pd.DataFrame(
633
+ embs_df[0:emb_dims-1].median(axis="rows"), columns=[self.exact_summary_stat]
634
  ).T
635
 
636
  if cell_state is not None: