ctheodoris commited on
Commit
99f40b4
1 Parent(s): 3e24216

update isp for cls perturb set

Browse files
Files changed (1) hide show
  1. geneformer/perturber_utils.py +2 -1
geneformer/perturber_utils.py CHANGED
@@ -620,9 +620,10 @@ def quant_cos_sims(
620
  cos = torch.nn.CosineSimilarity(dim=1)
621
 
622
  # if emb_mode == "gene", can only calculate gene cos sims
623
- # against original cell anyways
624
  if cell_states_to_model is None or emb_mode == "gene":
625
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
 
626
  elif cell_states_to_model is not None and emb_mode == "cell":
627
  possible_states = get_possible_states(cell_states_to_model)
628
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
 
620
  cos = torch.nn.CosineSimilarity(dim=1)
621
 
622
  # if emb_mode == "gene", can only calculate gene cos sims
623
+ # against original cell
624
  if cell_states_to_model is None or emb_mode == "gene":
625
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
626
+
627
  elif cell_states_to_model is not None and emb_mode == "cell":
628
  possible_states = get_possible_states(cell_states_to_model)
629
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))