ctheodoris
commited on
fix cell state gene embeddings bug (#345)
Browse files- added quality of life improvements; fixed gene similarities with cell_states_to_model (4b4547f0634eed07560e599766c30326138b7a32)
- reinstate save_to_disk patch (344f263c6173a6bbe96eabcc5ac65e45fa4756e7)
- geneformer/__init__.py +1 -1
- geneformer/in_silico_perturber.py +9 -1
- geneformer/perturber_utils.py +56 -2
- geneformer/tokenizer.py +2 -2
geneformer/__init__.py
CHANGED
@@ -11,7 +11,7 @@ from .collator_for_classification import (
|
|
11 |
DataCollatorForCellClassification,
|
12 |
DataCollatorForGeneClassification,
|
13 |
)
|
14 |
-
from .emb_extractor import EmbExtractor
|
15 |
from .in_silico_perturber import InSilicoPerturber
|
16 |
from .in_silico_perturber_stats import InSilicoPerturberStats
|
17 |
from .pretrainer import GeneformerPretrainer
|
|
|
11 |
DataCollatorForCellClassification,
|
12 |
DataCollatorForGeneClassification,
|
13 |
)
|
14 |
+
from .emb_extractor import EmbExtractor, get_embs
|
15 |
from .in_silico_perturber import InSilicoPerturber
|
16 |
from .in_silico_perturber_stats import InSilicoPerturberStats
|
17 |
from .pretrainer import GeneformerPretrainer
|
geneformer/in_silico_perturber.py
CHANGED
@@ -39,6 +39,7 @@ import os
|
|
39 |
import pickle
|
40 |
from collections import defaultdict
|
41 |
from typing import List
|
|
|
42 |
|
43 |
import seaborn as sns
|
44 |
import torch
|
@@ -47,7 +48,8 @@ from tqdm.auto import trange
|
|
47 |
|
48 |
from . import perturber_utils as pu
|
49 |
from .emb_extractor import get_embs
|
50 |
-
from .
|
|
|
51 |
|
52 |
sns.set()
|
53 |
|
@@ -185,6 +187,10 @@ class InSilicoPerturber:
|
|
185 |
token_dictionary_file : Path
|
186 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
187 |
"""
|
|
|
|
|
|
|
|
|
188 |
|
189 |
self.perturb_type = perturb_type
|
190 |
self.perturb_rank_shift = perturb_rank_shift
|
@@ -422,6 +428,7 @@ class InSilicoPerturber:
|
|
422 |
self.max_len = pu.get_model_input_size(model)
|
423 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
424 |
|
|
|
425 |
### filter input data ###
|
426 |
# general filtering of input data based on filter_data argument
|
427 |
filtered_input_data = pu.load_and_filter(
|
@@ -520,6 +527,7 @@ class InSilicoPerturber:
|
|
520 |
perturbed_data = filtered_input_data.map(
|
521 |
make_group_perturbation_batch, num_proc=self.nproc
|
522 |
)
|
|
|
523 |
if self.perturb_type == "overexpress":
|
524 |
filtered_input_data = filtered_input_data.add_column(
|
525 |
"n_overflow", perturbed_data["n_overflow"]
|
|
|
39 |
import pickle
|
40 |
from collections import defaultdict
|
41 |
from typing import List
|
42 |
+
from multiprocess import set_start_method
|
43 |
|
44 |
import seaborn as sns
|
45 |
import torch
|
|
|
48 |
|
49 |
from . import perturber_utils as pu
|
50 |
from .emb_extractor import get_embs
|
51 |
+
from .perturber_utils import TOKEN_DICTIONARY_FILE
|
52 |
+
|
53 |
|
54 |
sns.set()
|
55 |
|
|
|
187 |
token_dictionary_file : Path
|
188 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
189 |
"""
|
190 |
+
try:
|
191 |
+
set_start_method("spawn")
|
192 |
+
except RuntimeError:
|
193 |
+
pass
|
194 |
|
195 |
self.perturb_type = perturb_type
|
196 |
self.perturb_rank_shift = perturb_rank_shift
|
|
|
428 |
self.max_len = pu.get_model_input_size(model)
|
429 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
430 |
|
431 |
+
|
432 |
### filter input data ###
|
433 |
# general filtering of input data based on filter_data argument
|
434 |
filtered_input_data = pu.load_and_filter(
|
|
|
527 |
perturbed_data = filtered_input_data.map(
|
528 |
make_group_perturbation_batch, num_proc=self.nproc
|
529 |
)
|
530 |
+
|
531 |
if self.perturb_type == "overexpress":
|
532 |
filtered_input_data = filtered_input_data.add_column(
|
533 |
"n_overflow", perturbed_data["n_overflow"]
|
geneformer/perturber_utils.py
CHANGED
@@ -4,6 +4,8 @@ import pickle
|
|
4 |
import re
|
5 |
from collections import defaultdict
|
6 |
from typing import List
|
|
|
|
|
7 |
|
8 |
import numpy as np
|
9 |
import pandas as pd
|
@@ -16,6 +18,11 @@ from transformers import (
|
|
16 |
BertForTokenClassification,
|
17 |
)
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
sns.set()
|
20 |
|
21 |
logger = logging.getLogger(__name__)
|
@@ -581,9 +588,11 @@ def quant_cos_sims(
|
|
581 |
elif emb_mode == "cell":
|
582 |
cos = torch.nn.CosineSimilarity(dim=1)
|
583 |
|
584 |
-
if
|
|
|
|
|
585 |
cos_sims = cos(perturbation_emb, original_emb).to("cuda")
|
586 |
-
|
587 |
possible_states = get_possible_states(cell_states_to_model)
|
588 |
cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
|
589 |
for state in possible_states:
|
@@ -705,3 +714,48 @@ def validate_cell_states_to_model(cell_states_to_model):
|
|
705 |
"'alt_states': ['hcm', 'other1', 'other2']}"
|
706 |
)
|
707 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import re
|
5 |
from collections import defaultdict
|
6 |
from typing import List
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
|
10 |
import numpy as np
|
11 |
import pandas as pd
|
|
|
18 |
BertForTokenClassification,
|
19 |
)
|
20 |
|
21 |
+
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
|
22 |
+
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
23 |
+
ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
|
24 |
+
|
25 |
+
|
26 |
sns.set()
|
27 |
|
28 |
logger = logging.getLogger(__name__)
|
|
|
588 |
elif emb_mode == "cell":
|
589 |
cos = torch.nn.CosineSimilarity(dim=1)
|
590 |
|
591 |
+
# if emb_mode == "gene", can only calculate gene cos sims
|
592 |
+
# against original cell anyways
|
593 |
+
if cell_states_to_model is None or emb_mode == "gene":
|
594 |
cos_sims = cos(perturbation_emb, original_emb).to("cuda")
|
595 |
+
elif cell_states_to_model is not None and emb_mode == "cell":
|
596 |
possible_states = get_possible_states(cell_states_to_model)
|
597 |
cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
|
598 |
for state in possible_states:
|
|
|
714 |
"'alt_states': ['hcm', 'other1', 'other2']}"
|
715 |
)
|
716 |
raise
|
717 |
+
|
718 |
+
class GeneIdHandler:
|
719 |
+
def __init__(self, raise_errors=False):
|
720 |
+
def invert_dict(dict_obj):
|
721 |
+
return {v:k for k,v in dict_obj.items()}
|
722 |
+
|
723 |
+
self.raise_errors = raise_errors
|
724 |
+
|
725 |
+
with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
|
726 |
+
self.gene_token_dict = pickle.load(f)
|
727 |
+
self.token_gene_dict = invert_dict(self.gene_token_dict)
|
728 |
+
|
729 |
+
with open(ENSEMBL_DICTIONARY_FILE, 'rb') as f:
|
730 |
+
self.id_gene_dict = pickle.load(f)
|
731 |
+
self.gene_id_dict = invert_dict(self.id_gene_dict)
|
732 |
+
|
733 |
+
def ens_to_token(self, ens_id):
|
734 |
+
if not self.raise_errors:
|
735 |
+
return self.gene_token_dict.get(ens_id, ens_id)
|
736 |
+
else:
|
737 |
+
return self.gene_token_dict[ens_id]
|
738 |
+
|
739 |
+
def token_to_ens(self, token):
|
740 |
+
if not self.raise_errors:
|
741 |
+
return self.token_gene_dict.get(token, token)
|
742 |
+
else:
|
743 |
+
return self.token_gene_dict[token]
|
744 |
+
|
745 |
+
def ens_to_symbol(self, ens_id):
|
746 |
+
if not self.raise_errors:
|
747 |
+
return self.gene_id_dict.get(ens_id, ens_id)
|
748 |
+
else:
|
749 |
+
return self.gene_id_dict[ens_id]
|
750 |
+
|
751 |
+
def symbol_to_ens(self, symbol):
|
752 |
+
if not self.raise_errors:
|
753 |
+
return self.id_gene_dict.get(symbol, symbol)
|
754 |
+
else:
|
755 |
+
return self.id_gene_dict[symbol]
|
756 |
+
|
757 |
+
def token_to_symbol(self, token):
|
758 |
+
return self.ens_to_symbol(self.token_to_ens(token))
|
759 |
+
|
760 |
+
def symbol_to_token(self, symbol):
|
761 |
+
return self.ens_to_token(self.symbol_to_ens(symbol))
|
geneformer/tokenizer.py
CHANGED
@@ -52,8 +52,8 @@ import loompy as lp # noqa
|
|
52 |
|
53 |
logger = logging.getLogger(__name__)
|
54 |
|
55 |
-
|
56 |
-
|
57 |
|
58 |
def rank_genes(gene_vector, gene_tokens):
|
59 |
"""
|
|
|
52 |
|
53 |
logger = logging.getLogger(__name__)
|
54 |
|
55 |
+
from .perturber_utils import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
|
56 |
+
|
57 |
|
58 |
def rank_genes(gene_vector, gene_tokens):
|
59 |
"""
|