README.md CHANGED
@@ -1,9 +1,6 @@
1
  ---
2
  datasets: ctheodoris/Genecorpus-30M
3
  license: apache-2.0
4
- tags:
5
- - single-cell
6
- - genomics
7
  ---
8
  # Geneformer
9
  Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
 
1
  ---
2
  datasets: ctheodoris/Genecorpus-30M
3
  license: apache-2.0
 
 
 
4
  ---
5
  # Geneformer
6
  Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
examples/cell_classification.ipynb CHANGED
@@ -68,10 +68,6 @@
68
  " \"per_device_train_batch_size\": 12,\n",
69
  " \"seed\": 73,\n",
70
  "}\n",
71
- "\n",
72
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
73
- "# (otherwise the Classifier will use the current default model dictionary)\n",
74
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
75
  "cc = Classifier(classifier=\"cell\",\n",
76
  " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
77
  " filter_data=filter_data_dict,\n",
@@ -129,7 +125,7 @@
129
  " \"train\": train_ids+eval_ids,\n",
130
  " \"test\": test_ids}\n",
131
  "\n",
132
- "# Example input_data_file for 30M model: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
133
  "cc.prepare_data(input_data_file=\"/path/to/human_dcm_hcm_nf_2048_w_length.dataset\",\n",
134
  " output_directory=output_dir,\n",
135
  " output_prefix=output_prefix,\n",
@@ -264,7 +260,7 @@
264
  " \"train\": train_ids,\n",
265
  " \"eval\": eval_ids}\n",
266
  "\n",
267
- "# Example 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors\n",
268
  "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
269
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
270
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
@@ -450,7 +446,7 @@
450
  "name": "python",
451
  "nbconvert_exporter": "python",
452
  "pygments_lexer": "ipython3",
453
- "version": "3.10.15"
454
  }
455
  },
456
  "nbformat": 4,
 
68
  " \"per_device_train_batch_size\": 12,\n",
69
  " \"seed\": 73,\n",
70
  "}\n",
 
 
 
 
71
  "cc = Classifier(classifier=\"cell\",\n",
72
  " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
73
  " filter_data=filter_data_dict,\n",
 
125
  " \"train\": train_ids+eval_ids,\n",
126
  " \"test\": test_ids}\n",
127
  "\n",
128
+ "# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
129
  "cc.prepare_data(input_data_file=\"/path/to/human_dcm_hcm_nf_2048_w_length.dataset\",\n",
130
  " output_directory=output_dir,\n",
131
  " output_prefix=output_prefix,\n",
 
260
  " \"train\": train_ids,\n",
261
  " \"eval\": eval_ids}\n",
262
  "\n",
263
+ "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
264
  "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
265
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
266
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
 
446
  "name": "python",
447
  "nbconvert_exporter": "python",
448
  "pygments_lexer": "ipython3",
449
+ "version": "3.11.5"
450
  }
451
  },
452
  "nbformat": 4,
examples/extract_and_plot_cell_embeddings.ipynb CHANGED
@@ -18,8 +18,6 @@
18
  "outputs": [],
19
  "source": [
20
  "# initiate EmbExtractor\n",
21
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
22
- "# (otherwise the EmbExtractor will use the current default model dictionary)\n",
23
  "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
24
  " num_classes=3,\n",
25
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
@@ -28,13 +26,12 @@
28
  " emb_label=[\"disease\",\"cell_type\"],\n",
29
  " labels_to_plot=[\"disease\"],\n",
30
  " forward_batch_size=200,\n",
31
- " nproc=16,\n",
32
- " token_dictionary_file=\"./gene_dictionaries_30m/token_dictionary_gc30M.pkl\") # change from current default dictionary for 30M model series\n",
33
  "\n",
34
  "# extracts embedding from input data\n",
35
  "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
36
- "# example dataset for 30M model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
37
- "embs = embex.extract_embs(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
38
  " \"path/to/input_data/\",\n",
39
  " \"path/to/output_directory/\",\n",
40
  " \"output_prefix\")\n"
@@ -132,7 +129,7 @@
132
  "name": "python",
133
  "nbconvert_exporter": "python",
134
  "pygments_lexer": "ipython3",
135
- "version": "3.10.15"
136
  }
137
  },
138
  "nbformat": 4,
 
18
  "outputs": [],
19
  "source": [
20
  "# initiate EmbExtractor\n",
 
 
21
  "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
22
  " num_classes=3,\n",
23
  " filter_data={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]},\n",
 
26
  " emb_label=[\"disease\",\"cell_type\"],\n",
27
  " labels_to_plot=[\"disease\"],\n",
28
  " forward_batch_size=200,\n",
29
+ " nproc=16)\n",
 
30
  "\n",
31
  "# extracts embedding from input data\n",
32
  "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
33
+ "# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
34
+ "embs = embex.extract_embs(\"../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224\",\n",
35
  " \"path/to/input_data/\",\n",
36
  " \"path/to/output_directory/\",\n",
37
  " \"output_prefix\")\n"
 
129
  "name": "python",
130
  "nbconvert_exporter": "python",
131
  "pygments_lexer": "ipython3",
132
+ "version": "3.11.5"
133
  }
134
  },
135
  "nbformat": 4,
examples/gene_classification.ipynb CHANGED
@@ -71,9 +71,6 @@
71
  }
72
  ],
73
  "source": [
74
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
75
- "# (otherwise the Classifier will use the current default model dictionary)\n",
76
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
77
  "cc = Classifier(classifier=\"gene\",\n",
78
  " gene_class_dict = gene_class_dict,\n",
79
  " max_ncells = 10_000,\n",
@@ -105,7 +102,7 @@
105
  }
106
  ],
107
  "source": [
108
- "# Example input_data_file for 30M model series: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/gene_classification/dosage_sensitive_tfs/gc-30M_sample50k.dataset\n",
109
  "cc.prepare_data(input_data_file=\"/path/to/gc-30M_sample50k.dataset\",\n",
110
  " output_directory=output_dir,\n",
111
  " output_prefix=output_prefix)"
@@ -843,7 +840,7 @@
843
  }
844
  ],
845
  "source": [
846
- "# 6 layer 30M Geneformer model: https://huggingface.co/ctheodoris/Geneformer/blob/main/gf-6L-30M-i2048/model.safetensors\n",
847
  "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
848
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
849
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
@@ -1243,7 +1240,7 @@
1243
  "name": "python",
1244
  "nbconvert_exporter": "python",
1245
  "pygments_lexer": "ipython3",
1246
- "version": "3.10.15"
1247
  }
1248
  },
1249
  "nbformat": 4,
 
71
  }
72
  ],
