GenePooler
#429
by
jamieb-nvs
- opened
- README.md +0 -3
- examples/cell_classification.ipynb +3 -7
- examples/extract_and_plot_cell_embeddings.ipynb +4 -7
- examples/gene_classification.ipynb +3 -6
- examples/in_silico_perturbation.ipynb +8 -18
- examples/tokenizing_scRNAseq_data.ipynb +5 -15
- geneformer/emb_extractor.py +1 -13
- geneformer/gene_name_id_dict_gc95M.pkl +2 -2
- geneformer/in_silico_perturber.py +6 -28
- geneformer/in_silico_perturber_stats.py +4 -10
- geneformer/mtl/data.py +105 -117
- geneformer/pretrainer.py +176 -6
- geneformer/tokenizer.py +36 -59
- requirements.txt +1 -1
README.md
CHANGED
@@ -1,9 +1,6 @@
|
|
1 |
---
|
2 |
datasets: ctheodoris/Genecorpus-30M
|
3 |
license: apache-2.0
|
4 |
-
tags:
|
5 |
-
- single-cell
|
6 |
-
- genomics
|
7 |
---
|
8 |
# Geneformer
|
9 |
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
|
|
|
1 |
---
|
2 |
datasets: ctheodoris/Genecorpus-30M
|
3 |
license: apache-2.0
|
|
|
|
|
|
|
4 |
---
|
5 |
# Geneformer
|
6 |
Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
|
examples/cell_classification.ipynb
CHANGED
@@ -68,10 +68,6 @@
|
|
68 |
" \"per_device_train_batch_size\": 12,\n",
|
69 |
" \"seed\": 73,\n",
|
70 |
"}\n",
|
71 |
-
"\n",
|
72 |
-
"# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
|
73 |
-
"# (otherwise the Classifier will use the current default model dictionary)\n",
|
74 |
-
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
75 |
"cc = Classifier(classifier=\"cell\",\n",
|
76 |
" cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
|
77 |
" filter_data=filter_data_dict,\n",
|
@@ -129,7 +125,7 @@
|
|
129 |
" \"train\": train_ids+eval_ids,\n",
|
130 |
" \"test\": test_ids}\n",
|
131 |
"\n",
|
132 |
-
"# Example input_data_file
|
133 |
"cc.prepare_data(input_data_file=\"/path/to/human_dcm_hcm_nf_2048_w_length.dataset\",\n",
|
134 |
" output_directory=output_dir,\n",
|
135 |
" output_prefix=output_prefix,\n",
|
@@ -264,7 +260,7 @@
|
|
264 |
" \"train\": train_ids,\n",
|
265 |
" \"eval\": eval_ids}\n",
|
266 |
"\n",
|
267 |
-
"#
|
268 |
"all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
|
269 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
|
270 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
@@ -450,7 +446,7 @@
|
|
450 |
"name": "python",
|
451 |
"nbconvert_exporter": "python",
|
452 |
"pygments_lexer": "ipython3",
|
453 |
-
"version": "3.
|
454 |
}
|
455 |
},
|
456 |
"nbformat": 4,
|
|
|
68 |
" \"per_device_train_batch_size\": 12,\n",
|
69 |
" \"seed\": 73,\n",
|
70 |
"}\n",
|
|
|
|
|
|
|
|
|
71 |
"cc = Classifier(classifier=\"cell\",\n",
|
72 |
" cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
|
73 |
" filter_data=filter_data_dict,\n",
|
|
|
125 |
" \"train\": train_ids+eval_ids,\n",
|
126 |
" \"test\": test_ids}\n",
|
127 |
"\n",
|
128 |
+
"# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
|
129 |
"cc.prepare_data(input_data_file=\"/path/to/human_dcm_hcm_nf_2048_w_length.dataset\",\n",
|
130 |
" output_directory=output_dir,\n",
|
131 |
" output_prefix=output_prefix,\n",
|
|
|
260 |
" \"train\": train_ids,\n",
|
261 |
" \"eval\": eval_ids}\n",
|
262 |
"\n",
|
263 |
+
"# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
|
264 |
"all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
|
265 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
|
266 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
|
|
446 |
"name": "python",
|
447 |
"nbconvert_exporter": "python",
|
448 |
"pygments_lexer": "ipython3",
|
449 |
+
"version": "3.11.5"
|
450 |
}
|
451 |
},
|
452 |
"nbformat": 4,
|
examples/extract_and_plot_cell_embeddings.ipynb
CHANGED
@@ -18,8 +18,6 @@
|
|
18 |
"outputs": [],
|
19 |
"source": [
|
20 |
"# initiate EmbExtractor\n",
|
21 |
-
"# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
|
22 |
-
"# (otherwise the EmbExtractor will use the current default model dictionary)\n",
|
23 |
"embex = EmbExtractor(model_type=\"CellClassifier\",\n",
|
24 |
" num_classes=3,\n",
|
25 |
" filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
|
@@ -28,13 +26,12 @@
|
|
28 |
" emb_label=[\"disease\",\"cell_type\"],\n",
|
29 |
" labels_to_plot=[\"disease\"],\n",
|
30 |
" forward_batch_size=200,\n",
|
31 |
-
" nproc=16
|
32 |
-
" token_dictionary_file=\"./gene_dictionaries_30m/token_dictionary_gc30M.pkl\") # change from current default dictionary for 30M model series\n",
|
33 |
"\n",
|
34 |
"# extracts embedding from input data\n",
|
35 |
"# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
|
36 |
-
"# example dataset
|
37 |
-
"embs = embex.extract_embs(\"../fine_tuned_models/
|
38 |
" \"path/to/input_data/\",\n",
|
39 |
" \"path/to/output_directory/\",\n",
|
40 |
" \"output_prefix\")\n"
|
@@ -132,7 +129,7 @@
|
|
132 |
"name": "python",
|
133 |
"nbconvert_exporter": "python",
|
134 |
"pygments_lexer": "ipython3",
|
135 |
-
"version": "3.
|
136 |
}
|
137 |
},
|
138 |
"nbformat": 4,
|
|
|
18 |
"outputs": [],
|
19 |
"source": [
|
20 |
"# initiate EmbExtractor\n",
|
|
|
|
|
21 |
"embex = EmbExtractor(model_type=\"CellClassifier\",\n",
|
22 |
" num_classes=3,\n",
|
23 |
" filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
|
|
|
26 |
" emb_label=[\"disease\",\"cell_type\"],\n",
|
27 |
" labels_to_plot=[\"disease\"],\n",
|
28 |
" forward_batch_size=200,\n",
|
29 |
+
" nproc=16)\n",
|
|
|
30 |
"\n",
|
31 |
"# extracts embedding from input data\n",
|
32 |
"# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
|
33 |
+
"# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
|
34 |
+
"embs = embex.extract_embs(\"../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224\",\n",
|
35 |
" \"path/to/input_data/\",\n",
|
36 |
" \"path/to/output_directory/\",\n",
|
37 |
" \"output_prefix\")\n"
|
|
|
129 |
"name": "python",
|
130 |
"nbconvert_exporter": "python",
|
131 |
"pygments_lexer": "ipython3",
|
132 |
+
"version": "3.11.5"
|
133 |
}
|
134 |
},
|
135 |
"nbformat": 4,
|
examples/gene_classification.ipynb
CHANGED
@@ -71,9 +71,6 @@
|
|
71 |
}
|
72 |
],
|
73 |
"source": [
|
74 |
-
"# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
|
75 |
-
"# (otherwise the Classifier will use the current default model dictionary)\n",
|
76 |
-
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
77 |
"cc = Classifier(classifier=\"gene\",\n",
|
78 |
" gene_class_dict = gene_class_dict,\n",
|
79 |
" max_ncells = 10_000,\n",
|
@@ -105,7 +102,7 @@
|
|
105 |
}
|
106 |
],
|
107 |
"source": [
|
108 |
-
"# Example input_data_file
|
109 |
"cc.prepare_data(input_data_file=\"/path/to/gc-30M_sample50k.dataset\",\n",
|
110 |
" output_directory=output_dir,\n",
|
111 |
" output_prefix=output_prefix)"
|
@@ -843,7 +840,7 @@
|
|
843 |
}
|
844 |
],
|
845 |
"source": [
|
846 |
-
"# 6 layer
|
847 |
"all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
|
848 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
|
849 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
@@ -1243,7 +1240,7 @@
|
|
1243 |
"name": "python",
|
1244 |
"nbconvert_exporter": "python",
|
1245 |
"pygments_lexer": "ipython3",
|
1246 |
-
"version": "3.
|
1247 |
}
|
1248 |
},
|
1249 |
"nbformat": 4,
|
|
|
71 |
}
|
72 |
],
|
73 |
"source": [
|
|
|
|
|
|
|
74 |
"cc = Classifier(classifier=\"gene\",\n",
|
75 |
" gene_class_dict = gene_class_dict,\n",
|
76 |
" max_ncells = 10_000,\n",
|
|
|
102 |
}
|
103 |
],
|
104 |
"source": [
|
105 |
+
"# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/gene_classification/dosage_sensitive_tfs/gc-30M_sample50k.dataset\n",
|
106 |
"cc.prepare_data(input_data_file=\"/path/to/gc-30M_sample50k.dataset\",\n",
|
107 |
" output_directory=output_dir,\n",
|
108 |
" output_prefix=output_prefix)"
|
|
|
840 |
}
|
841 |
],
|
842 |
"source": [
|
843 |
+
"# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
|
844 |
"all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
|
845 |
" prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
|
846 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
|
|
1240 |
"name": "python",
|
1241 |
"nbconvert_exporter": "python",
|
1242 |
"pygments_lexer": "ipython3",
|
1243 |
+
"version": "3.11.5"
|
1244 |
}
|
1245 |
},
|
1246 |
"nbformat": 4,
|
examples/in_silico_perturbation.ipynb
CHANGED
@@ -39,10 +39,7 @@
|
|
39 |
"\n",
|
40 |
"filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
|
41 |
"\n",
|
42 |
-
"
|
43 |
-
"# (otherwise the EmbExtractor will use the current default model dictionary)\n",
|
44 |
-
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
45 |
-
"embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
|
46 |
" num_classes=3,\n",
|
47 |
" filter_data=filter_data_dict,\n",
|
48 |
" max_ncells=1000,\n",
|
@@ -52,7 +49,7 @@
|
|
52 |
" nproc=16)\n",
|
53 |
"\n",
|
54 |
"state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
|
55 |
-
" \"
|
56 |
" \"path/to/input_data\",\n",
|
57 |
" \"path/to/output_directory\",\n",
|
58 |
" \"output_prefix\")"
|
@@ -67,15 +64,12 @@
|
|
67 |
},
|
68 |
"outputs": [],
|
69 |
"source": [
|
70 |
-
"# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
|
71 |
-
"# (otherwise the InSilicoPerturber will use the current default model dictionary)\n",
|
72 |
-
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
73 |
"isp = InSilicoPerturber(perturb_type=\"delete\",\n",
|
74 |
" perturb_rank_shift=None,\n",
|
75 |
" genes_to_perturb=\"all\",\n",
|
76 |
" combos=0,\n",
|
77 |
" anchor_gene=None,\n",
|
78 |
-
" model_type=\"CellClassifier\"
|
79 |
" num_classes=3,\n",
|
80 |
" emb_mode=\"cell\",\n",
|
81 |
" cell_emb_style=\"mean_pool\",\n",
|
@@ -96,10 +90,9 @@
|
|
96 |
"outputs": [],
|
97 |
"source": [
|
98 |
"# outputs intermediate files from in silico perturbation\n",
|
99 |
-
"\n",
|
100 |
-
"isp.perturb_data(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
|
101 |
" \"path/to/input_data\",\n",
|
102 |
-
" \"path/to/
|
103 |
" \"output_prefix\")"
|
104 |
]
|
105 |
},
|
@@ -110,9 +103,6 @@
|
|
110 |
"metadata": {},
|
111 |
"outputs": [],
|
112 |
"source": [
|
113 |
-
"# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
|
114 |
-
"# (otherwise the InSilicoPerturberStats will use the current default model dictionary)\n",
|
115 |
-
"# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
|
116 |
"ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
|
117 |
" genes_perturbed=\"all\",\n",
|
118 |
" combos=0,\n",
|
@@ -128,9 +118,9 @@
|
|
128 |
"outputs": [],
|
129 |
"source": [
|
130 |
"# extracts data from intermediate files and processes stats to output in final .csv\n",
|
131 |
-
"ispstats.get_stats(\"path/to/
|
132 |
" None,\n",
|
133 |
-
" \"path/to/
|
134 |
" \"output_prefix\")"
|
135 |
]
|
136 |
}
|
@@ -151,7 +141,7 @@
|
|
151 |
"name": "python",
|
152 |
"nbconvert_exporter": "python",
|
153 |
"pygments_lexer": "ipython3",
|
154 |
-
"version": "3.10.
|
155 |
}
|
156 |
},
|
157 |
"nbformat": 4,
|
|
|
39 |
"\n",
|
40 |
"filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
|
41 |
"\n",
|
42 |
+
"embex = EmbExtractor(model_type=\"CellClassifier\",\n",
|
|
|
|
|
|
|
43 |
" num_classes=3,\n",
|
44 |
" filter_data=filter_data_dict,\n",
|
45 |
" max_ncells=1000,\n",
|
|
|
49 |
" nproc=16)\n",
|
50 |
"\n",
|
51 |
"state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
|
52 |
+
" \"path/to/model\",\n",
|
53 |
" \"path/to/input_data\",\n",
|
54 |
" \"path/to/output_directory\",\n",
|
55 |
" \"output_prefix\")"
|
|
|
64 |
},
|
65 |
"outputs": [],
|
66 |
"source": [
|
|
|
|
|
|
|
67 |
"isp = InSilicoPerturber(perturb_type=\"delete\",\n",
|
68 |
" perturb_rank_shift=None,\n",
|
69 |
" genes_to_perturb=\"all\",\n",
|
70 |
" combos=0,\n",
|
71 |
" anchor_gene=None,\n",
|
72 |
+
" model_type=\"CellClassifier\",\n",
|
73 |
" num_classes=3,\n",
|
74 |
" emb_mode=\"cell\",\n",
|
75 |
" cell_emb_style=\"mean_pool\",\n",
|
|
|
90 |
"outputs": [],
|
91 |
"source": [
|
92 |
"# outputs intermediate files from in silico perturbation\n",
|
93 |
+
"isp.perturb_data(\"path/to/model\",\n",
|
|
|
94 |
" \"path/to/input_data\",\n",
|
95 |
+
" \"path/to/output_directory\",\n",
|
96 |
" \"output_prefix\")"
|
97 |
]
|
98 |
},
|
|
|
103 |
"metadata": {},
|
104 |
"outputs": [],
|
105 |
"source": [
|
|
|
|
|
|
|
106 |
"ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
|
107 |
" genes_perturbed=\"all\",\n",
|
108 |
" combos=0,\n",
|
|
|
118 |
"outputs": [],
|
119 |
"source": [
|
120 |
"# extracts data from intermediate files and processes stats to output in final .csv\n",
|
121 |
+
"ispstats.get_stats(\"path/to/input_data\",\n",
|
122 |
" None,\n",
|
123 |
+
" \"path/to/output_directory\",\n",
|
124 |
" \"output_prefix\")"
|
125 |
]
|
126 |
}
|
|
|
141 |
"name": "python",
|
142 |
"nbconvert_exporter": "python",
|
143 |
"pygments_lexer": "ipython3",
|
144 |
+
"version": "3.10.11"
|
145 |
}
|
146 |
},
|
147 |
"nbformat": 4,
|
examples/tokenizing_scRNAseq_data.ipynb
CHANGED
@@ -12,7 +12,7 @@
|
|
12 |
},
|
13 |
{
|
14 |
"cell_type": "markdown",
|
15 |
-
"id": "
|
16 |
"metadata": {},
|
17 |
"source": [
|
18 |
"#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
|
@@ -25,21 +25,11 @@
|
|
25 |
"\n",
|
26 |
"#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
|
27 |
"\n",
|
28 |
-
"#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer
|
29 |
-
|
30 |
-
},
|
31 |
-
{
|
32 |
-
"cell_type": "markdown",
|
33 |
-
"id": "32c69493-4e5a-4b07-8dc1-958ff2ee7d0b",
|
34 |
-
"metadata": {},
|
35 |
-
"source": [
|
36 |
-
"**********************************************************************************************************\n",
|
37 |
"#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
|
38 |
-
"#### 95M: current defaults; 30M: https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer/gene_dictionaries_30m\n",
|
39 |
"\n",
|
40 |
-
"####
|
41 |
-
"#### The 95M model series require the special_token argument to be set to True and model_input_size to be 4096. (current defaults)\n",
|
42 |
-
"#### The 30M model series require the special_token argument to be set to False and the model_input_size to be 2048."
|
43 |
]
|
44 |
},
|
45 |
{
|
@@ -83,7 +73,7 @@
|
|
83 |
"name": "python",
|
84 |
"nbconvert_exporter": "python",
|
85 |
"pygments_lexer": "ipython3",
|
86 |
-
"version": "3.10.
|
87 |
}
|
88 |
},
|
89 |
"nbformat": 4,
|
|
|
12 |
},
|
13 |
{
|
14 |
"cell_type": "markdown",
|
15 |
+
"id": "350e6252-b783-494b-9767-f087eb868a15",
|
16 |
"metadata": {},
|
17 |
"source": [
|
18 |
"#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
|
|
|
25 |
"\n",
|
26 |
"#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
|
27 |
"\n",
|
28 |
+
"#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.\n",
|
29 |
+
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
"#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
|
|
|
31 |
"\n",
|
32 |
+
"#### The 95M model series also require the special_token argument to be set to True and model_input_size to be 4096."
|
|
|
|
|
33 |
]
|
34 |
},
|
35 |
{
|
|
|
73 |
"name": "python",
|
74 |
"nbconvert_exporter": "python",
|
75 |
"pygments_lexer": "ipython3",
|
76 |
+
"version": "3.10.11"
|
77 |
}
|
78 |
},
|
79 |
"nbformat": 4,
|
geneformer/emb_extractor.py
CHANGED
@@ -411,7 +411,7 @@ class EmbExtractor:
|
|
411 |
self,
|
412 |
model_type="Pretrained",
|
413 |
num_classes=0,
|
414 |
-
emb_mode="
|
415 |
cell_emb_style="mean_pool",
|
416 |
gene_emb_style="mean_pool",
|
417 |
filter_data=None,
|
@@ -596,12 +596,6 @@ class EmbExtractor:
|
|
596 |
filtered_input_data = pu.load_and_filter(
|
597 |
self.filter_data, self.nproc, input_data_file
|
598 |
)
|
599 |
-
|
600 |
-
# Check to make sure that all the labels exist in the tokenized data:
|
601 |
-
if self.emb_label is not None:
|
602 |
-
for label in self.emb_label:
|
603 |
-
assert label in filtered_input_data.features.keys(), f"Attribute `{label}` not present in dataset features"
|
604 |
-
|
605 |
if cell_state is not None:
|
606 |
filtered_input_data = pu.filter_by_dict(
|
607 |
filtered_input_data, cell_state, self.nproc
|
@@ -725,12 +719,6 @@ class EmbExtractor:
|
|
725 |
)
|
726 |
raise
|
727 |
|
728 |
-
if self.emb_label is not None:
|
729 |
-
logger.error(
|
730 |
-
"For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
|
731 |
-
)
|
732 |
-
raise
|
733 |
-
|
734 |
state_embs_dict = dict()
|
735 |
state_key = cell_states_to_model["state_key"]
|
736 |
for k, v in cell_states_to_model.items():
|
|
|
411 |
self,
|
412 |
model_type="Pretrained",
|
413 |
num_classes=0,
|
414 |
+
emb_mode="cell",
|
415 |
cell_emb_style="mean_pool",
|
416 |
gene_emb_style="mean_pool",
|
417 |
filter_data=None,
|
|
|
596 |
filtered_input_data = pu.load_and_filter(
|
597 |
self.filter_data, self.nproc, input_data_file
|
598 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
599 |
if cell_state is not None:
|
600 |
filtered_input_data = pu.filter_by_dict(
|
601 |
filtered_input_data, cell_state, self.nproc
|
|
|
719 |
)
|
720 |
raise
|
721 |
|
|
|
|
|
|
|
|
|
|
|
|
|
722 |
state_embs_dict = dict()
|
723 |
state_key = cell_states_to_model["state_key"]
|
724 |
for k, v in cell_states_to_model.items():
|
geneformer/gene_name_id_dict_gc95M.pkl
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b0fd0521406ed18b2e341ef0acb5f53aa1a62457a07ca5840e1c142f46dd326
|
3 |
+
size 2038812
|
geneformer/in_silico_perturber.py
CHANGED
@@ -40,7 +40,7 @@ import pickle
|
|
40 |
from collections import defaultdict
|
41 |
|
42 |
import torch
|
43 |
-
from datasets import Dataset
|
44 |
from multiprocess import set_start_method
|
45 |
from tqdm.auto import trange
|
46 |
|
@@ -48,9 +48,7 @@ from . import TOKEN_DICTIONARY_FILE
|
|
48 |
from . import perturber_utils as pu
|
49 |
from .emb_extractor import get_embs
|
50 |
|
51 |
-
|
52 |
-
datasets.logging.disable_progress_bar()
|
53 |
-
|
54 |
|
55 |
logger = logging.getLogger(__name__)
|
56 |
|
@@ -86,7 +84,7 @@ class InSilicoPerturber:
|
|
86 |
anchor_gene=None,
|
87 |
model_type="Pretrained",
|
88 |
num_classes=0,
|
89 |
-
emb_mode="
|
90 |
cell_emb_style="mean_pool",
|
91 |
filter_data=None,
|
92 |
cell_states_to_model=None,
|
@@ -796,8 +794,6 @@ class InSilicoPerturber:
|
|
796 |
return example
|
797 |
|
798 |
total_batch_length = len(filtered_input_data)
|
799 |
-
|
800 |
-
|
801 |
if self.cell_states_to_model is None:
|
802 |
cos_sims_dict = defaultdict(list)
|
803 |
else:
|
@@ -882,7 +878,7 @@ class InSilicoPerturber:
|
|
882 |
)
|
883 |
|
884 |
##### CLS and Gene Embedding Mode #####
|
885 |
-
elif self.emb_mode == "cls_and_gene":
|
886 |
full_original_emb = get_embs(
|
887 |
model,
|
888 |
minibatch,
|
@@ -895,7 +891,6 @@ class InSilicoPerturber:
|
|
895 |
silent=True,
|
896 |
)
|
897 |
indices_to_perturb = perturbation_batch["perturb_index"]
|
898 |
-
|
899 |
# remove indices that were perturbed
|
900 |
original_emb = pu.remove_perturbed_indices_set(
|
901 |
full_original_emb,
|
@@ -904,7 +899,6 @@ class InSilicoPerturber:
|
|
904 |
self.tokens_to_perturb,
|
905 |
minibatch["length"],
|
906 |
)
|
907 |
-
|
908 |
full_perturbation_emb = get_embs(
|
909 |
model,
|
910 |
perturbation_batch,
|
@@ -916,7 +910,7 @@ class InSilicoPerturber:
|
|
916 |
summary_stat=None,
|
917 |
silent=True,
|
918 |
)
|
919 |
-
|
920 |
# remove special tokens and padding
|
921 |
original_emb = original_emb[:, 1:-1, :]
|
922 |
if self.perturb_type == "overexpress":
|
@@ -927,25 +921,9 @@ class InSilicoPerturber:
|
|
927 |
perturbation_emb = full_perturbation_emb[
|
928 |
:, 1 : max(perturbation_batch["length"]) - 1, :
|
929 |
]
|
930 |
-
|
931 |
-
n_perturbation_genes = perturbation_emb.size()[1]
|
932 |
|
933 |
-
|
934 |
-
if self.perturb_type == "overexpress":
|
935 |
-
def calc_perturbation_length(ids):
|
936 |
-
if ids == [-100]:
|
937 |
-
return 0
|
938 |
-
else:
|
939 |
-
return len(ids)
|
940 |
-
|
941 |
-
max_tensor_size = max([length - calc_perturbation_length(ids) - 2 for length, ids in zip(minibatch["length"], indices_to_perturb)])
|
942 |
|
943 |
-
max_n_overflow = max(minibatch["n_overflow"])
|
944 |
-
if max_n_overflow > 0 and perturbation_emb.size()[1] < original_emb.size()[1]:
|
945 |
-
original_emb = original_emb[:, 0 : perturbation_emb.size()[1], :]
|
946 |
-
elif perturbation_emb.size()[1] < original_emb.size()[1]:
|
947 |
-
original_emb = original_emb[:, 0:max_tensor_size, :]
|
948 |
-
|
949 |
gene_cos_sims = pu.quant_cos_sims(
|
950 |
perturbation_emb,
|
951 |
original_emb,
|
|
|
40 |
from collections import defaultdict
|
41 |
|
42 |
import torch
|
43 |
+
from datasets import Dataset, disable_progress_bars
|
44 |
from multiprocess import set_start_method
|
45 |
from tqdm.auto import trange
|
46 |
|
|
|
48 |
from . import perturber_utils as pu
|
49 |
from .emb_extractor import get_embs
|
50 |
|
51 |
+
disable_progress_bars()
|
|
|
|
|
52 |
|
53 |
logger = logging.getLogger(__name__)
|
54 |
|
|
|
84 |
anchor_gene=None,
|
85 |
model_type="Pretrained",
|
86 |
num_classes=0,
|
87 |
+
emb_mode="cell",
|
88 |
cell_emb_style="mean_pool",
|
89 |
filter_data=None,
|
90 |
cell_states_to_model=None,
|
|
|
794 |
return example
|
795 |
|
796 |
total_batch_length = len(filtered_input_data)
|
|
|
|
|
797 |
if self.cell_states_to_model is None:
|
798 |
cos_sims_dict = defaultdict(list)
|
799 |
else:
|
|
|
878 |
)
|
879 |
|
880 |
##### CLS and Gene Embedding Mode #####
|
881 |
+
elif self.emb_mode == "cls_and_gene":
|
882 |
full_original_emb = get_embs(
|
883 |
model,
|
884 |
minibatch,
|
|
|
891 |
silent=True,
|
892 |
)
|
893 |
indices_to_perturb = perturbation_batch["perturb_index"]
|
|
|
894 |
# remove indices that were perturbed
|
895 |
original_emb = pu.remove_perturbed_indices_set(
|
896 |
full_original_emb,
|
|
|
899 |
self.tokens_to_perturb,
|
900 |
minibatch["length"],
|
901 |
)
|
|
|
902 |
full_perturbation_emb = get_embs(
|
903 |
model,
|
904 |
perturbation_batch,
|
|
|
910 |
summary_stat=None,
|
911 |
silent=True,
|
912 |
)
|
913 |
+
|
914 |
# remove special tokens and padding
|
915 |
original_emb = original_emb[:, 1:-1, :]
|
916 |
if self.perturb_type == "overexpress":
|
|
|
921 |
perturbation_emb = full_perturbation_emb[
|
922 |
:, 1 : max(perturbation_batch["length"]) - 1, :
|
923 |
]
|
|
|
|
|
924 |
|
925 |
+
n_perturbation_genes = perturbation_emb.size()[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
926 |
|
|
|
|
|
|
|
|
|
|
|
|
|
927 |
gene_cos_sims = pu.quant_cos_sims(
|
928 |
perturbation_emb,
|
929 |
original_emb,
|
geneformer/in_silico_perturber_stats.py
CHANGED
@@ -640,16 +640,10 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
|
|
640 |
cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
|
641 |
|
642 |
# quantify number of detections of each gene
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
]
|
648 |
-
else:
|
649 |
-
cos_sims_full_df["N_Detections"] = [
|
650 |
-
n_detections(i, dict_list, "gene", anchor_token)
|
651 |
-
for i in cos_sims_full_df["Gene"]
|
652 |
-
]
|
653 |
|
654 |
if combos == 0:
|
655 |
cos_sims_full_df = cos_sims_full_df.sort_values(
|
|
|
640 |
cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
|
641 |
|
642 |
# quantify number of detections of each gene
|
643 |
+
cos_sims_full_df["N_Detections"] = [
|
644 |
+
n_detections(i, dict_list, "gene", anchor_token)
|
645 |
+
for i in cos_sims_full_df["Gene"]
|
646 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
647 |
|
648 |
if combos == 0:
|
649 |
cos_sims_full_df = cos_sims_full_df.sort_values(
|
geneformer/mtl/data.py
CHANGED
@@ -1,162 +1,150 @@
|
|
1 |
import os
|
|
|
2 |
from .collators import DataCollatorForMultitaskCellClassification
|
3 |
from .imports import *
|
4 |
|
5 |
-
def validate_columns(dataset, required_columns, dataset_type):
|
6 |
-
"""Ensures required columns are present in the dataset."""
|
7 |
-
missing_columns = [col for col in required_columns if col not in dataset.column_names]
|
8 |
-
if missing_columns:
|
9 |
-
raise KeyError(
|
10 |
-
f"Missing columns in {dataset_type} dataset: {missing_columns}. "
|
11 |
-
f"Available columns: {dataset.column_names}"
|
12 |
-
)
|
13 |
-
|
14 |
-
|
15 |
-
def create_label_mappings(dataset, task_to_column):
|
16 |
-
"""Creates label mappings for the dataset."""
|
17 |
-
task_label_mappings = {}
|
18 |
-
num_labels_list = []
|
19 |
-
for task, column in task_to_column.items():
|
20 |
-
unique_values = sorted(set(dataset[column]))
|
21 |
-
mapping = {label: idx for idx, label in enumerate(unique_values)}
|
22 |
-
task_label_mappings[task] = mapping
|
23 |
-
num_labels_list.append(len(unique_values))
|
24 |
-
return task_label_mappings, num_labels_list
|
25 |
-
|
26 |
-
|
27 |
-
def save_label_mappings(mappings, path):
|
28 |
-
"""Saves label mappings to a pickle file."""
|
29 |
-
with open(path, "wb") as f:
|
30 |
-
pickle.dump(mappings, f)
|
31 |
-
|
32 |
-
|
33 |
-
def load_label_mappings(path):
|
34 |
-
"""Loads label mappings from a pickle file."""
|
35 |
-
with open(path, "rb") as f:
|
36 |
-
return pickle.load(f)
|
37 |
-
|
38 |
-
|
39 |
-
def transform_dataset(dataset, task_to_column, task_label_mappings, config, is_test):
|
40 |
-
"""Transforms the dataset to the required format."""
|
41 |
-
transformed_dataset = []
|
42 |
-
cell_id_mapping = {}
|
43 |
-
|
44 |
-
for idx, record in enumerate(dataset):
|
45 |
-
transformed_record = {
|
46 |
-
"input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
|
47 |
-
"cell_id": idx, # Index-based cell ID
|
48 |
-
}
|
49 |
-
|
50 |
-
if not is_test:
|
51 |
-
label_dict = {
|
52 |
-
task: task_label_mappings[task][record[column]]
|
53 |
-
for task, column in task_to_column.items()
|
54 |
-
}
|
55 |
-
else:
|
56 |
-
label_dict = {task: -1 for task in config["task_names"]}
|
57 |
-
|
58 |
-
transformed_record["label"] = label_dict
|
59 |
-
transformed_dataset.append(transformed_record)
|
60 |
-
cell_id_mapping[idx] = record.get("unique_cell_id", idx)
|
61 |
-
|
62 |
-
return transformed_dataset, cell_id_mapping
|
63 |
-
|
64 |
|
65 |
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
|
66 |
-
"""Main function to load and preprocess data."""
|
67 |
try:
|
68 |
dataset = load_from_disk(dataset_path)
|
69 |
|
70 |
-
# Setup task and column mappings
|
71 |
task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
|
72 |
task_to_column = dict(zip(task_names, config["task_columns"]))
|
73 |
config["task_names"] = task_names
|
74 |
|
75 |
-
label_mappings_path = os.path.join(
|
76 |
-
config["results_dir"],
|
77 |
-
f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
|
78 |
-
)
|
79 |
-
|
80 |
if not is_test:
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
else:
|
87 |
-
# Load
|
88 |
-
task_label_mappings
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
dataset, task_to_column, task_label_mappings, config, is_test
|
94 |
-
)
|
95 |
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
|
|
98 |
except KeyError as e:
|
99 |
-
|
100 |
except Exception as e:
|
101 |
-
|
|
|
102 |
|
103 |
|
104 |
def preload_and_process_data(config):
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
117 |
|
118 |
-
def validate_label_mappings(config):
|
119 |
-
"""Ensures train and validation label mappings are consistent."""
|
120 |
-
train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
|
121 |
-
val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
|
122 |
-
train_mappings = load_label_mappings(train_mappings_path)
|
123 |
-
val_mappings = load_label_mappings(val_mappings_path)
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
raise ValueError(
|
128 |
-
f"Mismatch in label mappings for task '{task_name}'.\n"
|
129 |
-
f"Train Mapping: {train_mappings[task_name]}\n"
|
130 |
-
f"Validation Mapping: {val_mappings[task_name]}"
|
131 |
-
)
|
132 |
|
|
|
133 |
|
134 |
-
|
135 |
-
"""Creates a DataLoader with optimal settings."""
|
136 |
-
return DataLoader(
|
137 |
preprocessed_dataset,
|
138 |
batch_size=batch_size,
|
139 |
shuffle=True,
|
140 |
-
collate_fn=
|
141 |
-
num_workers=
|
142 |
pin_memory=True,
|
143 |
)
|
|
|
144 |
|
145 |
|
146 |
def preload_data(config):
|
147 |
-
|
148 |
-
train_loader = get_data_loader(
|
149 |
-
val_loader = get_data_loader(
|
150 |
return train_loader, val_loader
|
151 |
|
152 |
|
153 |
def load_and_preprocess_test_data(config):
|
154 |
-
"""
|
|
|
|
|
155 |
return load_and_preprocess_data(config["test_path"], config, is_test=True)
|
156 |
|
157 |
|
158 |
def prepare_test_loader(config):
|
159 |
-
"""
|
160 |
-
|
|
|
|
|
|
|
|
|
161 |
test_loader = get_data_loader(test_dataset, config["batch_size"])
|
162 |
return test_loader, cell_id_mapping, num_labels_list
|
|
|
1 |
import os
|
2 |
+
|
3 |
from .collators import DataCollatorForMultitaskCellClassification
|
4 |
from .imports import *
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
|
|
|
8 |
try:
|
9 |
dataset = load_from_disk(dataset_path)
|
10 |
|
|
|
11 |
task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
|
12 |
task_to_column = dict(zip(task_names, config["task_columns"]))
|
13 |
config["task_names"] = task_names
|
14 |
|
|
|
|
|
|
|
|
|
|
|
15 |
if not is_test:
|
16 |
+
available_columns = set(dataset.column_names)
|
17 |
+
for column in task_to_column.values():
|
18 |
+
if column not in available_columns:
|
19 |
+
raise KeyError(
|
20 |
+
f"Column {column} not found in the dataset. Available columns: {list(available_columns)}"
|
21 |
+
)
|
22 |
+
|
23 |
+
label_mappings = {}
|
24 |
+
task_label_mappings = {}
|
25 |
+
cell_id_mapping = {}
|
26 |
+
num_labels_list = []
|
27 |
+
|
28 |
+
# Load or create task label mappings
|
29 |
+
if not is_test:
|
30 |
+
for task, column in task_to_column.items():
|
31 |
+
unique_values = sorted(set(dataset[column])) # Ensure consistency
|
32 |
+
label_mappings[column] = {
|
33 |
+
label: idx for idx, label in enumerate(unique_values)
|
34 |
+
}
|
35 |
+
task_label_mappings[task] = label_mappings[column]
|
36 |
+
num_labels_list.append(len(unique_values))
|
37 |
+
|
38 |
+
# Print the mappings for each task with dataset type prefix
|
39 |
+
for task, mapping in task_label_mappings.items():
|
40 |
+
print(
|
41 |
+
f"{dataset_type.capitalize()} mapping for {task}: {mapping}"
|
42 |
+
) # sanity check, for train/validation splits
|
43 |
+
|
44 |
+
# Save the task label mappings as a pickle file
|
45 |
+
with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
|
46 |
+
pickle.dump(task_label_mappings, f)
|
47 |
else:
|
48 |
+
# Load task label mappings from pickle file for test data
|
49 |
+
with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
|
50 |
+
task_label_mappings = pickle.load(f)
|
51 |
+
|
52 |
+
# Infer num_labels_list from task_label_mappings
|
53 |
+
for task, mapping in task_label_mappings.items():
|
54 |
+
num_labels_list.append(len(mapping))
|
55 |
+
|
56 |
+
# Store unique cell IDs in a separate dictionary
|
57 |
+
for idx, record in enumerate(dataset):
|
58 |
+
cell_id = record.get("unique_cell_id", idx)
|
59 |
+
cell_id_mapping[idx] = cell_id
|
60 |
+
|
61 |
+
# Transform records to the desired format
|
62 |
+
transformed_dataset = []
|
63 |
+
for idx, record in enumerate(dataset):
|
64 |
+
transformed_record = {}
|
65 |
+
transformed_record["input_ids"] = torch.tensor(
|
66 |
+
record["input_ids"], dtype=torch.long
|
67 |
+
)
|
68 |
|
69 |
+
# Use index-based cell ID for internal tracking
|
70 |
+
transformed_record["cell_id"] = idx
|
|
|
|
|
71 |
|
72 |
+
if not is_test:
|
73 |
+
# Prepare labels
|
74 |
+
label_dict = {}
|
75 |
+
for task, column in task_to_column.items():
|
76 |
+
label_value = record[column]
|
77 |
+
label_index = task_label_mappings[task][label_value]
|
78 |
+
label_dict[task] = label_index
|
79 |
+
transformed_record["label"] = label_dict
|
80 |
+
else:
|
81 |
+
# Create dummy labels for test data
|
82 |
+
label_dict = {task: -1 for task in config["task_names"]}
|
83 |
+
transformed_record["label"] = label_dict
|
84 |
+
|
85 |
+
transformed_dataset.append(transformed_record)
|
86 |
|
87 |
+
return transformed_dataset, cell_id_mapping, num_labels_list
|
88 |
except KeyError as e:
|
89 |
+
print(f"Missing configuration or dataset key: {e}")
|
90 |
except Exception as e:
|
91 |
+
print(f"An error occurred while loading or preprocessing data: {e}")
|
92 |
+
return None, None, None
|
93 |
|
94 |
|
95 |
def preload_and_process_data(config):
|
96 |
+
# Load and preprocess data once
|
97 |
+
train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(
|
98 |
+
config["train_path"], config, dataset_type="train"
|
99 |
+
)
|
100 |
+
val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(
|
101 |
+
config["val_path"], config, dataset_type="validation"
|
102 |
+
)
|
103 |
+
return (
|
104 |
+
train_dataset,
|
105 |
+
train_cell_id_mapping,
|
106 |
+
val_dataset,
|
107 |
+
val_cell_id_mapping,
|
108 |
+
num_labels_list,
|
109 |
+
)
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
+
def get_data_loader(preprocessed_dataset, batch_size):
|
113 |
+
nproc = os.cpu_count() ### I/O operations
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
+
data_collator = DataCollatorForMultitaskCellClassification()
|
116 |
|
117 |
+
loader = DataLoader(
|
|
|
|
|
118 |
preprocessed_dataset,
|
119 |
batch_size=batch_size,
|
120 |
shuffle=True,
|
121 |
+
collate_fn=data_collator,
|
122 |
+
num_workers=nproc,
|
123 |
pin_memory=True,
|
124 |
)
|
125 |
+
return loader
|
126 |
|
127 |
|
128 |
def preload_data(config):
|
129 |
+
# Preprocessing the data before the Optuna trials start
|
130 |
+
train_loader = get_data_loader("train", config)
|
131 |
+
val_loader = get_data_loader("val", config)
|
132 |
return train_loader, val_loader
|
133 |
|
134 |
|
135 |
def load_and_preprocess_test_data(config):
|
136 |
+
"""
|
137 |
+
Load and preprocess test data, treating it as unlabeled.
|
138 |
+
"""
|
139 |
return load_and_preprocess_data(config["test_path"], config, is_test=True)
|
140 |
|
141 |
|
142 |
def prepare_test_loader(config):
|
143 |
+
"""
|
144 |
+
Prepare DataLoader for the test dataset.
|
145 |
+
"""
|
146 |
+
test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(
|
147 |
+
config
|
148 |
+
)
|
149 |
test_loader = get_data_loader(test_dataset, config["batch_size"])
|
150 |
return test_loader, cell_id_mapping, num_labels_list
|
geneformer/pretrainer.py
CHANGED
@@ -8,12 +8,13 @@ import math
|
|
8 |
import pickle
|
9 |
import warnings
|
10 |
from enum import Enum
|
11 |
-
from typing import Dict, List, Optional, Union
|
12 |
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
from datasets import Dataset
|
16 |
from packaging import version
|
|
|
17 |
from torch.utils.data.sampler import RandomSampler
|
18 |
from transformers import (
|
19 |
BatchEncoding,
|
@@ -23,8 +24,11 @@ from transformers import (
|
|
23 |
)
|
24 |
from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
|
25 |
from transformers.trainer_pt_utils import (
|
|
|
|
|
26 |
LengthGroupedSampler,
|
27 |
)
|
|
|
28 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
29 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
30 |
|
@@ -603,7 +607,7 @@ class GeneformerPretrainer(Trainer):
|
|
603 |
)
|
604 |
super().__init__(*args, **kwargs)
|
605 |
|
606 |
-
#
|
607 |
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
608 |
if not isinstance(self.train_dataset, collections.abc.Sized):
|
609 |
return None
|
@@ -626,15 +630,181 @@ class GeneformerPretrainer(Trainer):
|
|
626 |
if self.tokenizer is not None
|
627 |
else None
|
628 |
)
|
629 |
-
|
|
|
630 |
dataset=self.train_dataset,
|
631 |
batch_size=self.args.train_batch_size,
|
632 |
lengths=lengths,
|
633 |
model_input_name=model_input_name,
|
634 |
generator=generator,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
635 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
636 |
|
|
|
|
|
|
|
637 |
else:
|
638 |
-
|
639 |
-
|
640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
import pickle
|
9 |
import warnings
|
10 |
from enum import Enum
|
11 |
+
from typing import Dict, Iterator, List, Optional, Union
|
12 |
|
13 |
import numpy as np
|
14 |
import torch
|
15 |
from datasets import Dataset
|
16 |
from packaging import version
|
17 |
+
from torch.utils.data.distributed import DistributedSampler
|
18 |
from torch.utils.data.sampler import RandomSampler
|
19 |
from transformers import (
|
20 |
BatchEncoding,
|
|
|
24 |
)
|
25 |
from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
|
26 |
from transformers.trainer_pt_utils import (
|
27 |
+
DistributedLengthGroupedSampler,
|
28 |
+
DistributedSamplerWithLoop,
|
29 |
LengthGroupedSampler,
|
30 |
)
|
31 |
+
from transformers.training_args import ParallelMode
|
32 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
33 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
34 |
|
|
|
607 |
)
|
608 |
super().__init__(*args, **kwargs)
|
609 |
|
610 |
+
# modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
|
611 |
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
612 |
if not isinstance(self.train_dataset, collections.abc.Sized):
|
613 |
return None
|
|
|
630 |
if self.tokenizer is not None
|
631 |
else None
|
632 |
)
|
633 |
+
if self.args.world_size <= 1:
|
634 |
+
return LengthGroupedSampler(
|
635 |
dataset=self.train_dataset,
|
636 |
batch_size=self.args.train_batch_size,
|
637 |
lengths=lengths,
|
638 |
model_input_name=model_input_name,
|
639 |
generator=generator,
|
640 |
+
)
|
641 |
+
else:
|
642 |
+
return CustomDistributedLengthGroupedSampler(
|
643 |
+
dataset=self.train_dataset,
|
644 |
+
batch_size=self.args.train_batch_size,
|
645 |
+
num_replicas=self.args.world_size,
|
646 |
+
rank=self.args.process_index,
|
647 |
+
lengths=lengths,
|
648 |
+
model_input_name=model_input_name,
|
649 |
+
seed=self.args.seed,
|
650 |
+
)
|
651 |
+
|
652 |
+
else:
|
653 |
+
if self.args.world_size <= 1:
|
654 |
+
if _is_torch_generator_available:
|
655 |
+
return RandomSampler(self.train_dataset, generator=generator)
|
656 |
+
return RandomSampler(self.train_dataset)
|
657 |
+
elif (
|
658 |
+
self.args.parallel_mode
|
659 |
+
in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
|
660 |
+
and not self.args.dataloader_drop_last
|
661 |
+
):
|
662 |
+
# Use a loop for TPUs when drop_last is False to have all batches have the same size.
|
663 |
+
return DistributedSamplerWithLoop(
|
664 |
+
self.train_dataset,
|
665 |
+
batch_size=self.args.per_device_train_batch_size,
|
666 |
+
num_replicas=self.args.world_size,
|
667 |
+
rank=self.args.process_index,
|
668 |
+
seed=self.args.seed,
|
669 |
+
)
|
670 |
+
else:
|
671 |
+
return DistributedSampler(
|
672 |
+
self.train_dataset,
|
673 |
+
num_replicas=self.args.world_size,
|
674 |
+
rank=self.args.process_index,
|
675 |
+
seed=self.args.seed,
|
676 |
+
)
|
677 |
+
|
678 |
+
|
679 |
+
class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
|
680 |
+
r"""
|
681 |
+
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
682 |
+
length while keeping a bit of randomness.
|
683 |
+
"""
|
684 |
+
|
685 |
+
# Copied and adapted from PyTorch DistributedSampler.
|
686 |
+
def __init__(
|
687 |
+
self,
|
688 |
+
dataset: Dataset,
|
689 |
+
batch_size: int,
|
690 |
+
num_replicas: Optional[int] = None,
|
691 |
+
rank: Optional[int] = None,
|
692 |
+
seed: int = 0,
|
693 |
+
drop_last: bool = False,
|
694 |
+
lengths: Optional[List[int]] = None,
|
695 |
+
model_input_name: Optional[str] = None,
|
696 |
+
):
|
697 |
+
if num_replicas is None:
|
698 |
+
if not dist.is_available():
|
699 |
+
raise RuntimeError("Requires distributed package to be available")
|
700 |
+
num_replicas = dist.get_world_size()
|
701 |
+
if rank is None:
|
702 |
+
if not dist.is_available():
|
703 |
+
raise RuntimeError("Requires distributed package to be available")
|
704 |
+
rank = dist.get_rank()
|
705 |
+
self.dataset = dataset
|
706 |
+
self.batch_size = batch_size
|
707 |
+
self.num_replicas = num_replicas
|
708 |
+
self.rank = rank
|
709 |
+
self.epoch = 0
|
710 |
+
self.drop_last = drop_last
|
711 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
712 |
+
# is no need to drop any data, since the dataset will be split equally.
|
713 |
+
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
714 |
+
# Split to nearest available length that is evenly divisible.
|
715 |
+
# This is to ensure each rank receives the same amount of data when
|
716 |
+
# using this Sampler.
|
717 |
+
self.num_samples = math.ceil(
|
718 |
+
(len(self.dataset) - self.num_replicas) / self.num_replicas
|
719 |
)
|
720 |
+
else:
|
721 |
+
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
722 |
+
self.total_size = self.num_samples * self.num_replicas
|
723 |
+
self.seed = seed
|
724 |
+
self.model_input_name = (
|
725 |
+
model_input_name if model_input_name is not None else "input_ids"
|
726 |
+
)
|
727 |
+
|
728 |
+
if lengths is None:
|
729 |
+
print("Lengths is none - calculating lengths.")
|
730 |
+
if (
|
731 |
+
not (
|
732 |
+
isinstance(dataset[0], dict)
|
733 |
+
or isinstance(dataset[0], BatchEncoding)
|
734 |
+
)
|
735 |
+
or self.model_input_name not in dataset[0]
|
736 |
+
):
|
737 |
+
raise ValueError(
|
738 |
+
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
739 |
+
f"'{self.model_input_name}' key."
|
740 |
+
)
|
741 |
+
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
742 |
+
self.lengths = lengths
|
743 |
+
|
744 |
+
def __iter__(self) -> Iterator:
|
745 |
+
# Deterministically shuffle based on epoch and seed
|
746 |
+
g = torch.Generator()
|
747 |
+
g.manual_seed(self.seed + self.epoch)
|
748 |
+
|
749 |
+
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
750 |
|
751 |
+
if not self.drop_last:
|
752 |
+
# add extra samples to make it evenly divisible
|
753 |
+
indices += indices[: (self.total_size - len(indices))]
|
754 |
else:
|
755 |
+
# remove tail of data to make it evenly divisible.
|
756 |
+
indices = indices[: self.total_size]
|
757 |
+
assert len(indices) == self.total_size
|
758 |
+
|
759 |
+
# subsample
|
760 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
761 |
+
assert len(indices) == self.num_samples
|
762 |
+
|
763 |
+
return iter(indices)
|
764 |
+
|
765 |
+
|
766 |
+
def get_length_grouped_indices(
|
767 |
+
lengths, batch_size, mega_batch_mult=None, generator=None
|
768 |
+
):
|
769 |
+
"""
|
770 |
+
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
|
771 |
+
similar lengths. To do this, the indices are:
|
772 |
+
|
773 |
+
- randomly permuted
|
774 |
+
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
|
775 |
+
- sorted by length in each mega-batch
|
776 |
+
|
777 |
+
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
|
778 |
+
maximum length placed first, so that an OOM happens sooner rather than later.
|
779 |
+
"""
|
780 |
+
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
|
781 |
+
if mega_batch_mult is None:
|
782 |
+
# mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
|
783 |
+
mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
|
784 |
+
# Just in case, for tiny datasets
|
785 |
+
if mega_batch_mult == 0:
|
786 |
+
mega_batch_mult = 1
|
787 |
+
|
788 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
789 |
+
indices = torch.randperm(len(lengths), generator=generator)
|
790 |
+
megabatch_size = mega_batch_mult * batch_size
|
791 |
+
megabatches = [
|
792 |
+
indices[i : i + megabatch_size].tolist()
|
793 |
+
for i in range(0, len(lengths), megabatch_size)
|
794 |
+
]
|
795 |
+
megabatches = [
|
796 |
+
list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
|
797 |
+
for megabatch in megabatches
|
798 |
+
]
|
799 |
+
|
800 |
+
# The rest is to get the biggest batch first.
|
801 |
+
# Since each megabatch is sorted by descending length, the longest element is the first
|
802 |
+
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
|
803 |
+
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
|
804 |
+
# Switch to put the longest element in first position
|
805 |
+
megabatches[0][0], megabatches[max_idx][0] = (
|
806 |
+
megabatches[max_idx][0],
|
807 |
+
megabatches[0][0],
|
808 |
+
)
|
809 |
+
|
810 |
+
return [item for sublist in megabatches for item in sublist]
|
geneformer/tokenizer.py
CHANGED
@@ -88,7 +88,6 @@ def sum_ensembl_ids(
|
|
88 |
collapse_gene_ids,
|
89 |
gene_mapping_dict,
|
90 |
gene_token_dict,
|
91 |
-
custom_attr_name_dict,
|
92 |
file_format="loom",
|
93 |
chunk_size=512,
|
94 |
):
|
@@ -104,45 +103,33 @@ def sum_ensembl_ids(
|
|
104 |
assert (
|
105 |
"ensembl_id_collapsed" not in data.ra.keys()
|
106 |
), "'ensembl_id_collapsed' column already exists in data.ra.keys()"
|
107 |
-
|
108 |
-
assert (
|
109 |
-
"n_counts" in data.ca.keys()
|
110 |
-
), "'n_counts' column missing from data.ca.keys()"
|
111 |
-
|
112 |
-
if custom_attr_name_dict is not None:
|
113 |
-
for label in custom_attr_name_dict:
|
114 |
-
assert label in data.ca.keys(), f"Attribute `{label}` not present in dataset features"
|
115 |
-
|
116 |
-
# Get the ensembl ids that exist in data
|
117 |
-
ensembl_ids = data.ra.ensembl_id
|
118 |
# Check for duplicate Ensembl IDs if collapse_gene_ids is False.
|
119 |
# Comparing to gene_token_dict here, would not perform any mapping steps
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
125 |
return data_directory
|
126 |
else:
|
127 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
data.ra["ensembl_id_collapsed"] =
|
138 |
return data_directory
|
139 |
-
# Genes need to be collapsed
|
140 |
else:
|
141 |
dedup_filename = data_directory.with_name(
|
142 |
data_directory.stem + "__dedup.loom"
|
143 |
)
|
144 |
-
|
145 |
-
data.ra["ensembl_id_collapsed"] = mapped_vals
|
146 |
dup_genes = [
|
147 |
idx
|
148 |
for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
|
@@ -216,41 +203,33 @@ def sum_ensembl_ids(
|
|
216 |
assert (
|
217 |
"ensembl_id_collapsed" not in data.var.columns
|
218 |
), "'ensembl_id_collapsed' column already exists in data.var"
|
219 |
-
assert (
|
220 |
-
"n_counts" in data.obs.columns
|
221 |
-
), "'n_counts' column missing from data.obs"
|
222 |
-
|
223 |
-
if custom_attr_name_dict is not None:
|
224 |
-
for label in custom_attr_name_dict:
|
225 |
-
assert label in data.obs.columns, f"Attribute `{label}` not present in data.obs"
|
226 |
|
227 |
-
|
228 |
-
# Get the ensembl ids that exist in data
|
229 |
-
ensembl_ids = data.var.ensembl_id
|
230 |
# Check for duplicate Ensembl IDs if collapse_gene_ids is False.
|
231 |
# Comparing to gene_token_dict here, would not perform any mapping steps
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
238 |
else:
|
239 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
240 |
|
241 |
-
#
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
|
|
249 |
return data
|
250 |
-
|
251 |
else:
|
252 |
-
data.var["ensembl_id_collapsed"] =
|
253 |
-
data.var_names =
|
254 |
data = data[:, ~data.var.index.isna()]
|
255 |
dup_genes = [
|
256 |
idx for idx, count in Counter(data.var_names).items() if count > 1
|
@@ -476,7 +455,6 @@ class TranscriptomeTokenizer:
|
|
476 |
self.collapse_gene_ids,
|
477 |
self.gene_mapping_dict,
|
478 |
self.gene_token_dict,
|
479 |
-
self.custom_attr_name_dict,
|
480 |
file_format="h5ad",
|
481 |
chunk_size=self.chunk_size,
|
482 |
)
|
@@ -553,7 +531,6 @@ class TranscriptomeTokenizer:
|
|
553 |
self.collapse_gene_ids,
|
554 |
self.gene_mapping_dict,
|
555 |
self.gene_token_dict,
|
556 |
-
self.custom_attr_name_dict,
|
557 |
file_format="loom",
|
558 |
chunk_size=self.chunk_size,
|
559 |
)
|
|
|
88 |
collapse_gene_ids,
|
89 |
gene_mapping_dict,
|
90 |
gene_token_dict,
|
|
|
91 |
file_format="loom",
|
92 |
chunk_size=512,
|
93 |
):
|
|
|
103 |
assert (
|
104 |
"ensembl_id_collapsed" not in data.ra.keys()
|
105 |
), "'ensembl_id_collapsed' column already exists in data.ra.keys()"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
# Check for duplicate Ensembl IDs if collapse_gene_ids is False.
|
107 |
# Comparing to gene_token_dict here, would not perform any mapping steps
|
108 |
+
gene_ids_in_dict = [
|
109 |
+
gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()
|
110 |
+
]
|
111 |
+
if collapse_gene_ids is False:
|
112 |
+
|
113 |
+
if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
|
114 |
return data_directory
|
115 |
else:
|
116 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
117 |
+
|
118 |
+
gene_ids_collapsed = [
|
119 |
+
gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id
|
120 |
+
]
|
121 |
+
gene_ids_collapsed_in_dict = [
|
122 |
+
gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
|
123 |
+
]
|
124 |
+
|
125 |
+
if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
|
126 |
+
data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
|
127 |
return data_directory
|
|
|
128 |
else:
|
129 |
dedup_filename = data_directory.with_name(
|
130 |
data_directory.stem + "__dedup.loom"
|
131 |
)
|
132 |
+
data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
|
|
|
133 |
dup_genes = [
|
134 |
idx
|
135 |
for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
|
|
|
203 |
assert (
|
204 |
"ensembl_id_collapsed" not in data.var.columns
|
205 |
), "'ensembl_id_collapsed' column already exists in data.var"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
|
|
|
|
|
|
207 |
# Check for duplicate Ensembl IDs if collapse_gene_ids is False.
|
208 |
# Comparing to gene_token_dict here, would not perform any mapping steps
|
209 |
+
gene_ids_in_dict = [
|
210 |
+
gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()
|
211 |
+
]
|
212 |
+
if collapse_gene_ids is False:
|
213 |
+
|
214 |
+
if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
|
215 |
+
return data
|
216 |
else:
|
217 |
raise ValueError("Error: data Ensembl IDs non-unique.")
|
218 |
|
219 |
+
# Check for when if collapse_gene_ids is True
|
220 |
+
gene_ids_collapsed = [
|
221 |
+
gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id
|
222 |
+
]
|
223 |
+
gene_ids_collapsed_in_dict = [
|
224 |
+
gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
|
225 |
+
]
|
226 |
+
if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
|
227 |
+
data.var["ensembl_id_collapsed"] = data.var.ensembl_id.map(gene_mapping_dict)
|
228 |
return data
|
229 |
+
|
230 |
else:
|
231 |
+
data.var["ensembl_id_collapsed"] = gene_ids_collapsed
|
232 |
+
data.var_names = gene_ids_collapsed
|
233 |
data = data[:, ~data.var.index.isna()]
|
234 |
dup_genes = [
|
235 |
idx for idx, count in Counter(data.var_names).items() if count > 1
|
|
|
455 |
self.collapse_gene_ids,
|
456 |
self.gene_mapping_dict,
|
457 |
self.gene_token_dict,
|
|
|
458 |
file_format="h5ad",
|
459 |
chunk_size=self.chunk_size,
|
460 |
)
|
|
|
531 |
self.collapse_gene_ids,
|
532 |
self.gene_mapping_dict,
|
533 |
self.gene_token_dict,
|
|
|
534 |
file_format="loom",
|
535 |
chunk_size=self.chunk_size,
|
536 |
)
|
requirements.txt
CHANGED
@@ -22,4 +22,4 @@ tdigest>=0.5.2
|
|
22 |
tensorboard>=2.15
|
23 |
torch>=2.0.1
|
24 |
tqdm>=4.65
|
25 |
-
transformers>=4.
|
|
|
22 |
tensorboard>=2.15
|
23 |
torch>=2.0.1
|
24 |
tqdm>=4.65
|
25 |
+
transformers>=4.28
|