ctheodoris commited on
Commit
8e35e45
1 Parent(s): 9026cc1

incorporate prior changes

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +4 -4
geneformer/emb_extractor.py CHANGED
@@ -49,8 +49,8 @@ def get_embs(
49
  if summary_stat is None:
50
  embs_list = []
51
  elif summary_stat is not None:
52
- # test embedding extraction for example cell and extract # emb dims
53
- emb_dims = pu.get_model_embedding_dimensions(model)
54
  if emb_mode == "cell":
55
  # initiate tdigests for # of emb dims
56
  embs_tdigests = [TDigest() for _ in range(emb_dims)]
@@ -76,7 +76,7 @@ def get_embs(
76
  gene_token_dict = {v:k for k,v in token_gene_dict.items()}
77
  cls_token_id = gene_token_dict["<cls>"]
78
  assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
79
- else:
80
  if cls_present:
81
  logger.warning("CLS token present in token dictionary, excluding from average.")
82
  if eos_present:
@@ -146,7 +146,7 @@ def get_embs(
146
  del embs_h
147
  del dict_h
148
  elif emb_mode == "cls":
149
- cls_embs = embs_i[:,0,:] # CLS token layer
150
  embs_list.append(cls_embs)
151
  del cls_embs
152
 
 
49
  if summary_stat is None:
50
  embs_list = []
51
  elif summary_stat is not None:
52
+ # get # of emb dims
53
+ emb_dims = pu.get_model_emb_dims(model)
54
  if emb_mode == "cell":
55
  # initiate tdigests for # of emb dims
56
  embs_tdigests = [TDigest() for _ in range(emb_dims)]
 
76
  gene_token_dict = {v:k for k,v in token_gene_dict.items()}
77
  cls_token_id = gene_token_dict["<cls>"]
78
  assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
79
+ elif emb_mode == "cell":
80
  if cls_present:
81
  logger.warning("CLS token present in token dictionary, excluding from average.")
82
  if eos_present:
 
146
  del embs_h
147
  del dict_h
148
  elif emb_mode == "cls":
149
+ cls_embs = embs_i[:,0,:].clone().detach() # CLS token layer
150
  embs_list.append(cls_embs)
151
  del cls_embs
152