73
  "source": [
 
 
 
74
  "cc = Classifier(classifier=\"gene\",\n",
75
  " gene_class_dict = gene_class_dict,\n",
76
  " max_ncells = 10_000,\n",
 
102
  }
103
  ],
104
  "source": [
105
+ "# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/gene_classification/dosage_sensitive_tfs/gc-30M_sample50k.dataset\n",
106
  "cc.prepare_data(input_data_file=\"/path/to/gc-30M_sample50k.dataset\",\n",
107
  " output_directory=output_dir,\n",
108
  " output_prefix=output_prefix)"
 
840
  }
841
  ],
842
  "source": [
843
+ "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
844
  "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
845
  " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled.dataset\",\n",
846
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
 
1240
  "name": "python",
1241
  "nbconvert_exporter": "python",
1242
  "pygments_lexer": "ipython3",
1243
+ "version": "3.11.5"
1244
  }
1245
  },
1246
  "nbformat": 4,
examples/in_silico_perturbation.ipynb CHANGED
@@ -39,10 +39,7 @@
39
  "\n",
40
  "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
41
  "\n",
42
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
43
- "# (otherwise the EmbExtractor will use the current default model dictionary)\n",
44
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
45
- "embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
46
  " num_classes=3,\n",
47
  " filter_data=filter_data_dict,\n",
48
  " max_ncells=1000,\n",
@@ -52,7 +49,7 @@
52
  " nproc=16)\n",
53
  "\n",
54
  "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
55
- " \"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
56
  " \"path/to/input_data\",\n",
57
  " \"path/to/output_directory\",\n",
58
  " \"output_prefix\")"
@@ -67,15 +64,12 @@
67
  },
68
  "outputs": [],
69
  "source": [
70
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
71
- "# (otherwise the InSilicoPerturber will use the current default model dictionary)\n",
72
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
73
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
74
  " perturb_rank_shift=None,\n",
75
  " genes_to_perturb=\"all\",\n",
76
  " combos=0,\n",
77
  " anchor_gene=None,\n",
78
- " model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
79
  " num_classes=3,\n",
80
  " emb_mode=\"cell\",\n",
81
  " cell_emb_style=\"mean_pool\",\n",
@@ -96,10 +90,9 @@
96
  "outputs": [],
97
  "source": [
98
  "# outputs intermediate files from in silico perturbation\n",
99
- "\n",
100
- "isp.perturb_data(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
101
  " \"path/to/input_data\",\n",
102
- " \"path/to/isp_output_directory\",\n",
103
  " \"output_prefix\")"
104
  ]
105
  },
@@ -110,9 +103,6 @@
110
  "metadata": {},
111
  "outputs": [],
112
  "source": [
113
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
114
- "# (otherwise the InSilicoPerturberStats will use the current default model dictionary)\n",
115
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
116
  "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
117
  " genes_perturbed=\"all\",\n",
118
  " combos=0,\n",
@@ -128,9 +118,9 @@
128
  "outputs": [],
129
  "source": [
130
  "# extracts data from intermediate files and processes stats to output in final .csv\n",
131
- "ispstats.get_stats(\"path/to/isp_output_directory\", # this should be the directory \n",
132
  " None,\n",
133
- " \"path/to/isp_stats_output_directory\",\n",
134
  " \"output_prefix\")"
135
  ]
136
  }
@@ -151,7 +141,7 @@
151
  "name": "python",
152
  "nbconvert_exporter": "python",
153
  "pygments_lexer": "ipython3",
154
- "version": "3.10.15"
155
  }
156
  },
157
  "nbformat": 4,
 
39
  "\n",
40
  "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
41
  "\n",
42
+ "embex = EmbExtractor(model_type=\"CellClassifier\",\n",
 
 
 
43
  " num_classes=3,\n",
44
  " filter_data=filter_data_dict,\n",
45
  " max_ncells=1000,\n",
 
49
  " nproc=16)\n",
50
  "\n",
51
  "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
52
+ " \"path/to/model\",\n",
53
  " \"path/to/input_data\",\n",
54
  " \"path/to/output_directory\",\n",
55
  " \"output_prefix\")"
 
64
  },
65
  "outputs": [],
66
  "source": [
 
 
 
67
  "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
68
  " perturb_rank_shift=None,\n",
69
  " genes_to_perturb=\"all\",\n",
70
  " combos=0,\n",
71
  " anchor_gene=None,\n",
72
+ " model_type=\"CellClassifier\",\n",
73
  " num_classes=3,\n",
74
  " emb_mode=\"cell\",\n",
75
  " cell_emb_style=\"mean_pool\",\n",
 
90
  "outputs": [],
91
  "source": [
92
  "# outputs intermediate files from in silico perturbation\n",
93
+ "isp.perturb_data(\"path/to/model\",\n",
 
94
  " \"path/to/input_data\",\n",
95
+ " \"path/to/output_directory\",\n",
96
  " \"output_prefix\")"
97
  ]
98
  },
 
103
  "metadata": {},
104
  "outputs": [],
105
  "source": [
 
 
 
106
  "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
107
  " genes_perturbed=\"all\",\n",
108
  " combos=0,\n",
 
118
  "outputs": [],
119
  "source": [
120
  "# extracts data from intermediate files and processes stats to output in final .csv\n",
121
+ "ispstats.get_stats(\"path/to/input_data\",\n",
122
  " None,\n",
123
+ " \"path/to/output_directory\",\n",
124
  " \"output_prefix\")"
125
  ]
126
  }
 
141
  "name": "python",
142
  "nbconvert_exporter": "python",
143
  "pygments_lexer": "ipython3",
144
+ "version": "3.10.11"
145
  }
146
  },
147
  "nbformat": 4,
examples/tokenizing_scRNAseq_data.ipynb CHANGED
@@ -12,7 +12,7 @@
12
  },
13
  {
14
  "cell_type": "markdown",
15
- "id": "1fe86f48-5578-47df-b373-58c21ec170ab",
16
  "metadata": {},
17
  "source": [
18
  "#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
@@ -25,21 +25,11 @@
25
  "\n",
26
  "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
27
  "\n",
28
- "#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer."
29
- ]
30
- },
31
- {
32
- "cell_type": "markdown",
33
- "id": "32c69493-4e5a-4b07-8dc1-958ff2ee7d0b",
34
- "metadata": {},
35
- "source": [
36
- "**********************************************************************************************************\n",
37
  "#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
38
- "#### 95M: current defaults; 30M: https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer/gene_dictionaries_30m\n",
39
  "\n",
40
- "#### ADDITIONALLY:\n",
41
- "#### The 95M model series require the special_token argument to be set to True and model_input_size to be 4096. (current defaults)\n",
42
- "#### The 30M model series require the special_token argument to be set to False and the model_input_size to be 2048."
43
  ]
44
  },
