fix cell state gene embeddings bug

#345
by davidjwen - opened
geneformer/__init__.py CHANGED
@@ -11,7 +11,7 @@ from .collator_for_classification import (
11
  DataCollatorForCellClassification,
12
  DataCollatorForGeneClassification,
13
  )
14
- from .emb_extractor import EmbExtractor
15
  from .in_silico_perturber import InSilicoPerturber
16
  from .in_silico_perturber_stats import InSilicoPerturberStats
17
  from .pretrainer import GeneformerPretrainer
 
11
  DataCollatorForCellClassification,
12
  DataCollatorForGeneClassification,
13
  )
14
+ from .emb_extractor import EmbExtractor, get_embs
15
  from .in_silico_perturber import InSilicoPerturber
16
  from .in_silico_perturber_stats import InSilicoPerturberStats
17
  from .pretrainer import GeneformerPretrainer
geneformer/in_silico_perturber.py CHANGED
@@ -39,6 +39,7 @@ import os
39
  import pickle
40
  from collections import defaultdict
41
  from typing import List
 
42
 
43
  import seaborn as sns
44
  import torch
@@ -47,7 +48,8 @@ from tqdm.auto import trange
47
 
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
50
- from .tokenizer import TOKEN_DICTIONARY_FILE
 
51
 
52
  sns.set()
53
 
@@ -185,6 +187,10 @@ class InSilicoPerturber:
185
  token_dictionary_file : Path
186
  | Path to pickle file containing token dictionary (Ensembl ID:token).
187
  """
 
 
 
 
188
 
189
  self.perturb_type = perturb_type
190
  self.perturb_rank_shift = perturb_rank_shift
@@ -422,6 +428,7 @@ class InSilicoPerturber:
422
  self.max_len = pu.get_model_input_size(model)
423
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
424
 
 
425
  ### filter input data ###
426
  # general filtering of input data based on filter_data argument
427
  filtered_input_data = pu.load_and_filter(
@@ -520,6 +527,7 @@ class InSilicoPerturber:
520
  perturbed_data = filtered_input_data.map(
521
  make_group_perturbation_batch, num_proc=self.nproc
522
  )
 
523
  if self.perturb_type == "overexpress":
524
  filtered_input_data = filtered_input_data.add_column(
525
  "n_overflow", perturbed_data["n_overflow"]
 
39
  import pickle
40
  from collections import defaultdict
41
  from typing import List
42
+ from multiprocess import set_start_method
43
 
44
  import seaborn as sns
45
  import torch
 
48
 
49
  from . import perturber_utils as pu
50
  from .emb_extractor import get_embs
51
+ from .perturber_utils import TOKEN_DICTIONARY_FILE
52
+
53
 
54
  sns.set()
55
 
 
187
  token_dictionary_file : Path
188
  | Path to pickle file containing token dictionary (Ensembl ID:token).
189
  """
190
+ try:
191
+ set_start_method("spawn")
192
+ except RuntimeError:
193
+ pass
194
 
195
  self.perturb_type = perturb_type
196
  self.perturb_rank_shift = perturb_rank_shift
 
428
  self.max_len = pu.get_model_input_size(model)
429
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
430
 
431
+
432
  ### filter input data ###
433
  # general filtering of input data based on filter_data argument