45
  {
@@ -83,7 +73,7 @@
83
  "name": "python",
84
  "nbconvert_exporter": "python",
85
  "pygments_lexer": "ipython3",
86
- "version": "3.10.15"
87
  }
88
  },
89
  "nbformat": 4,
 
12
  },
13
  {
14
  "cell_type": "markdown",
15
+ "id": "350e6252-b783-494b-9767-f087eb868a15",
16
  "metadata": {},
17
  "source": [
18
  "#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
 
25
  "\n",
26
  "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
27
  "\n",
28
+ "#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.\n",
29
+ "\n",
 
 
 
 
 
 
 
30
  "#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
 
31
  "\n",
32
+ "#### The 95M model series also require the special_token argument to be set to True and model_input_size to be 4096."
 
 
33
  ]
34
  },
35
  {
 
73
  "name": "python",
74
  "nbconvert_exporter": "python",
75
  "pygments_lexer": "ipython3",
76
+ "version": "3.10.11"
77
  }
78
  },
79
  "nbformat": 4,
geneformer/emb_extractor.py CHANGED
@@ -411,7 +411,7 @@ class EmbExtractor:
411
  self,
412
  model_type="Pretrained",
413
  num_classes=0,
414
- emb_mode="cls",
415
  cell_emb_style="mean_pool",
416
  gene_emb_style="mean_pool",
417
  filter_data=None,
@@ -596,12 +596,6 @@ class EmbExtractor:
596
  filtered_input_data = pu.load_and_filter(
597
  self.filter_data, self.nproc, input_data_file
598
  )
599
-
600
- # Check to make sure that all the labels exist in the tokenized data:
601
- if self.emb_label is not None:
602
- for label in self.emb_label:
603
- assert label in filtered_input_data.features.keys(), f"Attribute `{label}` not present in dataset features"
604
-
605
  if cell_state is not None:
606
  filtered_input_data = pu.filter_by_dict(
607
  filtered_input_data, cell_state, self.nproc
@@ -725,12 +719,6 @@ class EmbExtractor:
725
  )
726
  raise
727
 
728
- if self.emb_label is not None:
729
- logger.error(
730
- "For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
731
- )
732
- raise
733
-
734
  state_embs_dict = dict()
735
  state_key = cell_states_to_model["state_key"]
736
  for k, v in cell_states_to_model.items():
 
411
  self,
412
  model_type="Pretrained",
413
  num_classes=0,
414
+ emb_mode="cell",
415
  cell_emb_style="mean_pool",
416
  gene_emb_style="mean_pool",
417
  filter_data=None,
 
596
  filtered_input_data = pu.load_and_filter(
597
  self.filter_data, self.nproc, input_data_file
598
  )
 
 
 
 
 
 
599
  if cell_state is not None:
600
  filtered_input_data = pu.filter_by_dict(
601
  filtered_input_data, cell_state, self.nproc
 
719
  )
720
  raise
721
 
 
 
 
 
 
 
722
  state_embs_dict = dict()
723
  state_key = cell_states_to_model["state_key"]
724
  for k, v in cell_states_to_model.items():
geneformer/gene_name_id_dict_gc95M.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fabfa0c2f49c598c59ae432a32c3499a5908c033756c663b5e0cddf58deea8e1
3
- size 1660882
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b0fd0521406ed18b2e341ef0acb5f53aa1a62457a07ca5840e1c142f46dd326
3
+ size 2038812
geneformer/in_silico_perturber.py CHANGED
@@ -40,7 +40,7 @@ import pickle
40
  from collections import defaultdict
41
 
42
  import torch
43
- from datasets import Dataset
44
  from multiprocess import set_start_method
45
  from tqdm.auto import trange
46
 
@@ -48,9 +48,7 @@ from . import TOKEN_DICTIONARY_FILE
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
50
 
51
- import datasets
52
- datasets.logging.disable_progress_bar()
53
-
54
 
55
  logger = logging.getLogger(__name__)
56
 
@@ -86,7 +84,7 @@ class InSilicoPerturber:
86
  anchor_gene=None,
87
  model_type="Pretrained",
88
  num_classes=0,
89
- emb_mode="cls",
90
  cell_emb_style="mean_pool",
91
  filter_data=None,
92
  cell_states_to_model=None,
@@ -796,8 +794,6 @@ class InSilicoPerturber:
796
  return example
797
 
798
  total_batch_length = len(filtered_input_data)
799
-
800
-
801
  if self.cell_states_to_model is None:
802
  cos_sims_dict = defaultdict(list)
803
  else:
@@ -882,7 +878,7 @@ class InSilicoPerturber:
882
  )
883
 
884
  ##### CLS and Gene Embedding Mode #####
885
- elif self.emb_mode == "cls_and_gene":
886
  full_original_emb = get_embs(
887
  model,
888
  minibatch,
@@ -895,7 +891,6 @@ class InSilicoPerturber:
895
  silent=True,
896
  )
897
  indices_to_perturb = perturbation_batch["perturb_index"]
898
-
899
  # remove indices that were perturbed
900
  original_emb = pu.remove_perturbed_indices_set(
901
  full_original_emb,
@@ -904,7 +899,6 @@ class InSilicoPerturber:
904
  self.tokens_to_perturb,
905
  minibatch["length"],
906
  )
907
-
908
  full_perturbation_emb = get_embs(
909
  model,
910
  perturbation_batch,
@@ -916,7 +910,7 @@ class InSilicoPerturber:
916
  summary_stat=None,
917
  silent=True,
918
  )
919
-
920
  # remove special tokens and padding
921
  original_emb = original_emb[:, 1:-1, :]
922
  if self.perturb_type == "overexpress":
@@ -927,25 +921,9 @@ class InSilicoPerturber:
927
  perturbation_emb = full_perturbation_emb[
928
  :, 1 : max(perturbation_batch["length"]) - 1, :
929
  ]
930
-
931
- n_perturbation_genes = perturbation_emb.size()[1]
932
 
933
- # truncate the original embedding as necessary
934
- if self.perturb_type == "overexpress":
935
- def calc_perturbation_length(ids):
936
- if ids == [-100]:
937
- return 0
938
- else:
939
- return len(ids)
940
-
941
- max_tensor_size = max([length - calc_perturbation_length(ids) - 2 for length, ids in zip(minibatch["length"], indices_to_perturb)])
942
 
943
- max_n_overflow = max(minibatch["n_overflow"])
944
- if max_n_overflow > 0 and perturbation_emb.size()[1] < original_emb.size()[1]:
945
- original_emb = original_emb[:, 0 : perturbation_emb.size()[1], :]
946
- elif perturbation_emb.size()[1] < original_emb.size()[1]:
947
- original_emb = original_emb[:, 0:max_tensor_size, :]
948
-
949
  gene_cos_sims = pu.quant_cos_sims(
950
  perturbation_emb,
951
  original_emb,
 
40
  from collections import defaultdict
41
 
42
  import torch
43
+ from datasets import Dataset, disable_progress_bars
44
  from multiprocess import set_start_method
45
  from tqdm.auto import trange
46
 
 
48
  from . import perturber_utils as pu
49
  from .emb_extractor import get_embs
50
 
51
+ disable_progress_bars()
 
 
52
 
53
  logger = logging.getLogger(__name__)
54
 
 
84
  anchor_gene=None,
85
  model_type="Pretrained",
86
  num_classes=0,
87
+ emb_mode="cell",
88
  cell_emb_style="mean_pool",
89
  filter_data=None,
90
  cell_states_to_model=None,
 
794
  return example
795
 
796
  total_batch_length = len(filtered_input_data)
 
 
797
  if self.cell_states_to_model is None:
798
  cos_sims_dict = defaultdict(list)
799
  else:
 
878
  )
879
 
880
  ##### CLS and Gene Embedding Mode #####
881
+ elif self.emb_mode == "cls_and_gene":
882
  full_original_emb = get_embs(
883
  model,
884
  minibatch,
 
891
  silent=True,
892
  )
893
  indices_to_perturb = perturbation_batch["perturb_index"]
 
894
  # remove indices that were perturbed
895
  original_emb = pu.remove_perturbed_indices_set(
896
  full_original_emb,
 
899
  self.tokens_to_perturb,
900
  minibatch["length"],
901
  )
 
902
  full_perturbation_emb = get_embs(
903
  model,
904
  perturbation_batch,
 
910
  summary_stat=None,
911
  silent=True,
912
  )
913
+
914
  # remove special tokens and padding
915
  original_emb = original_emb[:, 1:-1, :]
916
  if self.perturb_type == "overexpress":
 
921
  perturbation_emb = full_perturbation_emb[
922
  :, 1 : max(perturbation_batch["length"]) - 1, :
923
  ]
 
 
924
 
925
+ n_perturbation_genes = perturbation_emb.size()[1]
 
 
 
 
 
 
 
 
926
 
 
 
 
 
 
 