434
  filtered_input_data = pu.load_and_filter(
 
527
  perturbed_data = filtered_input_data.map(
528
  make_group_perturbation_batch, num_proc=self.nproc
529
  )
530
+
531
  if self.perturb_type == "overexpress":
532
  filtered_input_data = filtered_input_data.add_column(
533
  "n_overflow", perturbed_data["n_overflow"]
geneformer/perturber_utils.py CHANGED
@@ -4,6 +4,8 @@ import pickle
4
  import re
5
  from collections import defaultdict
6
  from typing import List
 
 
7
 
8
  import numpy as np
9
  import pandas as pd
@@ -16,6 +18,11 @@ from transformers import (
16
  BertForTokenClassification,
17
  )
18
 
 
 
 
 
 
19
  sns.set()
20
 
21
  logger = logging.getLogger(__name__)
@@ -581,9 +588,11 @@ def quant_cos_sims(
581
  elif emb_mode == "cell":
582
  cos = torch.nn.CosineSimilarity(dim=1)
583
 
584
- if cell_states_to_model is None:
 
 
585
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
586
- else:
587
  possible_states = get_possible_states(cell_states_to_model)
588
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
589
  for state in possible_states:
@@ -705,3 +714,48 @@ def validate_cell_states_to_model(cell_states_to_model):
705
  "'alt_states': ['hcm', 'other1', 'other2']}"
706
  )
707
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import re
5
  from collections import defaultdict
6
  from typing import List
7
+ from pathlib import Path
8
+
9
 
10
  import numpy as np
11
  import pandas as pd
 
18
  BertForTokenClassification,
19
  )
20
 
21
+ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
22
+ TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
23
+ ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
24
+
25
+
26
  sns.set()
27
 
28
  logger = logging.getLogger(__name__)
 
588
  elif emb_mode == "cell":
589
  cos = torch.nn.CosineSimilarity(dim=1)
590
 
591
+ # if emb_mode == "gene", can only calculate gene cos sims
592
+ # against original cell anyways
593
+ if cell_states_to_model is None or emb_mode == "gene":
594
  cos_sims = cos(perturbation_emb, original_emb).to("cuda")
595
+ elif cell_states_to_model is not None and emb_mode == "cell":
596
  possible_states = get_possible_states(cell_states_to_model)
597
  cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
598
  for state in possible_states:
 
714
  "'alt_states': ['hcm', 'other1', 'other2']}"
715
  )
716
  raise
717
+
718
+ class GeneIdHandler:
719
+ def __init__(self, raise_errors=False):
720
+ def invert_dict(dict_obj):
721
+ return {v:k for k,v in dict_obj.items()}
722
+
723
+ self.raise_errors = raise_errors
724
+
725
+ with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
726
+ self.gene_token_dict = pickle.load(f)
727
+ self.token_gene_dict = invert_dict(self.gene_token_dict)
728
+
729
+ with open(ENSEMBL_DICTIONARY_FILE, 'rb') as f:
730
+ self.id_gene_dict = pickle.load(f)
731
+ self.gene_id_dict = invert_dict(self.id_gene_dict)
732
+
733
+ def ens_to_token(self, ens_id):
734
+ if not self.raise_errors:
735
+ return self.gene_token_dict.get(ens_id, ens_id)
736
+ else:
737
+ return self.gene_token_dict[ens_id]
738
+
739
+ def token_to_ens(self, token):
740
+ if not self.raise_errors:
741
+ return self.token_gene_dict.get(token, token)
742
+ else:
743
+ return self.token_gene_dict[token]
744
+
745
+ def ens_to_symbol(self, ens_id):
746
+ if not self.raise_errors:
747
+ return self.gene_id_dict.get(ens_id, ens_id)
748
+ else:
749
+ return self.gene_id_dict[ens_id]
750
+
751
+ def symbol_to_ens(self, symbol):
752
+ if not self.raise_errors:
753
+ return self.id_gene_dict.get(symbol, symbol)
754
+ else:
755
+ return self.id_gene_dict[symbol]
756
+
757
+ def token_to_symbol(self, token):
758
+ return self.ens_to_symbol(self.token_to_ens(token))
759
+
760
+ def symbol_to_token(self, symbol):
761
+ return self.ens_to_token(self.symbol_to_ens(symbol))
geneformer/tokenizer.py CHANGED
@@ -52,8 +52,8 @@ import loompy as lp # noqa
52
 
53
  logger = logging.getLogger(__name__)
54
 
55
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
56
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
57
 
58
  def rank_genes(gene_vector, gene_tokens):
59
  """
 
52
 
53
  logger = logging.getLogger(__name__)
54
 
55
+ from .perturber_utils import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
56
+
57
 
58
  def rank_genes(gene_vector, gene_tokens):
59
  """