927
  gene_cos_sims = pu.quant_cos_sims(
928
  perturbation_emb,
929
  original_emb,
geneformer/in_silico_perturber_stats.py CHANGED
@@ -640,16 +640,10 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
640
  cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
641
 
642
  # quantify number of detections of each gene
643
- if anchor_token is None:
644
- cos_sims_full_df["N_Detections"] = [
645
- n_detections(i, dict_list, "cell", anchor_token)
646
- for i in cos_sims_full_df["Gene"]
647
- ]
648
- else:
649
- cos_sims_full_df["N_Detections"] = [
650
- n_detections(i, dict_list, "gene", anchor_token)
651
- for i in cos_sims_full_df["Gene"]
652
- ]
653
 
654
  if combos == 0:
655
  cos_sims_full_df = cos_sims_full_df.sort_values(
 
640
  cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
641
 
642
  # quantify number of detections of each gene
643
+ cos_sims_full_df["N_Detections"] = [
644
+ n_detections(i, dict_list, "gene", anchor_token)
645
+ for i in cos_sims_full_df["Gene"]
646
+ ]
 
 
 
 
 
 
647
 
648
  if combos == 0:
649
  cos_sims_full_df = cos_sims_full_df.sort_values(
geneformer/mtl/data.py CHANGED
@@ -1,162 +1,150 @@
1
  import os
 
2
  from .collators import DataCollatorForMultitaskCellClassification
3
  from .imports import *
4
 
5
- def validate_columns(dataset, required_columns, dataset_type):
6
- """Ensures required columns are present in the dataset."""
7
- missing_columns = [col for col in required_columns if col not in dataset.column_names]
8
- if missing_columns:
9
- raise KeyError(
10
- f"Missing columns in {dataset_type} dataset: {missing_columns}. "
11
- f"Available columns: {dataset.column_names}"
12
- )
13
-
14
-
15
- def create_label_mappings(dataset, task_to_column):
16
- """Creates label mappings for the dataset."""
17
- task_label_mappings = {}
18
- num_labels_list = []
19
- for task, column in task_to_column.items():
20
- unique_values = sorted(set(dataset[column]))
21
- mapping = {label: idx for idx, label in enumerate(unique_values)}
22
- task_label_mappings[task] = mapping
23
- num_labels_list.append(len(unique_values))
24
- return task_label_mappings, num_labels_list
25
-
26
-
27
- def save_label_mappings(mappings, path):
28
- """Saves label mappings to a pickle file."""
29
- with open(path, "wb") as f:
30
- pickle.dump(mappings, f)
31
-
32
-
33
- def load_label_mappings(path):
34
- """Loads label mappings from a pickle file."""
35
- with open(path, "rb") as f:
36
- return pickle.load(f)
37
-
38
-
39
- def transform_dataset(dataset, task_to_column, task_label_mappings, config, is_test):
40
- """Transforms the dataset to the required format."""
41
- transformed_dataset = []
42
- cell_id_mapping = {}
43
-
44
- for idx, record in enumerate(dataset):
45
- transformed_record = {
46
- "input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
47
- "cell_id": idx, # Index-based cell ID
48
- }
49
-
50
- if not is_test:
51
- label_dict = {
52
- task: task_label_mappings[task][record[column]]
53
- for task, column in task_to_column.items()
54
- }
55
- else:
56
- label_dict = {task: -1 for task in config["task_names"]}
57
-
58
- transformed_record["label"] = label_dict
59
- transformed_dataset.append(transformed_record)
60
- cell_id_mapping[idx] = record.get("unique_cell_id", idx)
61
-
62
- return transformed_dataset, cell_id_mapping
63
-
64
 
65
  def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
66
- """Main function to load and preprocess data."""
67
  try:
68
  dataset = load_from_disk(dataset_path)
69
 
70
- # Setup task and column mappings
71
  task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
72
  task_to_column = dict(zip(task_names, config["task_columns"]))
73
  config["task_names"] = task_names
74
 
75
- label_mappings_path = os.path.join(
76
- config["results_dir"],
77
- f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
78
- )
79
-
80
  if not is_test:
81
- validate_columns(dataset, task_to_column.values(), dataset_type)
82
-
83
- # Create and save label mappings
84
- task_label_mappings, num_labels_list = create_label_mappings(dataset, task_to_column)
85
- save_label_mappings(task_label_mappings, label_mappings_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  else:
87
- # Load existing mappings for test data
88
- task_label_mappings = load_label_mappings(label_mappings_path)
89
- num_labels_list = [len(mapping) for mapping in task_label_mappings.values()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- # Transform dataset
92
- transformed_dataset, cell_id_mapping = transform_dataset(
93
- dataset, task_to_column, task_label_mappings, config, is_test
94
- )
95
 
96
- return transformed_dataset, cell_id_mapping, num_labels_list
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
98
  except KeyError as e:
99
- raise ValueError(f"Configuration error or dataset key missing: {e}")
100
  except Exception as e:
101
- raise RuntimeError(f"Error during data loading or preprocessing: {e}")
 
102
 
103
 
104
  def preload_and_process_data(config):
105
- """Preloads and preprocesses train and validation datasets."""
106
- # Process train data and save mappings
107
- train_data = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
108
-
109
- # Process validation data and save mappings
110
- val_data = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
111
-
112
- # Validate that the mappings match
113
- validate_label_mappings(config)
114
-
115
- return (*train_data[:2], *val_data) # Return train and val data along with mappings
116
-
 
 
117
 
118
- def validate_label_mappings(config):
119
- """Ensures train and validation label mappings are consistent."""
120
- train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
121
- val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
122
- train_mappings = load_label_mappings(train_mappings_path)
123
- val_mappings = load_label_mappings(val_mappings_path)
124
 
125
- for task_name in config["task_names"]:
126
- if train_mappings[task_name] != val_mappings[task_name]:
127
- raise ValueError(
128
- f"Mismatch in label mappings for task '{task_name}'.\n"
129
- f"Train Mapping: {train_mappings[task_name]}\n"
130
- f"Validation Mapping: {val_mappings[task_name]}"
131
- )
132
 
 
133
 
134
- def get_data_loader(preprocessed_dataset, batch_size):
135
- """Creates a DataLoader with optimal settings."""
136
- return DataLoader(
137
  preprocessed_dataset,
138
  batch_size=batch_size,
139
  shuffle=True,
140
- collate_fn=DataCollatorForMultitaskCellClassification(),
141
- num_workers=os.cpu_count(),
142
  pin_memory=True,
143
  )
 
144
 
145
 
146
  def preload_data(config):
147
- """Preprocesses train and validation data for trials."""
148
- train_loader = get_data_loader(*preload_and_process_data(config)[:2], config["batch_size"])
149
- val_loader = get_data_loader(*preload_and_process_data(config)[2:4], config["batch_size"])
150
  return train_loader, val_loader
151
 
152
 
153
  def load_and_preprocess_test_data(config):
154
- """Loads and preprocesses test data."""
 
 
155
  return load_and_preprocess_data(config["test_path"], config, is_test=True)
156
 
157
 
158
  def prepare_test_loader(config):
159
- """Prepares DataLoader for test data."""
160
- test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
 
 
 
 
161
  test_loader = get_data_loader(test_dataset, config["batch_size"])
162
  return test_loader, cell_id_mapping, num_labels_list
 
1
  import os
2
+
3
  from .collators import DataCollatorForMultitaskCellClassification
4
  from .imports import *
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
 
8
  try:
9
  dataset = load_from_disk(dataset_path)
10
 
 
11
  task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
12
  task_to_column = dict(zip(task_names, config["task_columns"]))
13
  config["task_names"] = task_names
14
 
 
 
 
 
 
15
  if not is_test:
16
+ available_columns = set(dataset.column_names)
17
+ for column in task_to_column.values():
18
+ if column not in available_columns:
19
+ raise KeyError(
20
+ f"Column {column} not found in the dataset. Available columns: {list(available_columns)}"
21
+ )
22
+
23
+ label_mappings = {}
24
+ task_label_mappings = {}
25
+ cell_id_mapping = {}
26
+ num_labels_list = []
27
+
28
+ # Load or create task label mappings
29
+ if not is_test:
30
+ for task, column in task_to_column.items():
31
+ unique_values = sorted(set(dataset[column])) # Ensure consistency
32
+ label_mappings[column] = {
33
+ label: idx for idx, label in enumerate(unique_values)
34
+ }
35
+ task_label_mappings[task] = label_mappings[column]
36
+ num_labels_list.append(len(unique_values))
37
+
38
+ # Print the mappings for each task with dataset type prefix
39
+ for task, mapping in task_label_mappings.items():
40
+ print(
41
+ f"{dataset_type.capitalize()} mapping for {task}: {mapping}"
42
+ ) # sanity check, for train/validation splits
43
+
44
+ # Save the task label mappings as a pickle file
45
+ with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f:
46
+ pickle.dump(task_label_mappings, f)
47
  else:
48
+ # Load task label mappings from pickle file for test data
49
+ with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
50
+ task_label_mappings = pickle.load(f)
51
+
52
+ # Infer num_labels_list from task_label_mappings
53
+ for task, mapping in task_label_mappings.items():
54
+ num_labels_list.append(len(mapping))
55
+
56
+ # Store unique cell IDs in a separate dictionary
57
+ for idx, record in enumerate(dataset):
58
+ cell_id = record.get("unique_cell_id", idx)
59
+ cell_id_mapping[idx] = cell_id
60
+
61
+ # Transform records to the desired format
62
+ transformed_dataset = []
63
+ for idx, record in enumerate(dataset):
64
+ transformed_record = {}
65
+ transformed_record["input_ids"] = torch.tensor(
66
+ record["input_ids"], dtype=torch.long
67
+ )
68
 
69
+ # Use index-based cell ID for internal tracking
70
+ transformed_record["cell_id"] = idx
 
 
71
 
72
+ if not is_test:
73
+ # Prepare labels
74
+ label_dict = {}
75
+ for task, column in task_to_column.items():
76
+ label_value = record[column]
77
+ label_index = task_label_mappings[task][label_value]
78
+ label_dict[task] = label_index
79
+ transformed_record["label"] = label_dict
80
+ else:
81
+ # Create dummy labels for test data
82
+ label_dict = {task: -1 for task in config["task_names"]}
83
+ transformed_record["label"] = label_dict
84
+
85
+ transformed_dataset.append(transformed_record)
86
 
87
+ return transformed_dataset, cell_id_mapping, num_labels_list
88
  except KeyError as e:
89
+ print(f"Missing configuration or dataset key: {e}")
90
  except Exception as e:
91
+ print(f"An error occurred while loading or preprocessing data: {e}")
92
+ return None, None, None
93
 
94
 
95
  def preload_and_process_data(config):
96
+ # Load and preprocess data once
97
+ train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(
98
+ config["train_path"], config, dataset_type="train"
99
+ )
100
+ val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(
101
+ config["val_path"], config, dataset_type="validation"
102
+ )
103
+ return (
104
+ train_dataset,
105
+ train_cell_id_mapping,
106
+ val_dataset,
107
+ val_cell_id_mapping,
108
+ num_labels_list,
109
+ )
110
 
 
 
 
 
 
 
111
 
112
+ def get_data_loader(preprocessed_dataset, batch_size):
113
+ nproc = os.cpu_count() ### I/O operations
 
 
 
 
 
114
 
115
+ data_collator = DataCollatorForMultitaskCellClassification()
116
 
117
+ loader = DataLoader(
 
 
118
  preprocessed_dataset,
119
  batch_size=batch_size,
120
  shuffle=True,
121
+ collate_fn=data_collator,
122
+ num_workers=nproc,
123
  pin_memory=True,
124
  )
125
+ return loader
126
 
127
 
128
  def preload_data(config):
129
+ # Preprocessing the data before the Optuna trials start
130
+ train_loader = get_data_loader("train", config)
131
+ val_loader = get_data_loader("val", config)
132
  return train_loader, val_loader
133
 
134
 
135
  def load_and_preprocess_test_data(config):
136
+ """
137
+ Load and preprocess test data, treating it as unlabeled.
138
+ """
139
  return load_and_preprocess_data(config["test_path"], config, is_test=True)
140
 
141
 
142
  def prepare_test_loader(config):
143
+ """
144
+ Prepare DataLoader for the test dataset.
145
+ """
146
+ test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(
147
+ config
148
+ )
149
  test_loader = get_data_loader(test_dataset, config["batch_size"])
150
  return test_loader, cell_id_mapping, num_labels_list
geneformer/pretrainer.py CHANGED
@@ -8,12 +8,13 @@ import math
8
  import pickle
9
  import warnings
10
  from enum import Enum
11
- from typing import Dict, List, Optional, Union
12
 
13
  import numpy as np
14
  import torch
15
  from datasets import Dataset
16
  from packaging import version
 
17
  from torch.utils.data.sampler import RandomSampler
18
  from transformers import (
19
  BatchEncoding,
@@ -23,8 +24,11 @@ from transformers import (
23
  )
24
  from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
25
  from transformers.trainer_pt_utils import (
 
 
26
  LengthGroupedSampler,
27
  )
 
28
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
29
  from transformers.utils.generic import _is_tensorflow, _is_torch
30
 
@@ -603,7 +607,7 @@ class GeneformerPretrainer(Trainer):
603
  )
604
  super().__init__(*args, **kwargs)
605
 
606
- # updated to not use distributed sampler since Trainer now distributes with accelerate
607
  def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
608
  if not isinstance(self.train_dataset, collections.abc.Sized):
609
  return None
@@ -626,15 +630,181 @@ class GeneformerPretrainer(Trainer):
626
  if self.tokenizer is not None
627
  else None
628
  )
629
- return LengthGroupedSampler(
 
630
  dataset=self.train_dataset,
631
  batch_size=self.args.train_batch_size,
632
  lengths=lengths,
633
  model_input_name=model_input_name,
634
  generator=generator,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
 
 
 
637
  else:
638
- if _is_torch_generator_available:
639
- return RandomSampler(self.train_dataset, generator=generator)
640
- return RandomSampler(self.train_dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import pickle
9
  import warnings
10
  from enum import Enum
11
+ from typing import Dict, Iterator, List, Optional, Union
12
 
13
  import numpy as np
14
  import torch
15
  from datasets import Dataset
16
  from packaging import version
17
+ from torch.utils.data.distributed import DistributedSampler
18
  from torch.utils.data.sampler import RandomSampler
19
  from transformers import (
20
  BatchEncoding,
 
24
  )
25
  from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
26
  from transformers.trainer_pt_utils import (
27
+ DistributedLengthGroupedSampler,
28
+ DistributedSamplerWithLoop,
29
  LengthGroupedSampler,
30
  )
31
+ from transformers.training_args import ParallelMode
32
  from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
33
  from transformers.utils.generic import _is_tensorflow, _is_torch
34
 
 
607
  )
608
  super().__init__(*args, **kwargs)
609
 
610
+ # modify LengthGroupedSampler to avoid dataset[length_column_name] hanging
611
  def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
612
  if not isinstance(self.train_dataset, collections.abc.Sized):
613
  return None
 
630
  if self.tokenizer is not None
631
  else None
632
  )
633
+ if self.args.world_size <= 1:
634
+ return LengthGroupedSampler(
635
  dataset=self.train_dataset,
636
  batch_size=self.args.train_batch_size,
637
  lengths=lengths,
638
  model_input_name=model_input_name,
639
  generator=generator,
640
+ )
641
+ else:
642
+ return CustomDistributedLengthGroupedSampler(
643
+ dataset=self.train_dataset,
644
+ batch_size=self.args.train_batch_size,
645
+ num_replicas=self.args.world_size,
646
+ rank=self.args.process_index,
647
+ lengths=lengths,
648
+ model_input_name=model_input_name,
649
+ seed=self.args.seed,
650
+ )
651
+
652
+ else:
653
+ if self.args.world_size <= 1:
654
+ if _is_torch_generator_available:
655
+ return RandomSampler(self.train_dataset, generator=generator)
656
+ return RandomSampler(self.train_dataset)
657
+ elif (
658
+ self.args.parallel_mode
659
+ in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
660
+ and not self.args.dataloader_drop_last
661
+ ):
662
+ # Use a loop for TPUs when drop_last is False to have all batches have the same size.
663
+ return DistributedSamplerWithLoop(
664
+ self.train_dataset,
665
+ batch_size=self.args.per_device_train_batch_size,
666
+ num_replicas=self.args.world_size,
667
+ rank=self.args.process_index,
668
+ seed=self.args.seed,
669
+ )
670
+ else:
671
+ return DistributedSampler(
672
+ self.train_dataset,
673
+ num_replicas=self.args.world_size,
674
+ rank=self.args.process_index,
675
+ seed=self.args.seed,
676
+ )
677
+
678
+
679
+ class CustomDistributedLengthGroupedSampler(DistributedLengthGroupedSampler):
680
+ r"""
681
+ Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
682
+ length while keeping a bit of randomness.
683
+ """
684
+
685
+ # Copied and adapted from PyTorch DistributedSampler.
686
+ def __init__(
687
+ self,
688
+ dataset: Dataset,
689
+ batch_size: int,
690
+ num_replicas: Optional[int] = None,
691
+ rank: Optional[int] = None,
692
+ seed: int = 0,
693
+ drop_last: bool = False,
694
+ lengths: Optional[List[int]] = None,
695
+ model_input_name: Optional[str] = None,
696
+ ):
697
+ if num_replicas is None:
698
+ if not dist.is_available():
699
+ raise RuntimeError("Requires distributed package to be available")
700
+ num_replicas = dist.get_world_size()
701
+ if rank is None:
702
+ if not dist.is_available():
703
+ raise RuntimeError("Requires distributed package to be available")
704
+ rank = dist.get_rank()
705
+ self.dataset = dataset
706
+ self.batch_size = batch_size
707
+ self.num_replicas = num_replicas
708
+ self.rank = rank
709
+ self.epoch = 0
710
+ self.drop_last = drop_last
711
+ # If the dataset length is evenly divisible by # of replicas, then there
712
+ # is no need to drop any data, since the dataset will be split equally.
713
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0:
714
+ # Split to nearest available length that is evenly divisible.
715
+ # This is to ensure each rank receives the same amount of data when
716
+ # using this Sampler.
717
+ self.num_samples = math.ceil(
718
+ (len(self.dataset) - self.num_replicas) / self.num_replicas
719
  )
720
+ else:
721
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
722
+ self.total_size = self.num_samples * self.num_replicas
723
+ self.seed = seed
724
+ self.model_input_name = (
725
+ model_input_name if model_input_name is not None else "input_ids"
726
+ )
727
+
728
+ if lengths is None:
729
+ print("Lengths is none - calculating lengths.")
730
+ if (
731
+ not (
732
+ isinstance(dataset[0], dict)
733
+ or isinstance(dataset[0], BatchEncoding)
734
+ )
735
+ or self.model_input_name not in dataset[0]
736
+ ):
737
+ raise ValueError(
738
+ "Can only automatically infer lengths for datasets whose items are dictionaries with an "
739
+ f"'{self.model_input_name}' key."
740
+ )
741
+ lengths = [len(feature[self.model_input_name]) for feature in dataset]
742
+ self.lengths = lengths
743
+
744
+ def __iter__(self) -> Iterator:
745
+ # Deterministically shuffle based on epoch and seed
746
+ g = torch.Generator()
747
+ g.manual_seed(self.seed + self.epoch)
748
+
749
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
750
 
751
+ if not self.drop_last:
752
+ # add extra samples to make it evenly divisible
753
+ indices += indices[: (self.total_size - len(indices))]
754
  else:
755
+ # remove tail of data to make it evenly divisible.
756
+ indices = indices[: self.total_size]
757
+ assert len(indices) == self.total_size
758
+
759
+ # subsample
760
+ indices = indices[self.rank : self.total_size : self.num_replicas]
761
+ assert len(indices) == self.num_samples
762
+
763
+ return iter(indices)
764
+
765
+
766
+ def get_length_grouped_indices(
767
+ lengths, batch_size, mega_batch_mult=None, generator=None
768
+ ):
769
+ """
770
+ Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
771
+ similar lengths. To do this, the indices are:
772
+
773
+ - randomly permuted
774
+ - grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
775
+ - sorted by length in each mega-batch
776
+
777
+ The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
778
+ maximum length placed first, so that an OOM happens sooner rather than later.
779
+ """
780
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
781
+ if mega_batch_mult is None:
782
+ # mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
783
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 1000)
784
+ # Just in case, for tiny datasets
785
+ if mega_batch_mult == 0:
786
+ mega_batch_mult = 1
787
+
788
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
789
+ indices = torch.randperm(len(lengths), generator=generator)
790
+ megabatch_size = mega_batch_mult * batch_size
791
+ megabatches = [
792
+ indices[i : i + megabatch_size].tolist()
793
+ for i in range(0, len(lengths), megabatch_size)
794
+ ]
795
+ megabatches = [
796
+ list(sorted(megabatch, key=lambda i: lengths[i], reverse=True))
797
+ for megabatch in megabatches
798
+ ]
799
+
800
+ # The rest is to get the biggest batch first.
801
+ # Since each megabatch is sorted by descending length, the longest element is the first
802
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
803
+ max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
804
+ # Switch to put the longest element in first position
805
+ megabatches[0][0], megabatches[max_idx][0] = (
806
+ megabatches[max_idx][0],
807
+ megabatches[0][0],
808
+ )
809
+
810
+ return [item for sublist in megabatches for item in sublist]
geneformer/tokenizer.py CHANGED
@@ -88,7 +88,6 @@ def sum_ensembl_ids(
88
  collapse_gene_ids,
89
  gene_mapping_dict,
90
  gene_token_dict,
91
- custom_attr_name_dict,
92
  file_format="loom",
93
  chunk_size=512,
94
  ):
@@ -104,45 +103,33 @@ def sum_ensembl_ids(
104
  assert (
105
  "ensembl_id_collapsed" not in data.ra.keys()
106
  ), "'ensembl_id_collapsed' column already exists in data.ra.keys()"
107
-
108
- assert (
109
- "n_counts" in data.ca.keys()
110
- ), "'n_counts' column missing from data.ca.keys()"
111
-
112
- if custom_attr_name_dict is not None:
113
- for label in custom_attr_name_dict:
114
- assert label in data.ca.keys(), f"Attribute `{label}` not present in dataset features"
115
-
116
- # Get the ensembl ids that exist in data
117
- ensembl_ids = data.ra.ensembl_id
118
  # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
119
  # Comparing to gene_token_dict here, would not perform any mapping steps
120
- if not collapse_gene_ids:
121
- ensembl_id_check = [
122
- gene for gene in ensembl_ids if gene in gene_token_dict.keys()
123
- ]
124
- if len(ensembl_id_check) == len(set(ensembl_id_check)):
 
125
  return data_directory
126
  else:
127
  raise ValueError("Error: data Ensembl IDs non-unique.")
128
-
129
- # Get the genes that exist in the mapping dictionary and the value of those genes
130
- genes_in_map_dict = [gene for gene in ensembl_ids if gene in gene_mapping_dict.keys()]
131
- vals_from_map_dict = [gene_mapping_dict.get(gene) for gene in genes_in_map_dict]
132
-
133
- # if the genes in the mapping dict and the value of those genes are of the same length,
134
- # simply return the mapped values
135
- if(len(set(genes_in_map_dict)) == len(set(vals_from_map_dict))):
136
- mapped_vals = [gene_mapping_dict.get(gene.upper()) for gene in data.ra["ensembl_id"]]
137
- data.ra["ensembl_id_collapsed"] = mapped_vals
138
  return data_directory
139
- # Genes need to be collapsed
140
  else:
141
  dedup_filename = data_directory.with_name(
142
  data_directory.stem + "__dedup.loom"
143
  )
144
- mapped_vals = [gene_mapping_dict.get(gene.upper()) for gene in data.ra["ensembl_id"]]
145
- data.ra["ensembl_id_collapsed"] = mapped_vals
146
  dup_genes = [
147
  idx
148
  for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
@@ -216,41 +203,33 @@ def sum_ensembl_ids(
216
  assert (
217
  "ensembl_id_collapsed" not in data.var.columns
218
  ), "'ensembl_id_collapsed' column already exists in data.var"
219
- assert (
220
- "n_counts" in data.obs.columns
221
- ), "'n_counts' column missing from data.obs"
222
-
223
- if custom_attr_name_dict is not None:
224
- for label in custom_attr_name_dict:
225
- assert label in data.obs.columns, f"Attribute `{label}` not present in data.obs"
226
 
227
-
228
- # Get the ensembl ids that exist in data
229
- ensembl_ids = data.var.ensembl_id
230
  # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
231
  # Comparing to gene_token_dict here, would not perform any mapping steps
232
- if not collapse_gene_ids:
233
- ensembl_id_check = [
234
- gene for gene in ensembl_ids if gene in gene_token_dict.keys()
235
- ]
236
- if len(ensembl_id_check) == len(set(ensembl_id_check)):
237
- return data_directory
 
238
  else:
239
  raise ValueError("Error: data Ensembl IDs non-unique.")
240
 
241
- # Get the genes that exist in the mapping dictionary and the value of those genes
242
- genes_in_map_dict = [gene for gene in ensembl_ids if gene in gene_mapping_dict.keys()]
243
- vals_from_map_dict = [gene_mapping_dict.get(gene) for gene in genes_in_map_dict]
244
-
245
- # if the genes in the mapping dict and the value of those genes are of the same length,
246
- # simply return the mapped values
247
- if(len(set(genes_in_map_dict)) == len(set(vals_from_map_dict))):
248
- data.var["ensembl_id_collapsed"] = data.var.ensembl_id.str.upper().map(gene_mapping_dict)
 
249
  return data
250
- # Genes need to be collapsed
251
  else:
252
- data.var["ensembl_id_collapsed"] = data.var.ensembl_id.str.upper().map(gene_mapping_dict)
253
- data.var_names = data.var["ensembl_id_collapsed"]
254
  data = data[:, ~data.var.index.isna()]
255
  dup_genes = [
256
  idx for idx, count in Counter(data.var_names).items() if count > 1
@@ -476,7 +455,6 @@ class TranscriptomeTokenizer:
476
  self.collapse_gene_ids,
477
  self.gene_mapping_dict,
478
  self.gene_token_dict,
479
- self.custom_attr_name_dict,
480
  file_format="h5ad",
481
  chunk_size=self.chunk_size,
482
  )
@@ -553,7 +531,6 @@ class TranscriptomeTokenizer:
553
  self.collapse_gene_ids,
554
  self.gene_mapping_dict,
555
  self.gene_token_dict,
556
- self.custom_attr_name_dict,
557
  file_format="loom",
558
  chunk_size=self.chunk_size,
559
  )
 
88
  collapse_gene_ids,
89
  gene_mapping_dict,
90
  gene_token_dict,
 
91
  file_format="loom",
92
  chunk_size=512,
93
  ):
 
103
  assert (
104
  "ensembl_id_collapsed" not in data.ra.keys()
105
  ), "'ensembl_id_collapsed' column already exists in data.ra.keys()"
 
 
 
 
 
 
 
 
 
 
 
106
  # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
107
  # Comparing to gene_token_dict here, would not perform any mapping steps
108
+ gene_ids_in_dict = [
109
+ gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()
110
+ ]
111
+ if collapse_gene_ids is False:
112
+
113
+ if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
114
  return data_directory
115
  else:
116
  raise ValueError("Error: data Ensembl IDs non-unique.")
117
+
118
+ gene_ids_collapsed = [
119
+ gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id
120
+ ]
121
+ gene_ids_collapsed_in_dict = [
122
+ gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
123
+ ]
124
+
125
+ if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
126
+ data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
127
  return data_directory
 
128
  else:
129
  dedup_filename = data_directory.with_name(
130
  data_directory.stem + "__dedup.loom"
131
  )
132
+ data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
 
133
  dup_genes = [
134
  idx
135
  for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
 
203
  assert (
204
  "ensembl_id_collapsed" not in data.var.columns
205
  ), "'ensembl_id_collapsed' column already exists in data.var"
 
 
 
 
 
 
 
206
 
 
 
 
207
  # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
208
  # Comparing to gene_token_dict here, would not perform any mapping steps
209
+ gene_ids_in_dict = [
210
+ gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()
211
+ ]
212
+ if collapse_gene_ids is False:
213
+
214
+ if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
215
+ return data
216
  else:
217
  raise ValueError("Error: data Ensembl IDs non-unique.")
218
 
219
+ # Check for when if collapse_gene_ids is True
220
+ gene_ids_collapsed = [
221
+ gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id
222
+ ]
223
+ gene_ids_collapsed_in_dict = [
224
+ gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
225
+ ]
226
+ if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
227
+ data.var["ensembl_id_collapsed"] = data.var.ensembl_id.map(gene_mapping_dict)
228
  return data
229
+
230
  else:
231
+ data.var["ensembl_id_collapsed"] = gene_ids_collapsed
232
+ data.var_names = gene_ids_collapsed
233
  data = data[:, ~data.var.index.isna()]
234
  dup_genes = [
235
  idx for idx, count in Counter(data.var_names).items() if count > 1
 
455
  self.collapse_gene_ids,
456
  self.gene_mapping_dict,
457
  self.gene_token_dict,
 
458
  file_format="h5ad",
459
  chunk_size=self.chunk_size,
460
  )
 
531
  self.collapse_gene_ids,
532
  self.gene_mapping_dict,
533
  self.gene_token_dict,
 
534
  file_format="loom",
535
  chunk_size=self.chunk_size,
536
  )
requirements.txt CHANGED
@@ -22,4 +22,4 @@ tdigest>=0.5.2
22
  tensorboard>=2.15
23
  torch>=2.0.1
24
  tqdm>=4.65
25
- transformers>=4.40
 
22
  tensorboard>=2.15
23
  torch>=2.0.1
24
  tqdm>=4.65
25
+ transformers>=4.28