Upload 15 files
#267
by
tigerdeF
- opened
- .pre-commit-config.yaml +26 -0
- geneformer/cell_classifier.py +874 -0
- geneformer/gene_classifier.py +935 -0
- geneformer/modular_classifier_usage.md +156 -0
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# See https://pre-commit.com for more information
|
2 |
+
# See https://pre-commit.com/hooks.html for more hooks
|
3 |
+
repos:
|
4 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
5 |
+
rev: v3.2.0
|
6 |
+
hooks:
|
7 |
+
- id: trailing-whitespace
|
8 |
+
- id: end-of-file-fixer
|
9 |
+
- id: check-yaml
|
10 |
+
- id: check-added-large-files
|
11 |
+
- id: check-merge-conflict
|
12 |
+
- id: mixed-line-ending
|
13 |
+
- id: check-docstring-first
|
14 |
+
- repo: https://github.com/pycqa/isort
|
15 |
+
rev: 5.12.0
|
16 |
+
hooks:
|
17 |
+
- id: isort
|
18 |
+
args: ["--profile", "black"]
|
19 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
20 |
+
# Ruff version.
|
21 |
+
rev: v0.1.4
|
22 |
+
hooks:
|
23 |
+
# Run the Ruff linter.
|
24 |
+
- id: ruff
|
25 |
+
# Run the Ruff formatter.
|
26 |
+
- id: ruff-format
|
geneformer/cell_classifier.py
ADDED
@@ -0,0 +1,874 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer cell classifier.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
from geneformer import classify_cells
|
6 |
+
classify_cells(
|
7 |
+
token_set=Path("geneformer/token_dictionary.pkl"),
|
8 |
+
median_set=Path("geneformer/gene_median_dictionary.pkl"),
|
9 |
+
pretrained_model=".",
|
10 |
+
dataset="Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/",
|
11 |
+
dataset_split=None,
|
12 |
+
filter_cells=0.005,
|
13 |
+
epochs=1,
|
14 |
+
cpu_cores=os.cpu_count(),
|
15 |
+
geneformer_batch_size=12,
|
16 |
+
optimizer="adamw",
|
17 |
+
max_lr=5e-5,
|
18 |
+
num_gpus=torch.cuda.device_count(),
|
19 |
+
max_input_size=2**11,
|
20 |
+
lr_schedule_fn="linear",
|
21 |
+
warmup_steps=500,
|
22 |
+
freeze_layers=0,
|
23 |
+
emb_extract=False,
|
24 |
+
max_cells=1000,
|
25 |
+
emb_layer=0,
|
26 |
+
emb_filter=None,
|
27 |
+
emb_dir="embeddings",
|
28 |
+
overwrite=True,
|
29 |
+
label="cell_type",
|
30 |
+
data_filter=None,
|
31 |
+
forward_batch=200,
|
32 |
+
model_location=None,
|
33 |
+
skip_training=False,
|
34 |
+
sample_data=1,
|
35 |
+
inference=False,
|
36 |
+
optimize_hyperparameters=False,
|
37 |
+
output_dir=None,
|
38 |
+
)
|
39 |
+
"""
|
40 |
+
|
41 |
+
import ast
|
42 |
+
import datetime
|
43 |
+
import os
|
44 |
+
import pickle
|
45 |
+
import random
|
46 |
+
import subprocess
|
47 |
+
from collections import Counter
|
48 |
+
from pathlib import Path
|
49 |
+
|
50 |
+
import numpy as np
|
51 |
+
import seaborn as sns
|
52 |
+
import torch
|
53 |
+
import torch.nn.functional as F
|
54 |
+
from datasets import load_from_disk
|
55 |
+
from matplotlib import pyplot as plt
|
56 |
+
from ray import tune
|
57 |
+
from ray.tune.search.hyperopt import HyperOptSearch
|
58 |
+
from sklearn.metrics import accuracy_score
|
59 |
+
from sklearn.metrics import auc as precision_auc
|
60 |
+
from sklearn.metrics import f1_score, precision_recall_curve, roc_auc_score, roc_curve
|
61 |
+
from transformers import BertForSequenceClassification, Trainer
|
62 |
+
from transformers.training_args import TrainingArguments
|
63 |
+
|
64 |
+
from geneformer import DataCollatorForCellClassification, EmbExtractor
|
65 |
+
|
66 |
+
sns.set()
|
67 |
+
|
68 |
+
# Properly sets up NCCV environment
|
69 |
+
GPU_NUMBER = [i for i in range(torch.cuda.device_count())]
|
70 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
71 |
+
os.environ["NCCL_DEBUG"] = "INFO"
|
72 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
73 |
+
|
74 |
+
|
75 |
+
# Function for generating an ROC curve from data
|
76 |
+
def ROC(prediction, truth, type="GeneFormer", label=""):
|
77 |
+
fpr, tpr, _ = roc_curve(truth, prediction[:, 1])
|
78 |
+
auc = roc_auc_score(truth, prediction[:, 1])
|
79 |
+
print(f"{type} AUC: {auc}")
|
80 |
+
plt.plot(fpr, tpr, label="AUC=" + str(auc))
|
81 |
+
plt.ylabel("True Positive Rate")
|
82 |
+
plt.xlabel("False Positive Rate")
|
83 |
+
plt.title(f"{label} ROC Curve")
|
84 |
+
plt.legend(loc=4)
|
85 |
+
plt.savefig("ROC.png")
|
86 |
+
|
87 |
+
return tpr, fpr, auc
|
88 |
+
|
89 |
+
|
90 |
+
# Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
|
91 |
+
def similarity(tensor1, tensor2, cosine=False):
|
92 |
+
if cosine is False:
|
93 |
+
if tensor1.ndimension() > 1:
|
94 |
+
tensor1 = tensor1.view(1, -1)
|
95 |
+
if tensor2.ndimension() > 1:
|
96 |
+
tensor2 = tensor2.view(1, -1)
|
97 |
+
dot_product = torch.matmul(tensor1, tensor2)
|
98 |
+
norm_tensor1 = torch.norm(tensor1)
|
99 |
+
norm_tensor2 = torch.norm(tensor2)
|
100 |
+
epsilon = 1e-8
|
101 |
+
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
102 |
+
similarity = (similarity.item() + 1) / 2
|
103 |
+
else:
|
104 |
+
if tensor1.shape != tensor2.shape:
|
105 |
+
raise ValueError("Input tensors must have the same shape.")
|
106 |
+
|
107 |
+
# Compute cosine similarity using PyTorch's dot product function
|
108 |
+
dot_product = torch.dot(tensor1, tensor2)
|
109 |
+
norm_tensor1 = torch.norm(tensor1)
|
110 |
+
norm_tensor2 = torch.norm(tensor2)
|
111 |
+
|
112 |
+
# Avoid division by zero by adding a small epsilon
|
113 |
+
epsilon = 1e-8
|
114 |
+
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
115 |
+
|
116 |
+
return similarity.item()
|
117 |
+
|
118 |
+
|
119 |
+
# Plots heatmap between different classes/labels
|
120 |
+
def plot_similarity_heatmap(similarities):
|
121 |
+
classes = list(similarities.keys())
|
122 |
+
classlen = len(classes)
|
123 |
+
arr = np.zeros((classlen, classlen))
|
124 |
+
for i, c in enumerate(classes):
|
125 |
+
for j, cc in enumerate(classes):
|
126 |
+
if cc == c:
|
127 |
+
val = 1.0
|
128 |
+
else:
|
129 |
+
val = similarities[c][cc]
|
130 |
+
arr[i][j] = val
|
131 |
+
|
132 |
+
plt.figure(figsize=(8, 6))
|
133 |
+
plt.imshow(arr, cmap="inferno", vmin=0, vmax=1)
|
134 |
+
plt.colorbar()
|
135 |
+
plt.xticks(np.arange(classlen), classes, rotation=45, ha="right")
|
136 |
+
plt.yticks(np.arange(classlen), classes)
|
137 |
+
plt.title("Similarity Heatmap")
|
138 |
+
plt.savefig("similarity_heatmap.png")
|
139 |
+
|
140 |
+
|
141 |
+
def classify_cells(
|
142 |
+
token_set=Path("./token_dictionary.pkl"),
|
143 |
+
median_set=Path("./gene_median_dictionary.pkl"),
|
144 |
+
pretrained_model="../",
|
145 |
+
dataset="Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/",
|
146 |
+
dataset_split=None,
|
147 |
+
filter_cells=0.005,
|
148 |
+
epochs=1,
|
149 |
+
cpu_cores=os.cpu_count(),
|
150 |
+
training_batch_size=12,
|
151 |
+
optimizer="adamw",
|
152 |
+
max_lr=5e-5,
|
153 |
+
num_gpus=torch.cuda.device_count(),
|
154 |
+
max_input_size=2**11,
|
155 |
+
lr_schedule_fn="linear",
|
156 |
+
warmup_steps=500,
|
157 |
+
freeze_layers=0,
|
158 |
+
emb_extract=False,
|
159 |
+
max_cells=None,
|
160 |
+
emb_layer=-1,
|
161 |
+
emb_filter=None,
|
162 |
+
emb_dir="embeddings",
|
163 |
+
overwrite=False,
|
164 |
+
label="cell_type",
|
165 |
+
data_filter=None,
|
166 |
+
inference_batch_size=200,
|
167 |
+
finetuned_model=None,
|
168 |
+
skip_training=False,
|
169 |
+
sample_data=1,
|
170 |
+
inference=False,
|
171 |
+
optimize_hyperparameters=True,
|
172 |
+
output_dir=None,
|
173 |
+
):
|
174 |
+
"""
|
175 |
+
Primary Parameters
|
176 |
+
-------------------
|
177 |
+
dataset: path
|
178 |
+
Path to fine-tuning dataset for training
|
179 |
+
|
180 |
+
finetuned_model: path
|
181 |
+
Path to location of fine-tuned model to use for inference and embedding extraction
|
182 |
+
|
183 |
+
pretrained_model: path
|
184 |
+
Path to pretrained Geneformer model
|
185 |
+
|
186 |
+
inference: bool
|
187 |
+
Indicates whether to perform inference and return a list of similarities. Defaults to False.
|
188 |
+
|
189 |
+
skip_training: bool
|
190 |
+
Indicates whether to skip training the model. Defaults to False.
|
191 |
+
|
192 |
+
emb_extract: bool
|
193 |
+
Indicates whether to extract embeddings and calculate similarities. Defaults to True.
|
194 |
+
|
195 |
+
optimize_hyperparameters: bool
|
196 |
+
Indicates whether to optimize model hyperparamters. Defaults to False.
|
197 |
+
|
198 |
+
|
199 |
+
Customization Parameters
|
200 |
+
-------------------
|
201 |
+
|
202 |
+
dataset_split: str
|
203 |
+
Indicates how the dataset should be partitioned (if at all), and what ID should be used for partitioning
|
204 |
+
|
205 |
+
data_filter: list
|
206 |
+
(For embeddings and inference) Runs analysis on subsets of the dataset based on the ID defined by dataset_split
|
207 |
+
|
208 |
+
label: str
|
209 |
+
Feature to read as a classification label.
|
210 |
+
|
211 |
+
emb_layer: int
|
212 |
+
What layer embeddings should be extracted and compared.
|
213 |
+
|
214 |
+
emb_filter: ['cell1', 'cell2'...]
|
215 |
+
Allows user to narrow down range of cells that embeddings will be extracted from.
|
216 |
+
|
217 |
+
max_cells: int
|
218 |
+
Max number of cells to use for embedding extraction.
|
219 |
+
|
220 |
+
freeze_layers: int
|
221 |
+
Number of layers that should be frozen during fine-tuning.
|
222 |
+
|
223 |
+
sample_data: float
|
224 |
+
Proportion of the dataset that should be used.
|
225 |
+
|
226 |
+
"""
|
227 |
+
|
228 |
+
dataset_list = []
|
229 |
+
evalset_list = []
|
230 |
+
split_list = []
|
231 |
+
target_dict_list = []
|
232 |
+
|
233 |
+
train_dataset = load_from_disk(dataset)
|
234 |
+
num_samples = int(len(train_dataset) * sample_data)
|
235 |
+
random_indices = random.sample(range(len(train_dataset)), num_samples)
|
236 |
+
train_dataset = train_dataset.select(random_indices)
|
237 |
+
|
238 |
+
sample = int(sample_data * len(train_dataset))
|
239 |
+
sample_indices = random.sample(range(len(train_dataset)), sample)
|
240 |
+
train_dataset = train_dataset.select(sample_indices)
|
241 |
+
|
242 |
+
def if_not_rare_cell_state(example):
|
243 |
+
return example[label] in cells_to_keep
|
244 |
+
|
245 |
+
# change labels to numerical ids
|
246 |
+
def classes_to_ids(example):
|
247 |
+
example["label"] = target_name_id_dict[example["label"]]
|
248 |
+
return example
|
249 |
+
|
250 |
+
def if_trained_label(example):
|
251 |
+
return example["label"] in trained_labels
|
252 |
+
|
253 |
+
if skip_training is not True:
|
254 |
+
|
255 |
+
def compute_metrics(pred):
|
256 |
+
labels = pred.label_ids
|
257 |
+
preds = pred.predictions.argmax(-1)
|
258 |
+
# calculate accuracy and macro f1 using sklearn's function
|
259 |
+
acc = accuracy_score(labels, preds)
|
260 |
+
macro_f1 = f1_score(labels, preds, average="macro")
|
261 |
+
return {"accuracy": acc, "macro_f1": macro_f1}
|
262 |
+
|
263 |
+
# Defines custom exceptions for collecting labels (default excluded)
|
264 |
+
excep = {"bone_marrow": "immune"}
|
265 |
+
|
266 |
+
if dataset_split is not None:
|
267 |
+
if data_filter is not None:
|
268 |
+
split_iter = [data_filter]
|
269 |
+
else:
|
270 |
+
split_iter = Counter(train_dataset[dataset_split]).keys()
|
271 |
+
for lab in split_iter:
|
272 |
+
# collect list of tissues for fine-tuning (immune and bone marrow are included together)
|
273 |
+
if lab in list(excep.keys()):
|
274 |
+
continue
|
275 |
+
elif lab == list(excep.values()):
|
276 |
+
split_ids = [excep.keys(), excep.values()]
|
277 |
+
split_list += [excep.values()]
|
278 |
+
else:
|
279 |
+
split_ids = [lab]
|
280 |
+
split_list += [lab]
|
281 |
+
|
282 |
+
# filter datasets for given organ
|
283 |
+
def if_label(example):
|
284 |
+
return example[dataset_split] == lab
|
285 |
+
|
286 |
+
trainset_label = train_dataset.filter(if_label, num_proc=cpu_cores)
|
287 |
+
label_counter = Counter(trainset_label[label])
|
288 |
+
total_cells = sum(label_counter.values())
|
289 |
+
|
290 |
+
# excludes cells with a low proportion in the dataset
|
291 |
+
cells_to_keep = [
|
292 |
+
k
|
293 |
+
for k, v in label_counter.items()
|
294 |
+
if v > (filter_cells * total_cells)
|
295 |
+
]
|
296 |
+
trainset_label_subset = trainset_label.filter(
|
297 |
+
if_not_rare_cell_state, num_proc=cpu_cores
|
298 |
+
)
|
299 |
+
|
300 |
+
# shuffle datasets and rename columns
|
301 |
+
trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
|
302 |
+
trainset_label_shuffled = trainset_label_shuffled.rename_column(
|
303 |
+
label, "label"
|
304 |
+
)
|
305 |
+
trainset_label_shuffled = trainset_label_shuffled.remove_columns(
|
306 |
+
dataset_split
|
307 |
+
)
|
308 |
+
|
309 |
+
# create dictionary of cell types : label ids
|
310 |
+
target_names = list(Counter(trainset_label_shuffled["label"]).keys())
|
311 |
+
target_name_id_dict = dict(
|
312 |
+
zip(target_names, [i for i in range(len(target_names))])
|
313 |
+
)
|
314 |
+
target_dict_list += [target_name_id_dict]
|
315 |
+
|
316 |
+
labeled_trainset = trainset_label_shuffled.map(
|
317 |
+
classes_to_ids, num_proc=cpu_cores
|
318 |
+
)
|
319 |
+
|
320 |
+
# create 80/20 train/eval splits
|
321 |
+
labeled_train_split = trainset_label_shuffled.select(
|
322 |
+
[i for i in range(0, round(len(labeled_trainset) * 0.8))]
|
323 |
+
)
|
324 |
+
labeled_eval_split = trainset_label_shuffled.select(
|
325 |
+
[
|
326 |
+
i
|
327 |
+
for i in range(
|
328 |
+
round(len(labeled_trainset) * 0.8), len(labeled_trainset)
|
329 |
+
)
|
330 |
+
]
|
331 |
+
)
|
332 |
+
|
333 |
+
# filter dataset for cell types in corresponding training set
|
334 |
+
trained_labels = list(Counter(labeled_train_split["label"]).keys())
|
335 |
+
|
336 |
+
labeled_eval_split_subset = labeled_eval_split.filter(
|
337 |
+
if_trained_label, num_proc=cpu_cores
|
338 |
+
)
|
339 |
+
|
340 |
+
dataset_list += [labeled_train_split]
|
341 |
+
evalset_list += [labeled_eval_split_subset]
|
342 |
+
|
343 |
+
trainset_dict = dict(zip(split_list, dataset_list))
|
344 |
+
traintargetdict_dict = dict(zip(split_list, target_dict_list))
|
345 |
+
evalset_dict = dict(zip(split_list, evalset_list))
|
346 |
+
|
347 |
+
for lab in split_list:
|
348 |
+
label_trainset = trainset_dict[lab]
|
349 |
+
label_evalset = evalset_dict[lab]
|
350 |
+
label_dict = traintargetdict_dict[lab]
|
351 |
+
|
352 |
+
# set logging steps
|
353 |
+
logging_steps = round(len(label_trainset) / training_batch_size / 10)
|
354 |
+
if logging_steps == 0:
|
355 |
+
logging_steps = 1
|
356 |
+
|
357 |
+
# load pretrained model
|
358 |
+
model = BertForSequenceClassification.from_pretrained(
|
359 |
+
pretrained_model,
|
360 |
+
num_labels=len(label_dict.keys()),
|
361 |
+
output_attentions=False,
|
362 |
+
output_hidden_states=False,
|
363 |
+
).to(device)
|
364 |
+
|
365 |
+
# define output directory path
|
366 |
+
current_date = datetime.datetime.now()
|
367 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
368 |
+
|
369 |
+
if output_dir is None:
|
370 |
+
output_dir = f"{datestamp}_geneformer_CellClassifier_{lab}_L{max_input_size}_B{training_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
|
371 |
+
|
372 |
+
# ensure not overwriting previously saved model
|
373 |
+
saved_model_test = os.path.join(output_dir, "pytorch_model.bin")
|
374 |
+
|
375 |
+
if os.path.isfile(saved_model_test) is True and overwrite is False:
|
376 |
+
raise Exception("Model already saved to this directory.")
|
377 |
+
|
378 |
+
# make output directory
|
379 |
+
subprocess.call(f"mkdir -p {output_dir}", shell=True)
|
380 |
+
|
381 |
+
# set training arguments
|
382 |
+
training_args = {
|
383 |
+
"learning_rate": max_lr,
|
384 |
+
"do_train": True,
|
385 |
+
"do_eval": True,
|
386 |
+
"evaluation_strategy": "epoch",
|
387 |
+
"save_strategy": "epoch",
|
388 |
+
"logging_steps": logging_steps,
|
389 |
+
"group_by_length": True,
|
390 |
+
"length_column_name": "length",
|
391 |
+
"disable_tqdm": False,
|
392 |
+
"lr_scheduler_type": lr_schedule_fn,
|
393 |
+
"warmup_steps": warmup_steps,
|
394 |
+
"weight_decay": 0.001,
|
395 |
+
"per_device_train_batch_size": training_batch_size,
|
396 |
+
"per_device_eval_batch_size": training_batch_size,
|
397 |
+
"num_train_epochs": epochs,
|
398 |
+
"load_best_model_at_end": True,
|
399 |
+
"output_dir": output_dir,
|
400 |
+
}
|
401 |
+
|
402 |
+
training_args_init = TrainingArguments(**training_args)
|
403 |
+
true_labels = label_evalset["label"]
|
404 |
+
|
405 |
+
if optimize_hyperparameters is False:
|
406 |
+
# create the trainer
|
407 |
+
trainer = Trainer(
|
408 |
+
model=model,
|
409 |
+
args=training_args_init,
|
410 |
+
data_collator=DataCollatorForCellClassification(),
|
411 |
+
train_dataset=label_trainset,
|
412 |
+
eval_dataset=label_evalset,
|
413 |
+
compute_metrics=compute_metrics,
|
414 |
+
)
|
415 |
+
|
416 |
+
# train the cell type classifier
|
417 |
+
trainer.train()
|
418 |
+
predictions = trainer.predict(label_evalset)
|
419 |
+
print(
|
420 |
+
f'accuracy: {accuracy_score(predictions.argmax(), label_evalset["labels"])}'
|
421 |
+
)
|
422 |
+
|
423 |
+
tpr, fpr, auc = ROC(predictions.predictions, true_labels)
|
424 |
+
|
425 |
+
metrics = compute_metrics(predictions)
|
426 |
+
with open(f"{output_dir}predictions.pickle", "wb") as fp:
|
427 |
+
pickle.dump(predictions, fp)
|
428 |
+
|
429 |
+
trainer.save_metrics("eval", predictions.metrics)
|
430 |
+
|
431 |
+
with open(f"{output_dir}/targets.txt", "w") as f:
|
432 |
+
if len(target_dict_list) == 1:
|
433 |
+
f.write(str(target_dict_list[0]))
|
434 |
+
else:
|
435 |
+
f.write(str(target_dict_list))
|
436 |
+
|
437 |
+
try:
|
438 |
+
precision, recall, _ = precision_recall_curve(
|
439 |
+
true_labels, predictions.predictions[:, 1]
|
440 |
+
)
|
441 |
+
pr_auc = precision_auc(recall, precision)
|
442 |
+
|
443 |
+
print(f"AUC: {pr_auc}")
|
444 |
+
return recall, precision, pr_auc
|
445 |
+
except:
|
446 |
+
pass
|
447 |
+
|
448 |
+
trainer.save_model(output_dir)
|
449 |
+
else:
|
450 |
+
|
451 |
+
def model_init():
|
452 |
+
model = BertForSequenceClassification.from_pretrained(
|
453 |
+
pretrained_model,
|
454 |
+
num_labels=len(label_dict.keys()),
|
455 |
+
output_attentions=False,
|
456 |
+
output_hidden_states=False,
|
457 |
+
)
|
458 |
+
if freeze_layers is not None:
|
459 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
460 |
+
for module in modules_to_freeze:
|
461 |
+
for param in module.parameters():
|
462 |
+
param.requires_grad = False
|
463 |
+
model = model.to(device)
|
464 |
+
return model
|
465 |
+
|
466 |
+
trainer = Trainer(
|
467 |
+
model_init=model_init,
|
468 |
+
args=training_args_init,
|
469 |
+
data_collator=DataCollatorForCellClassification(),
|
470 |
+
train_dataset=label_trainset,
|
471 |
+
eval_dataset=label_evalset,
|
472 |
+
compute_metrics=compute_metrics,
|
473 |
+
)
|
474 |
+
# specify raytune hyperparameter search space
|
475 |
+
ray_config = {
|
476 |
+
"num_train_epochs": tune.choice([epochs]),
|
477 |
+
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
478 |
+
"weight_decay": tune.uniform(0.0, 0.3),
|
479 |
+
"lr_scheduler_type": tune.choice(
|
480 |
+
["linear", "cosine", "polynomial"]
|
481 |
+
),
|
482 |
+
"warmup_steps": tune.uniform(100, 2000),
|
483 |
+
"seed": tune.uniform(0, 100),
|
484 |
+
"per_device_train_batch_size": tune.choice(
|
485 |
+
[training_batch_size]
|
486 |
+
),
|
487 |
+
}
|
488 |
+
|
489 |
+
hyperopt_search = HyperOptSearch(metric="eval_accuracy", mode="max")
|
490 |
+
|
491 |
+
if torch.device == "cuda":
|
492 |
+
resources_per_trial = ({"cpu": 8, "gpu": 1},)
|
493 |
+
else:
|
494 |
+
resources_per_trial = {"cpu": 8}
|
495 |
+
|
496 |
+
# optimize hyperparameters
|
497 |
+
best_trial = trainer.hyperparameter_search(
|
498 |
+
direction="maximize",
|
499 |
+
backend="ray",
|
500 |
+
resources_per_trial=resources_per_trial,
|
501 |
+
hp_space=lambda _: ray_config,
|
502 |
+
search_alg=hyperopt_search,
|
503 |
+
n_trials=100, # number of trials
|
504 |
+
progress_reporter=tune.CLIReporter(
|
505 |
+
max_report_frequency=600,
|
506 |
+
sort_by_metric=True,
|
507 |
+
max_progress_rows=100,
|
508 |
+
mode="max",
|
509 |
+
metric="eval_accuracy",
|
510 |
+
metric_columns=["loss", "eval_loss", "eval_accuracy"],
|
511 |
+
),
|
512 |
+
)
|
513 |
+
best_hyperparameters = best_trial.hyperparameters
|
514 |
+
|
515 |
+
print("Best Hyperparameters:")
|
516 |
+
print(best_hyperparameters)
|
517 |
+
|
518 |
+
else:
|
519 |
+
trainset_label = train_dataset
|
520 |
+
label_counter = Counter(trainset_label[label])
|
521 |
+
total_cells = sum(label_counter.values())
|
522 |
+
|
523 |
+
# Excludes cells with a low proportion in the dataset
|
524 |
+
cells_to_keep = [
|
525 |
+
k for k, v in label_counter.items() if v > (filter_cells * total_cells)
|
526 |
+
]
|
527 |
+
trainset_label_subset = trainset_label.filter(
|
528 |
+
if_not_rare_cell_state, num_proc=cpu_cores
|
529 |
+
)
|
530 |
+
|
531 |
+
# shuffle datasets and rename columns
|
532 |
+
trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
|
533 |
+
trainset_label_shuffled = trainset_label_shuffled.rename_column(
|
534 |
+
label, "label"
|
535 |
+
)
|
536 |
+
|
537 |
+
# create dictionary of cell types : label ids
|
538 |
+
target_names = list(Counter(trainset_label_shuffled["label"]).keys())
|
539 |
+
target_name_id_dict = dict(
|
540 |
+
zip(target_names, [i for i in range(len(target_names))])
|
541 |
+
)
|
542 |
+
target_dict_list = target_name_id_dict
|
543 |
+
|
544 |
+
labeled_trainset = trainset_label_shuffled.map(
|
545 |
+
classes_to_ids, num_proc=cpu_cores
|
546 |
+
)
|
547 |
+
|
548 |
+
# create 80/20 train/eval splits
|
549 |
+
labeled_train_split = labeled_trainset.select(
|
550 |
+
[i for i in range(0, round(len(labeled_trainset) * 0.8))]
|
551 |
+
)
|
552 |
+
labeled_eval_split = labeled_trainset.select(
|
553 |
+
[
|
554 |
+
i
|
555 |
+
for i in range(
|
556 |
+
round(len(labeled_trainset) * 0.8), len(labeled_trainset)
|
557 |
+
)
|
558 |
+
]
|
559 |
+
)
|
560 |
+
|
561 |
+
# filter dataset for cell types in corresponding training set
|
562 |
+
trained_labels = list(Counter(labeled_train_split["label"]).keys())
|
563 |
+
labeled_eval_split_subset = labeled_eval_split.filter(
|
564 |
+
if_trained_label, num_proc=cpu_cores
|
565 |
+
)
|
566 |
+
|
567 |
+
# set logging steps
|
568 |
+
logging_steps = round(len(trainset_label) / training_batch_size / 10)
|
569 |
+
|
570 |
+
# load pretrained model
|
571 |
+
model = BertForSequenceClassification.from_pretrained(
|
572 |
+
pretrained_model,
|
573 |
+
num_labels=len(target_dict_list.keys()),
|
574 |
+
output_attentions=False,
|
575 |
+
output_hidden_states=False,
|
576 |
+
).to(device)
|
577 |
+
# define output directory path
|
578 |
+
current_date = datetime.datetime.now()
|
579 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
580 |
+
|
581 |
+
if output_dir is None:
|
582 |
+
output_dir = f"{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{training_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
|
583 |
+
|
584 |
+
# ensure not overwriting previously saved model
|
585 |
+
saved_model_test = os.path.join(output_dir, "pytorch_model.bin")
|
586 |
+
if os.path.isfile(saved_model_test) is True and overwrite is False:
|
587 |
+
raise Exception("Model already saved to this directory.")
|
588 |
+
|
589 |
+
# make output directory
|
590 |
+
subprocess.call(f"mkdir -p {output_dir}", shell=True)
|
591 |
+
|
592 |
+
# set training arguments
|
593 |
+
training_args = {
|
594 |
+
"learning_rate": max_lr,
|
595 |
+
"do_train": True,
|
596 |
+
"do_eval": True,
|
597 |
+
"evaluation_strategy": "epoch",
|
598 |
+
"save_strategy": "epoch",
|
599 |
+
"logging_steps": logging_steps,
|
600 |
+
"group_by_length": True,
|
601 |
+
"length_column_name": "length",
|
602 |
+
"disable_tqdm": False,
|
603 |
+
"lr_scheduler_type": lr_schedule_fn,
|
604 |
+
"warmup_steps": warmup_steps,
|
605 |
+
"weight_decay": 0.001,
|
606 |
+
"per_device_train_batch_size": training_batch_size,
|
607 |
+
"per_device_eval_batch_size": training_batch_size,
|
608 |
+
"num_train_epochs": epochs,
|
609 |
+
"load_best_model_at_end": True,
|
610 |
+
"output_dir": output_dir,
|
611 |
+
}
|
612 |
+
|
613 |
+
training_args_init = TrainingArguments(**training_args)
|
614 |
+
true_labels = labeled_eval_split_subset["label"]
|
615 |
+
|
616 |
+
if optimize_hyperparameters is False:
|
617 |
+
# create the trainer
|
618 |
+
trainer = Trainer(
|
619 |
+
model=model,
|
620 |
+
args=training_args_init,
|
621 |
+
data_collator=DataCollatorForCellClassification(),
|
622 |
+
train_dataset=labeled_train_split,
|
623 |
+
eval_dataset=labeled_eval_split_subset,
|
624 |
+
compute_metrics=compute_metrics,
|
625 |
+
)
|
626 |
+
|
627 |
+
# train the cell type classifier
|
628 |
+
trainer.train()
|
629 |
+
predictions = trainer.predict(labeled_eval_split_subset)
|
630 |
+
predictions_tensor = torch.Tensor(predictions.predictions)
|
631 |
+
predicted_labels = torch.argmax(predictions_tensor, dim=1)
|
632 |
+
print(
|
633 |
+
f'accuracy: {accuracy_score(predicted_labels, labeled_eval_split_subset["label"])}'
|
634 |
+
)
|
635 |
+
metrics = compute_metrics(predictions)
|
636 |
+
|
637 |
+
with open(f"{output_dir}predictions.pickle", "wb") as fp:
|
638 |
+
pickle.dump(predictions.predictions.argmax(-1), fp)
|
639 |
+
|
640 |
+
trainer.save_metrics("eval", predictions.metrics)
|
641 |
+
trainer.save_model(output_dir)
|
642 |
+
|
643 |
+
# Saves label conversion dictionary to output directory
|
644 |
+
with open(f"{output_dir}/targets.txt", "w") as f:
|
645 |
+
f.write(str(target_dict_list))
|
646 |
+
|
647 |
+
try:
|
648 |
+
precision, recall, _ = precision_recall_curve(
|
649 |
+
true_labels, predictions.predictions[:, 1]
|
650 |
+
)
|
651 |
+
pr_auc = precision_auc(recall, precision)
|
652 |
+
|
653 |
+
print(f"AUC: {pr_auc}")
|
654 |
+
return recall, precision, pr_auc
|
655 |
+
except:
|
656 |
+
pass
|
657 |
+
|
658 |
+
else:
|
659 |
+
# Optimizes hyperparameters
|
660 |
+
|
661 |
+
num_classes = len(list(set(labeled_train_split["label"])))
|
662 |
+
|
663 |
+
def model_init():
|
664 |
+
model = BertForSequenceClassification.from_pretrained(
|
665 |
+
pretrained_model,
|
666 |
+
num_labels=num_classes,
|
667 |
+
output_attentions=False,
|
668 |
+
output_hidden_states=False,
|
669 |
+
)
|
670 |
+
|
671 |
+
if freeze_layers is not None:
|
672 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
673 |
+
for module in modules_to_freeze:
|
674 |
+
for param in module.parameters():
|
675 |
+
param.requires_grad = False
|
676 |
+
model = model.to(device)
|
677 |
+
return model
|
678 |
+
|
679 |
+
# create the trainer
|
680 |
+
trainer = Trainer(
|
681 |
+
model_init=model_init,
|
682 |
+
args=training_args_init,
|
683 |
+
data_collator=DataCollatorForCellClassification(),
|
684 |
+
train_dataset=labeled_train_split,
|
685 |
+
eval_dataset=labeled_eval_split_subset,
|
686 |
+
compute_metrics=compute_metrics,
|
687 |
+
)
|
688 |
+
|
689 |
+
# specify raytune hyperparameter search space
|
690 |
+
ray_config = {
|
691 |
+
"num_train_epochs": tune.choice([epochs]),
|
692 |
+
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
693 |
+
"weight_decay": tune.uniform(0.0, 0.3),
|
694 |
+
"lr_scheduler_type": tune.choice(
|
695 |
+
["linear", "cosine", "polynomial"]
|
696 |
+
),
|
697 |
+
"warmup_steps": tune.uniform(100, 2000),
|
698 |
+
"seed": tune.uniform(0, 100),
|
699 |
+
"per_device_train_batch_size": tune.choice([training_batch_size]),
|
700 |
+
}
|
701 |
+
|
702 |
+
hyperopt_search = HyperOptSearch(metric="eval_accuracy", mode="max")
|
703 |
+
|
704 |
+
if torch.device == "cuda":
|
705 |
+
resources_per_trial = ({"cpu": 8, "gpu": 1},)
|
706 |
+
else:
|
707 |
+
resources_per_trial = {"cpu": 8}
|
708 |
+
|
709 |
+
# optimize hyperparameters
|
710 |
+
best_trial = trainer.hyperparameter_search(
|
711 |
+
direction="maximize",
|
712 |
+
backend="ray",
|
713 |
+
resources_per_trial=resources_per_trial,
|
714 |
+
hp_space=lambda _: ray_config,
|
715 |
+
search_alg=hyperopt_search,
|
716 |
+
n_trials=100, # number of trials
|
717 |
+
progress_reporter=tune.CLIReporter(
|
718 |
+
max_report_frequency=600,
|
719 |
+
sort_by_metric=True,
|
720 |
+
max_progress_rows=100,
|
721 |
+
mode="max",
|
722 |
+
metric="eval_accuracy",
|
723 |
+
metric_columns=["loss", "eval_loss", "eval_accuracy"],
|
724 |
+
),
|
725 |
+
)
|
726 |
+
best_hyperparameters = best_trial.hyperparameters
|
727 |
+
|
728 |
+
print("Best Hyperparameters:")
|
729 |
+
print(best_hyperparameters)
|
730 |
+
|
731 |
+
# Performs Inference with model
|
732 |
+
if inference is True:
|
733 |
+
if dataset_split is not None and data_filter is not None:
|
734 |
+
|
735 |
+
def if_label(example):
|
736 |
+
return example[dataset_split] == data_filter
|
737 |
+
|
738 |
+
train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
|
739 |
+
|
740 |
+
trainset_label_shuffled = train_dataset
|
741 |
+
total_cells = len(trainset_label_shuffled)
|
742 |
+
|
743 |
+
# loads dictionary of all cell labels model was trained on
|
744 |
+
with open(Path(finetuned_model) / "targets.txt", "r") as f:
|
745 |
+
data = ast.literal_eval(f.read())
|
746 |
+
if dataset_split is not None and data_filter is None:
|
747 |
+
indexer = dataset_split.index(data_filter)
|
748 |
+
data = data[indexer]
|
749 |
+
|
750 |
+
target_dict_list = {key: value for key, value in enumerate(data)}
|
751 |
+
|
752 |
+
# set logging steps
|
753 |
+
logging_steps = round(len(trainset_label_shuffled) / training_batch_size / 20)
|
754 |
+
|
755 |
+
# load pretrained model
|
756 |
+
input_ids = trainset_label_shuffled["input_ids"]
|
757 |
+
inputs = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
|
758 |
+
attention = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
|
759 |
+
|
760 |
+
for i, sentence in enumerate(input_ids):
|
761 |
+
sentence_length = len(sentence)
|
762 |
+
if sentence_length <= max_input_size:
|
763 |
+
inputs[i, :sentence_length] = torch.tensor(sentence)
|
764 |
+
attention[i, :sentence_length] = torch.ones(sentence_length)
|
765 |
+
else:
|
766 |
+
inputs[i, :] = torch.tensor(sentence[:max_input_size])
|
767 |
+
attention[i, :] = torch.ones(max_input_size)
|
768 |
+
|
769 |
+
model = BertForSequenceClassification.from_pretrained(
|
770 |
+
finetuned_model, num_labels=len(target_dict_list)
|
771 |
+
).to(device)
|
772 |
+
model_outputs = model(inputs.to(device), attention_mask=attention)["logits"]
|
773 |
+
predictions = F.softmax(model_outputs, dim=-1).argmax(-1)
|
774 |
+
|
775 |
+
predictions = [target_dict_list[int(pred)] for pred in predictions]
|
776 |
+
|
777 |
+
return predictions
|
778 |
+
|
779 |
+
# Extracts embeddings from labeled data
|
780 |
+
if emb_extract is True:
|
781 |
+
if emb_filter is None:
|
782 |
+
with open(f"{finetuned_model}/targets.txt", "r") as f:
|
783 |
+
data = ast.literal_eval(f.read())
|
784 |
+
if dataset_split is not None and data_filter is None:
|
785 |
+
indexer = dataset_split.index(data_filter)
|
786 |
+
data = data[indexer]
|
787 |
+
|
788 |
+
target_dict_list = {key: value for key, value in enumerate(data)}
|
789 |
+
total_filter = None
|
790 |
+
else:
|
791 |
+
total_filter = emb_filter
|
792 |
+
|
793 |
+
train_dataset = load_from_disk(dataset)
|
794 |
+
if dataset_split is not None:
|
795 |
+
|
796 |
+
def if_label(example):
|
797 |
+
return example[dataset_split] == data_filter
|
798 |
+
|
799 |
+
train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
|
800 |
+
|
801 |
+
label_counter = Counter(train_dataset[label])
|
802 |
+
total_cells = sum(label_counter.values())
|
803 |
+
cells_to_keep = [
|
804 |
+
k for k, v in label_counter.items() if v > (filter_cells * total_cells)
|
805 |
+
]
|
806 |
+
|
807 |
+
def if_not_rare(example):
|
808 |
+
return example[label] in cells_to_keep
|
809 |
+
|
810 |
+
train_dataset = train_dataset.filter(if_not_rare, num_proc=cpu_cores)
|
811 |
+
|
812 |
+
true_labels = train_dataset[label]
|
813 |
+
num_classes = len(list(set(true_labels)))
|
814 |
+
|
815 |
+
embex = EmbExtractor(
|
816 |
+
model_type="CellClassifier",
|
817 |
+
num_classes=num_classes,
|
818 |
+
filter_data=total_filter,
|
819 |
+
max_ncells=max_cells,
|
820 |
+
emb_layer=emb_layer,
|
821 |
+
emb_label=[dataset_split, label],
|
822 |
+
labels_to_plot=[label],
|
823 |
+
forward_batch_size=inference_batch_size,
|
824 |
+
nproc=cpu_cores,
|
825 |
+
)
|
826 |
+
|
827 |
+
# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
|
828 |
+
subprocess.call(f"mkdir -p {emb_dir}", shell=True)
|
829 |
+
|
830 |
+
embs = embex.extract_embs(
|
831 |
+
model_directory=finetuned_model,
|
832 |
+
input_data_file=dataset,
|
833 |
+
output_directory=emb_dir,
|
834 |
+
output_prefix=f"{label}_embeddings",
|
835 |
+
)
|
836 |
+
true_labels = embex.filtered_input_data[label]
|
837 |
+
|
838 |
+
emb_dict = {label: [] for label in list(set(true_labels))}
|
839 |
+
for num, emb in embs.iterrows():
|
840 |
+
key = emb[label]
|
841 |
+
selection = emb.iloc[:255]
|
842 |
+
emb = torch.Tensor(selection)
|
843 |
+
emb_dict[key].append(emb)
|
844 |
+
|
845 |
+
for key in list(emb_dict.keys()):
|
846 |
+
stack = torch.stack(emb_dict[key], dim=0)
|
847 |
+
emb_dict[key] = torch.mean(stack, dim=0)
|
848 |
+
similarities = {key: {} for key in list(emb_dict.keys())}
|
849 |
+
|
850 |
+
for key in list(emb_dict.keys()):
|
851 |
+
remaining_keys = [k for k in list(emb_dict.keys()) if k != key]
|
852 |
+
for k in remaining_keys:
|
853 |
+
embedding = emb_dict[k]
|
854 |
+
sim = similarity(emb_dict[key], embedding, cosine=True)
|
855 |
+
|
856 |
+
similarities[key][k] = sim
|
857 |
+
|
858 |
+
plot_similarity_heatmap(similarities)
|
859 |
+
|
860 |
+
embex.plot_embs(
|
861 |
+
embs=embs,
|
862 |
+
plot_style="umap",
|
863 |
+
output_directory=emb_dir,
|
864 |
+
output_prefix="emb_plot",
|
865 |
+
)
|
866 |
+
|
867 |
+
embex.plot_embs(
|
868 |
+
embs=embs,
|
869 |
+
plot_style="heatmap",
|
870 |
+
output_directory=emb_dir,
|
871 |
+
output_prefix="emb_plot",
|
872 |
+
)
|
873 |
+
|
874 |
+
return similarities
|
geneformer/gene_classifier.py
ADDED
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
GPU_NUMBER = [0] # CHANGE WITH MULTIGPU
|
5 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
6 |
+
os.environ["NCCL_DEBUG"] = "INFO"
|
7 |
+
|
8 |
+
import ast
|
9 |
+
import datetime
|
10 |
+
import math
|
11 |
+
import pickle
|
12 |
+
import subprocess
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
import numpy as np
|
17 |
+
import pandas as pd
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from datasets import Dataset, load_from_disk
|
21 |
+
from sklearn import preprocessing
|
22 |
+
from sklearn.metrics import (
|
23 |
+
ConfusionMatrixDisplay,
|
24 |
+
accuracy_score,
|
25 |
+
auc,
|
26 |
+
confusion_matrix,
|
27 |
+
roc_auc_score,
|
28 |
+
roc_curve,
|
29 |
+
)
|
30 |
+
|
31 |
+
# imports
|
32 |
+
from sklearn.model_selection import StratifiedKFold, train_test_split
|
33 |
+
from tqdm.notebook import tqdm
|
34 |
+
from transformers import BertForTokenClassification, Trainer
|
35 |
+
from transformers.training_args import TrainingArguments
|
36 |
+
|
37 |
+
from geneformer import DataCollatorForGeneClassification, EmbExtractor
|
38 |
+
from geneformer.pretrainer import token_dictionary
|
39 |
+
|
40 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
41 |
+
from geneformer import TranscriptomeTokenizer
|
42 |
+
|
43 |
+
|
44 |
+
def vote(logit_pair):
|
45 |
+
a, b = logit_pair
|
46 |
+
if a > b:
|
47 |
+
return 0
|
48 |
+
elif b > a:
|
49 |
+
return 1
|
50 |
+
elif a == b:
|
51 |
+
return "tie"
|
52 |
+
|
53 |
+
|
54 |
+
def py_softmax(vector):
|
55 |
+
e = np.exp(vector)
|
56 |
+
return e / e.sum()
|
57 |
+
|
58 |
+
# Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
|
59 |
+
|
60 |
+
|
61 |
+
def similarity(tensor1, tensor2, cosine=True):
|
62 |
+
if cosine == False:
|
63 |
+
if tensor1.ndimension() > 1:
|
64 |
+
tensor1 = tensor1.view(1, -1)
|
65 |
+
if tensor2.ndimension() > 1:
|
66 |
+
tensor2 = tensor2.view(1, -1)
|
67 |
+
dot_product = torch.matmul(tensor1, tensor2)
|
68 |
+
norm_tensor1 = torch.norm(tensor1)
|
69 |
+
norm_tensor2 = torch.norm(tensor2)
|
70 |
+
epsilon = 1e-8
|
71 |
+
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
72 |
+
similarity = (similarity.item() + 1) / 2
|
73 |
+
else:
|
74 |
+
if tensor1.shape != tensor2.shape:
|
75 |
+
raise ValueError("Input tensors must have the same shape.")
|
76 |
+
|
77 |
+
# Compute cosine similarity using PyTorch's dot product function
|
78 |
+
dot_product = torch.dot(tensor1, tensor2)
|
79 |
+
norm_tensor1 = torch.norm(tensor1)
|
80 |
+
norm_tensor2 = torch.norm(tensor2)
|
81 |
+
|
82 |
+
# Avoid division by zero by adding a small epsilon
|
83 |
+
epsilon = 1e-8
|
84 |
+
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
85 |
+
|
86 |
+
return similarity.item()
|
87 |
+
|
88 |
+
|
89 |
+
# Plots heatmap between different classes/labels
|
90 |
+
def plot_similarity_heatmap(similarities):
|
91 |
+
classes = list(similarities.keys())
|
92 |
+
classlen = len(classes)
|
93 |
+
arr = np.zeros((classlen, classlen))
|
94 |
+
for i, c in enumerate(classes):
|
95 |
+
for j, cc in enumerate(classes):
|
96 |
+
if cc == c:
|
97 |
+
val = 1.0
|
98 |
+
else:
|
99 |
+
val = similarities[c][cc]
|
100 |
+
arr[i][j] = val
|
101 |
+
|
102 |
+
plt.figure(figsize=(8, 6))
|
103 |
+
plt.imshow(arr, cmap="inferno", vmin=0, vmax=1)
|
104 |
+
plt.colorbar()
|
105 |
+
plt.xticks(np.arange(classlen), classes, rotation=45, ha="right")
|
106 |
+
plt.yticks(np.arange(classlen), classes)
|
107 |
+
plt.title("Similarity Heatmap")
|
108 |
+
plt.savefig("similarity_heatmap.png")
|
109 |
+
|
110 |
+
|
111 |
+
# get cross-validated mean and sd metrics
|
112 |
+
def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
|
113 |
+
wts = [count / sum(all_tpr_wt) for count in all_tpr_wt]
|
114 |
+
|
115 |
+
all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)]
|
116 |
+
mean_tpr = np.sum(all_weighted_tpr, axis=0)
|
117 |
+
mean_tpr[-1] = 1.0
|
118 |
+
all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)]
|
119 |
+
roc_auc = np.sum(all_weighted_roc_auc)
|
120 |
+
roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts))
|
121 |
+
return mean_tpr, roc_auc, roc_auc_sd
|
122 |
+
|
123 |
+
|
124 |
+
def validate(
|
125 |
+
data,
|
126 |
+
targets,
|
127 |
+
labels,
|
128 |
+
nsplits,
|
129 |
+
subsample_size,
|
130 |
+
training_args,
|
131 |
+
freeze_layers,
|
132 |
+
output_dir,
|
133 |
+
num_proc,
|
134 |
+
num_labels,
|
135 |
+
pre_model,
|
136 |
+
):
|
137 |
+
# initiate eval metrics to return
|
138 |
+
num_classes = len(set(labels))
|
139 |
+
mean_fpr = np.linspace(0, 1, 100)
|
140 |
+
|
141 |
+
# create 80/20 train/eval splits
|
142 |
+
targets_train, targets_eval, labels_train, labels_eval = train_test_split(
|
143 |
+
targets, labels, test_size=0.25, shuffle=True
|
144 |
+
)
|
145 |
+
label_dict_train = dict(zip(targets_train, labels_train))
|
146 |
+
label_dict_eval = dict(zip(targets_eval, labels_eval))
|
147 |
+
|
148 |
+
# function to filter by whether contains train or eval labels
|
149 |
+
def if_contains_train_label(example):
|
150 |
+
a = label_dict_train.keys()
|
151 |
+
b = example["input_ids"]
|
152 |
+
return not set(a).isdisjoint(b)
|
153 |
+
|
154 |
+
def if_contains_eval_label(example):
|
155 |
+
a = label_dict_eval.keys()
|
156 |
+
b = example["input_ids"]
|
157 |
+
return not set(a).isdisjoint(b)
|
158 |
+
|
159 |
+
# filter dataset for examples containing classes for this split
|
160 |
+
print(f"Filtering training data")
|
161 |
+
trainset = data.filter(if_contains_train_label, num_proc=num_proc)
|
162 |
+
print(
|
163 |
+
f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n"
|
164 |
+
)
|
165 |
+
print(f"Filtering evalation data")
|
166 |
+
evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
|
167 |
+
print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
|
168 |
+
|
169 |
+
# minimize to smaller training sample
|
170 |
+
training_size = min(subsample_size, len(trainset))
|
171 |
+
trainset_min = trainset.select([i for i in range(training_size)])
|
172 |
+
eval_size = min(training_size, len(evalset))
|
173 |
+
half_training_size = round(eval_size / 2)
|
174 |
+
evalset_train_min = evalset.select([i for i in range(half_training_size)])
|
175 |
+
evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
|
176 |
+
|
177 |
+
# label conversion functions
|
178 |
+
def generate_train_labels(example):
|
179 |
+
example["labels"] = [
|
180 |
+
label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
|
181 |
+
]
|
182 |
+
return example
|
183 |
+
|
184 |
+
def generate_eval_labels(example):
|
185 |
+
example["labels"] = [
|
186 |
+
label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
|
187 |
+
]
|
188 |
+
return example
|
189 |
+
|
190 |
+
# label datasets
|
191 |
+
print(f"Labeling training data")
|
192 |
+
trainset_labeled = trainset_min.map(generate_train_labels)
|
193 |
+
print(f"Labeling evaluation data")
|
194 |
+
evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
|
195 |
+
print(f"Labeling evaluation OOS data")
|
196 |
+
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
|
197 |
+
|
198 |
+
# load model
|
199 |
+
model = BertForTokenClassification.from_pretrained(
|
200 |
+
pre_model,
|
201 |
+
num_labels=num_labels,
|
202 |
+
output_attentions=False,
|
203 |
+
output_hidden_states=False,
|
204 |
+
)
|
205 |
+
if freeze_layers is not None:
|
206 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
207 |
+
for module in modules_to_freeze:
|
208 |
+
for param in module.parameters():
|
209 |
+
param.requires_grad = False
|
210 |
+
|
211 |
+
model = model.to(device)
|
212 |
+
|
213 |
+
# add output directory to training args and initiate
|
214 |
+
training_args["output_dir"] = output_dir
|
215 |
+
training_args_init = TrainingArguments(**training_args)
|
216 |
+
|
217 |
+
# create the trainer
|
218 |
+
trainer = Trainer(
|
219 |
+
model=model,
|
220 |
+
args=training_args_init,
|
221 |
+
data_collator=DataCollatorForGeneClassification(),
|
222 |
+
train_dataset=trainset_labeled,
|
223 |
+
eval_dataset=evalset_train_labeled,
|
224 |
+
)
|
225 |
+
|
226 |
+
# train the gene classifier
|
227 |
+
trainer.train()
|
228 |
+
trainer.save_model(output_dir)
|
229 |
+
|
230 |
+
fpr, tpr, interp_tpr, conf_mat = classifier_predict(
|
231 |
+
trainer.model, evalset_oos_labeled, 200, mean_fpr
|
232 |
+
)
|
233 |
+
auc_score = auc(fpr, tpr)
|
234 |
+
|
235 |
+
return fpr, tpr, auc_score
|
236 |
+
|
237 |
+
|
238 |
+
# cross-validate gene classifier
|
239 |
+
def cross_validate(
|
240 |
+
data,
|
241 |
+
targets,
|
242 |
+
labels,
|
243 |
+
nsplits,
|
244 |
+
subsample_size,
|
245 |
+
training_args,
|
246 |
+
freeze_layers,
|
247 |
+
output_dir,
|
248 |
+
num_proc,
|
249 |
+
num_labels,
|
250 |
+
pre_model,
|
251 |
+
):
|
252 |
+
# check if output directory already written to
|
253 |
+
# ensure not overwriting previously saved model
|
254 |
+
model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin")
|
255 |
+
# if os.path.isfile(model_dir_test) == True:
|
256 |
+
# raise Exception("Model already saved to this directory.")
|
257 |
+
|
258 |
+
device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
259 |
+
# initiate eval metrics to return
|
260 |
+
num_classes = len(set(labels))
|
261 |
+
mean_fpr = np.linspace(0, 1, 100)
|
262 |
+
all_tpr = []
|
263 |
+
all_roc_auc = []
|
264 |
+
all_tpr_wt = []
|
265 |
+
label_dicts = []
|
266 |
+
confusion = np.zeros((num_classes, num_classes))
|
267 |
+
|
268 |
+
# set up cross-validation splits
|
269 |
+
skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)
|
270 |
+
# train and evaluate
|
271 |
+
iteration_num = 0
|
272 |
+
for train_index, eval_index in tqdm(skf.split(targets, labels)):
|
273 |
+
if len(labels) > 500:
|
274 |
+
print("early stopping activated due to large # of training examples")
|
275 |
+
if iteration_num == 3:
|
276 |
+
break
|
277 |
+
|
278 |
+
print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n")
|
279 |
+
|
280 |
+
# generate cross-validation splits
|
281 |
+
targets_train, targets_eval = targets[train_index], targets[eval_index]
|
282 |
+
labels_train, labels_eval = labels[train_index], labels[eval_index]
|
283 |
+
label_dict_train = dict(zip(targets_train, labels_train))
|
284 |
+
label_dict_eval = dict(zip(targets_eval, labels_eval))
|
285 |
+
label_dicts += (
|
286 |
+
iteration_num,
|
287 |
+
targets_train,
|
288 |
+
targets_eval,
|
289 |
+
labels_train,
|
290 |
+
labels_eval,
|
291 |
+
)
|
292 |
+
|
293 |
+
# function to filter by whether contains train or eval labels
|
294 |
+
def if_contains_train_label(example):
|
295 |
+
a = label_dict_train.keys()
|
296 |
+
b = example["input_ids"]
|
297 |
+
|
298 |
+
return not set(a).isdisjoint(b)
|
299 |
+
|
300 |
+
def if_contains_eval_label(example):
|
301 |
+
a = label_dict_eval.keys()
|
302 |
+
b = example["input_ids"]
|
303 |
+
|
304 |
+
return not set(a).isdisjoint(b)
|
305 |
+
|
306 |
+
# filter dataset for examples containing classes for this split
|
307 |
+
print(f"Filtering training data")
|
308 |
+
trainset = data.filter(if_contains_train_label, num_proc=num_proc)
|
309 |
+
print(
|
310 |
+
f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n"
|
311 |
+
)
|
312 |
+
print(f"Filtering evalation data")
|
313 |
+
evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
|
314 |
+
print(
|
315 |
+
f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n"
|
316 |
+
)
|
317 |
+
|
318 |
+
# minimize to smaller training sample
|
319 |
+
training_size = min(subsample_size, len(trainset))
|
320 |
+
trainset_min = trainset.select([i for i in range(training_size)])
|
321 |
+
eval_size = min(training_size, len(evalset))
|
322 |
+
half_training_size = round(eval_size / 2)
|
323 |
+
evalset_train_min = evalset.select([i for i in range(half_training_size)])
|
324 |
+
evalset_oos_min = evalset.select(
|
325 |
+
[i for i in range(half_training_size, eval_size)]
|
326 |
+
)
|
327 |
+
|
328 |
+
# label conversion functions
|
329 |
+
def generate_train_labels(example):
|
330 |
+
example["labels"] = [
|
331 |
+
label_dict_train.get(token_id, -100)
|
332 |
+
for token_id in example["input_ids"]
|
333 |
+
]
|
334 |
+
return example
|
335 |
+
|
336 |
+
def generate_eval_labels(example):
|
337 |
+
example["labels"] = [
|
338 |
+
label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
|
339 |
+
]
|
340 |
+
return example
|
341 |
+
|
342 |
+
# label datasets
|
343 |
+
print(f"Labeling training data")
|
344 |
+
trainset_labeled = trainset_min.map(generate_train_labels)
|
345 |
+
print(f"Labeling evaluation data")
|
346 |
+
evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
|
347 |
+
print(f"Labeling evaluation OOS data")
|
348 |
+
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
|
349 |
+
|
350 |
+
# create output directories
|
351 |
+
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
352 |
+
ksplit_model_dir = os.path.join(ksplit_output_dir, "models/")
|
353 |
+
|
354 |
+
# ensure not overwriting previously saved model
|
355 |
+
model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin")
|
356 |
+
# if os.path.isfile(model_output_file) == True:
|
357 |
+
# raise Exception("Model already saved to this directory.")
|
358 |
+
|
359 |
+
# make training and model output directories
|
360 |
+
subprocess.call(f"mkdir -p {ksplit_output_dir}", shell=True)
|
361 |
+
subprocess.call(f"mkdir -p {ksplit_model_dir}", shell=True)
|
362 |
+
|
363 |
+
# load model
|
364 |
+
model = BertForTokenClassification.from_pretrained(
|
365 |
+
pre_model,
|
366 |
+
num_labels=num_labels,
|
367 |
+
output_attentions=False,
|
368 |
+
output_hidden_states=False,
|
369 |
+
)
|
370 |
+
if freeze_layers is not None:
|
371 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
372 |
+
for module in modules_to_freeze:
|
373 |
+
for param in module.parameters():
|
374 |
+
param.requires_grad = False
|
375 |
+
|
376 |
+
model = model.to(device)
|
377 |
+
|
378 |
+
# add output directory to training args and initiate
|
379 |
+
training_args["output_dir"] = ksplit_output_dir
|
380 |
+
training_args_init = TrainingArguments(**training_args)
|
381 |
+
|
382 |
+
# create the trainer
|
383 |
+
trainer = Trainer(
|
384 |
+
model=model,
|
385 |
+
args=training_args_init,
|
386 |
+
data_collator=DataCollatorForGeneClassification(),
|
387 |
+
train_dataset=trainset_labeled,
|
388 |
+
eval_dataset=evalset_train_labeled,
|
389 |
+
)
|
390 |
+
|
391 |
+
# train the gene classifier
|
392 |
+
trainer.train()
|
393 |
+
|
394 |
+
# save model
|
395 |
+
trainer.save_model(ksplit_model_dir)
|
396 |
+
|
397 |
+
# evaluate model
|
398 |
+
fpr, tpr, interp_tpr, conf_mat = classifier_predict(
|
399 |
+
trainer.model, evalset_oos_labeled, 200, mean_fpr
|
400 |
+
)
|
401 |
+
|
402 |
+
# append to tpr and roc lists
|
403 |
+
confusion = confusion + conf_mat
|
404 |
+
all_tpr.append(interp_tpr)
|
405 |
+
all_roc_auc.append(auc(fpr, tpr))
|
406 |
+
# append number of eval examples by which to weight tpr in averaged graphs
|
407 |
+
all_tpr_wt.append(len(tpr))
|
408 |
+
|
409 |
+
iteration_num = iteration_num + 1
|
410 |
+
|
411 |
+
# get overall metrics for cross-validation
|
412 |
+
mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(
|
413 |
+
all_tpr, all_roc_auc, all_tpr_wt
|
414 |
+
)
|
415 |
+
return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts
|
416 |
+
|
417 |
+
|
418 |
+
# Computes metrics
|
419 |
+
def compute_metrics(pred):
|
420 |
+
labels = pred.label_ids
|
421 |
+
preds = pred.predictions.argmax(-1)
|
422 |
+
# calculate accuracy and macro f1 using sklearn's function
|
423 |
+
acc = accuracy_score(labels, preds)
|
424 |
+
macro_f1 = f1_score(labels, preds, average="macro")
|
425 |
+
|
426 |
+
return {"accuracy": acc, "macro_f1": macro_f1}
|
427 |
+
|
428 |
+
|
429 |
+
# plot ROC curve
|
430 |
+
def plot_ROC(bundled_data, title):
|
431 |
+
plt.figure()
|
432 |
+
lw = 2
|
433 |
+
for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
|
434 |
+
plt.plot(
|
435 |
+
mean_fpr,
|
436 |
+
mean_tpr,
|
437 |
+
color=color,
|
438 |
+
lw=lw,
|
439 |
+
label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(
|
440 |
+
sample, roc_auc, roc_auc_sd
|
441 |
+
),
|
442 |
+
)
|
443 |
+
|
444 |
+
plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--")
|
445 |
+
plt.xlim([0.0, 1.0])
|
446 |
+
plt.ylim([0.0, 1.05])
|
447 |
+
plt.xlabel("False Positive Rate")
|
448 |
+
plt.ylabel("True Positive Rate")
|
449 |
+
plt.title(title)
|
450 |
+
plt.legend(loc="lower right")
|
451 |
+
plt.savefig("ROC.png")
|
452 |
+
|
453 |
+
return mean_fpr, mean_tpr, roc_auc
|
454 |
+
|
455 |
+
|
456 |
+
# plot confusion matrix
|
457 |
+
def plot_confusion_matrix(classes_list, conf_mat, title):
|
458 |
+
display_labels = []
|
459 |
+
i = 0
|
460 |
+
for label in classes_list:
|
461 |
+
display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:, i]))]
|
462 |
+
i = i + 1
|
463 |
+
display = ConfusionMatrixDisplay(
|
464 |
+
confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"),
|
465 |
+
display_labels=display_labels,
|
466 |
+
)
|
467 |
+
display.plot(cmap="Blues", values_format=".2g")
|
468 |
+
plt.title(title)
|
469 |
+
plt.savefig("CM.png")
|
470 |
+
|
471 |
+
|
472 |
+
# Function to find the largest number smaller
|
473 |
+
# than or equal to N that is divisible by k
|
474 |
+
def find_largest_div(N, K):
|
475 |
+
rem = N % K
|
476 |
+
if rem == 0:
|
477 |
+
return N
|
478 |
+
else:
|
479 |
+
return N - rem
|
480 |
+
|
481 |
+
|
482 |
+
def preprocess_classifier_batch(cell_batch, max_len):
|
483 |
+
if max_len == None:
|
484 |
+
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
485 |
+
|
486 |
+
def pad_label_example(example):
|
487 |
+
example["labels"] = np.pad(
|
488 |
+
example["labels"],
|
489 |
+
(0, max_len - len(example["input_ids"])),
|
490 |
+
mode="constant",
|
491 |
+
constant_values=-100,
|
492 |
+
)
|
493 |
+
example["input_ids"] = np.pad(
|
494 |
+
example["input_ids"],
|
495 |
+
(0, max_len - len(example["input_ids"])),
|
496 |
+
mode="constant",
|
497 |
+
constant_values=token_dictionary.get("<pad>"),
|
498 |
+
)
|
499 |
+
example["attention_mask"] = (
|
500 |
+
example["input_ids"] != token_dictionary.get("<pad>")
|
501 |
+
).astype(int)
|
502 |
+
return example
|
503 |
+
|
504 |
+
padded_batch = cell_batch.map(pad_label_example)
|
505 |
+
return padded_batch
|
506 |
+
|
507 |
+
|
508 |
+
# forward batch size is batch size for model inference (e.g. 200)
|
509 |
+
def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
|
510 |
+
predict_logits = []
|
511 |
+
predict_labels = []
|
512 |
+
model.to("cpu")
|
513 |
+
model.eval()
|
514 |
+
|
515 |
+
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
516 |
+
evalset_len = len(evalset)
|
517 |
+
max_divisible = find_largest_div(evalset_len, forward_batch_size)
|
518 |
+
if len(evalset) - max_divisible == 1:
|
519 |
+
evalset_len = max_divisible
|
520 |
+
|
521 |
+
max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
|
522 |
+
|
523 |
+
for i in range(0, evalset_len, forward_batch_size):
|
524 |
+
max_range = min(i + forward_batch_size, evalset_len)
|
525 |
+
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
526 |
+
padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
|
527 |
+
padded_batch.set_format(type="torch")
|
528 |
+
|
529 |
+
input_data_batch = padded_batch["input_ids"]
|
530 |
+
attn_msk_batch = padded_batch["attention_mask"]
|
531 |
+
label_batch = padded_batch["labels"]
|
532 |
+
with torch.no_grad():
|
533 |
+
input_ids = input_data_batch
|
534 |
+
attn_mask = attn_msk_batch
|
535 |
+
labels = label_batch
|
536 |
+
outputs = model(
|
537 |
+
input_ids=input_ids, attention_mask=attn_mask, labels=labels
|
538 |
+
)
|
539 |
+
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
|
540 |
+
predict_labels += [torch.squeeze(label_batch.to("cpu"))]
|
541 |
+
|
542 |
+
logits_by_cell = torch.cat(predict_logits)
|
543 |
+
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
|
544 |
+
labels_by_cell = torch.cat(predict_labels)
|
545 |
+
all_labels = torch.flatten(labels_by_cell)
|
546 |
+
logit_label_paired = [
|
547 |
+
item
|
548 |
+
for item in list(zip(all_logits.tolist(), all_labels.tolist()))
|
549 |
+
if item[1] != -100
|
550 |
+
]
|
551 |
+
y_pred = [vote(item[0]) for item in logit_label_paired]
|
552 |
+
y_true = [item[1] for item in logit_label_paired]
|
553 |
+
logits_list = [item[0] for item in logit_label_paired]
|
554 |
+
# probability of class 1
|
555 |
+
y_score = [py_softmax(item)[1] for item in logits_list]
|
556 |
+
conf_mat = confusion_matrix(y_true, y_pred)
|
557 |
+
fpr, tpr, _ = roc_curve(y_true, y_score)
|
558 |
+
# plot roc_curve for this split
|
559 |
+
plt.plot(fpr, tpr)
|
560 |
+
plt.xlim([0.0, 1.0])
|
561 |
+
plt.ylim([0.0, 1.05])
|
562 |
+
plt.xlabel("False Positive Rate")
|
563 |
+
plt.ylabel("True Positive Rate")
|
564 |
+
plt.title("ROC")
|
565 |
+
plt.show()
|
566 |
+
# interpolate to graph
|
567 |
+
interp_tpr = np.interp(mean_fpr, fpr, tpr)
|
568 |
+
interp_tpr[0] = 0.0
|
569 |
+
return fpr, tpr, interp_tpr, conf_mat
|
570 |
+
|
571 |
+
|
572 |
+
def classify_genes(
|
573 |
+
gene_info="Genecorpus-30M/example_input_files/gene_info_table.csv",
|
574 |
+
genes="Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
|
575 |
+
corpus_30M="Genecorpus-30M/genecorpus_30M_2048.dataset/",
|
576 |
+
model=".",
|
577 |
+
max_input_size=2**11,
|
578 |
+
max_lr=5e-5,
|
579 |
+
freeze_layers=4,
|
580 |
+
num_gpus=1,
|
581 |
+
num_proc=os.cpu_count(),
|
582 |
+
geneformer_batch_size=9,
|
583 |
+
epochs=1,
|
584 |
+
filter_dataset=50_000,
|
585 |
+
emb_extract=True,
|
586 |
+
emb_layer=0,
|
587 |
+
forward_batch=200,
|
588 |
+
filter_data=None,
|
589 |
+
inference=False,
|
590 |
+
k_validate=True,
|
591 |
+
model_location="230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
|
592 |
+
skip_training=False,
|
593 |
+
emb_dir="gene_emb",
|
594 |
+
output_dir=None,
|
595 |
+
max_cells=1000,
|
596 |
+
num_cpus=os.cpu_count(),
|
597 |
+
):
|
598 |
+
""" "
|
599 |
+
Primary Parameters
|
600 |
+
-----------
|
601 |
+
|
602 |
+
gene_info: path
|
603 |
+
Path to gene mappings
|
604 |
+
|
605 |
+
corpus_30M: path
|
606 |
+
Path to 30M Gene Corpus
|
607 |
+
|
608 |
+
model: path
|
609 |
+
Path to pretrained GeneFormer model
|
610 |
+
|
611 |
+
genes: path
|
612 |
+
Path to csv file containing different columns of genes and the column labels
|
613 |
+
|
614 |
+
inference: bool
|
615 |
+
Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
|
616 |
+
|
617 |
+
k_validate: bool
|
618 |
+
Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
|
619 |
+
|
620 |
+
skip_training: bool
|
621 |
+
Whether the model should skip the training portion. Defaults to False
|
622 |
+
|
623 |
+
emb_extract: bool
|
624 |
+
WHether the model should extract embeddings for a given gene (WIP)
|
625 |
+
|
626 |
+
|
627 |
+
Customization Parameters
|
628 |
+
-----------
|
629 |
+
|
630 |
+
freeze_layers: int
|
631 |
+
Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
|
632 |
+
|
633 |
+
filter_dataset: int
|
634 |
+
Number of cells to filter from 30M dataset. Default is 50_000
|
635 |
+
|
636 |
+
emb_layer: int
|
637 |
+
What layer embeddings are extracted from. Default is 4
|
638 |
+
|
639 |
+
filter_data: str, list
|
640 |
+
Filters down embeddings to a single category. Default is None
|
641 |
+
|
642 |
+
|
643 |
+
"""
|
644 |
+
|
645 |
+
# table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)
|
646 |
+
gene_info = pd.read_csv(gene_info, index_col=0)
|
647 |
+
labels = gene_info.columns
|
648 |
+
|
649 |
+
# create dictionaries for corresponding attributes
|
650 |
+
gene_id_type_dict = dict(zip(gene_info["ensembl_id"], gene_info["gene_type"]))
|
651 |
+
gene_name_id_dict = dict(zip(gene_info["gene_name"], gene_info["ensembl_id"]))
|
652 |
+
gene_id_name_dict = {v: k for k, v in gene_name_id_dict.items()}
|
653 |
+
|
654 |
+
# function for preparing targets and labels
|
655 |
+
def prep_inputs(label_store, id_type):
|
656 |
+
target_list = []
|
657 |
+
if id_type == "gene_name":
|
658 |
+
for key in list(label_store.keys()):
|
659 |
+
targets = [
|
660 |
+
gene_name_id_dict[gene]
|
661 |
+
for gene in label_store[key]
|
662 |
+
if gene_name_id_dict.get(gene) in token_dictionary
|
663 |
+
]
|
664 |
+
targets_id = [token_dictionary[gene] for gene in targets]
|
665 |
+
target_list.append(targets_id)
|
666 |
+
elif id_type == "ensembl_id":
|
667 |
+
for key in list(label_store.keys()):
|
668 |
+
targets = [
|
669 |
+
gene for gene in label_store[key] if gene in token_dictionary
|
670 |
+
]
|
671 |
+
targets_id = [token_dictionary[gene] for gene in targets]
|
672 |
+
target_list.append(targets_id)
|
673 |
+
|
674 |
+
targets, labels = [], []
|
675 |
+
for targ in target_list:
|
676 |
+
targets = targets + targ
|
677 |
+
targets = np.array(targets)
|
678 |
+
for num, targ in enumerate(target_list):
|
679 |
+
label = [num] * len(targ)
|
680 |
+
labels = labels + label
|
681 |
+
labels = np.array(labels)
|
682 |
+
unique_labels = num + 1
|
683 |
+
|
684 |
+
nsplits = min(5, min([len(targ) for targ in target_list]) - 1)
|
685 |
+
assert nsplits > 2
|
686 |
+
|
687 |
+
return targets, labels, nsplits, unique_labels
|
688 |
+
|
689 |
+
if skip_training == False:
|
690 |
+
# preparing targets and labels for dosage sensitive vs insensitive TFs
|
691 |
+
gene_classes = pd.read_csv(genes, header=0)
|
692 |
+
if filter_data == None:
|
693 |
+
labels = gene_classes.columns
|
694 |
+
else:
|
695 |
+
if isinstance(filter_data, list):
|
696 |
+
labels = filter_data
|
697 |
+
else:
|
698 |
+
labels = [filter_data]
|
699 |
+
label_store = {}
|
700 |
+
|
701 |
+
# Dictionary for decoding labels
|
702 |
+
decode = {i: labels[i] for i in range(len(labels))}
|
703 |
+
|
704 |
+
for label in labels:
|
705 |
+
label_store[label] = gene_classes[label].dropna()
|
706 |
+
|
707 |
+
targets, labels, nsplits, unique_labels = prep_inputs(label_store, "ensembl_id")
|
708 |
+
|
709 |
+
# load training dataset
|
710 |
+
train_dataset = load_from_disk(corpus_30M)
|
711 |
+
shuffled_train_dataset = train_dataset.shuffle(seed=42)
|
712 |
+
subsampled_train_dataset = shuffled_train_dataset.select(
|
713 |
+
[i for i in range(filter_dataset)]
|
714 |
+
)
|
715 |
+
lr_schedule_fn = "linear"
|
716 |
+
warmup_steps = 500
|
717 |
+
optimizer = "adamw"
|
718 |
+
subsample_size = 10_000
|
719 |
+
|
720 |
+
training_args = {
|
721 |
+
"learning_rate": max_lr,
|
722 |
+
"do_train": True,
|
723 |
+
"evaluation_strategy": "no",
|
724 |
+
"save_strategy": "epoch",
|
725 |
+
"logging_steps": 10,
|
726 |
+
"group_by_length": True,
|
727 |
+
"length_column_name": "length",
|
728 |
+
"disable_tqdm": False,
|
729 |
+
"lr_scheduler_type": lr_schedule_fn,
|
730 |
+
"warmup_steps": warmup_steps,
|
731 |
+
"weight_decay": 0.001,
|
732 |
+
"per_device_train_batch_size": geneformer_batch_size,
|
733 |
+
"per_device_eval_batch_size": geneformer_batch_size,
|
734 |
+
"num_train_epochs": epochs,
|
735 |
+
}
|
736 |
+
|
737 |
+
# define output directory path
|
738 |
+
current_date = datetime.datetime.now()
|
739 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
740 |
+
|
741 |
+
if output_dir == None:
|
742 |
+
training_output_dir = Path(
|
743 |
+
f"{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/"
|
744 |
+
)
|
745 |
+
else:
|
746 |
+
training_output_dir = Path(output_dir)
|
747 |
+
|
748 |
+
# make output directory
|
749 |
+
subprocess.call(f"mkdir -p {training_output_dir}", shell=True)
|
750 |
+
|
751 |
+
# Places number of classes + in directory
|
752 |
+
num_classes = len(set(labels))
|
753 |
+
info_list = [num_classes, decode]
|
754 |
+
|
755 |
+
with open(training_output_dir / "classes.txt", "w") as f:
|
756 |
+
f.write(str(info_list))
|
757 |
+
|
758 |
+
subsampled_train_dataset.save_to_disk(output_dir / "dataset")
|
759 |
+
|
760 |
+
if k_validate == True:
|
761 |
+
ksplit_model = "ksplit0/models"
|
762 |
+
ksplit_model_test = os.path.join(training_output_dir, ksplit_model)
|
763 |
+
# if os.path.isfile(ksplit_model_test) == True:
|
764 |
+
# raise Exception("Model already saved to this directory.")
|
765 |
+
# cross-validate gene classifier
|
766 |
+
(
|
767 |
+
all_roc_auc,
|
768 |
+
roc_auc,
|
769 |
+
roc_auc_sd,
|
770 |
+
mean_fpr,
|
771 |
+
mean_tpr,
|
772 |
+
confusion,
|
773 |
+
label_dicts,
|
774 |
+
) = cross_validate(
|
775 |
+
subsampled_train_dataset,
|
776 |
+
targets,
|
777 |
+
labels,
|
778 |
+
nsplits,
|
779 |
+
subsample_size,
|
780 |
+
training_args,
|
781 |
+
freeze_layers,
|
782 |
+
training_output_dir,
|
783 |
+
1,
|
784 |
+
unique_labels,
|
785 |
+
model,
|
786 |
+
)
|
787 |
+
|
788 |
+
bundled_data = []
|
789 |
+
bundled_data += [
|
790 |
+
(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red")
|
791 |
+
]
|
792 |
+
graph_title = " ".join(
|
793 |
+
[
|
794 |
+
i + " vs" if count < len(label_store) - 1 else i
|
795 |
+
for count, i in enumerate(label_store)
|
796 |
+
]
|
797 |
+
)
|
798 |
+
fpr, tpr, auc = plot_ROC(
|
799 |
+
bundled_data, "Dosage Sensitive vs Insensitive TFs"
|
800 |
+
)
|
801 |
+
print(auc)
|
802 |
+
# plot confusion matrix
|
803 |
+
plot_confusion_matrix(label_store, confusion, "Geneformer")
|
804 |
+
else:
|
805 |
+
fpr, tpr, auc = validate(
|
806 |
+
subsampled_train_dataset,
|
807 |
+
targets,
|
808 |
+
labels,
|
809 |
+
nsplits,
|
810 |
+
subsample_size,
|
811 |
+
training_args,
|
812 |
+
freeze_layers,
|
813 |
+
training_output_dir,
|
814 |
+
1,
|
815 |
+
unique_labels,
|
816 |
+
model,
|
817 |
+
)
|
818 |
+
print(auc)
|
819 |
+
|
820 |
+
if inference == True:
|
821 |
+
# preparing targets and labels for dosage sensitive vs insensitive TFs
|
822 |
+
gene_classes = pd.read_csv(genes, header=0)
|
823 |
+
targets = []
|
824 |
+
for column in gene_classes.columns:
|
825 |
+
targets += list(gene_classes[column])
|
826 |
+
tokens = []
|
827 |
+
for target in targets:
|
828 |
+
try:
|
829 |
+
tokens.append(token_dictionary[target])
|
830 |
+
except:
|
831 |
+
tokens.append(0)
|
832 |
+
|
833 |
+
targets = torch.LongTensor([tokens])
|
834 |
+
|
835 |
+
with open(f"{model_location}classes.txt", "r") as f:
|
836 |
+
info_list = ast.literal_eval(f.read())
|
837 |
+
num_classes = info_list[0]
|
838 |
+
labels = info_list[1]
|
839 |
+
|
840 |
+
model = BertForTokenClassification.from_pretrained(
|
841 |
+
model_location,
|
842 |
+
num_labels=num_classes,
|
843 |
+
output_attentions=False,
|
844 |
+
output_hidden_states=False,
|
845 |
+
local_files_only=True,
|
846 |
+
)
|
847 |
+
if freeze_layers is not None:
|
848 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
849 |
+
for module in modules_to_freeze:
|
850 |
+
for param in module.parameters():
|
851 |
+
param.requires_grad = False
|
852 |
+
|
853 |
+
model = model.to(device)
|
854 |
+
|
855 |
+
# evaluate model
|
856 |
+
predictions = F.softmax(model(targets.to(device))["logits"], dim=-1).argmax(-1)[
|
857 |
+
0
|
858 |
+
]
|
859 |
+
predictions = [labels[int(pred)] for pred in predictions]
|
860 |
+
|
861 |
+
return predictions
|
862 |
+
|
863 |
+
# Extracts aggregate gene embeddings for each label
|
864 |
+
if emb_extract == True:
|
865 |
+
with open(f"{model_location}/classes.txt", "r") as f:
|
866 |
+
data = ast.literal_eval(f.read())
|
867 |
+
num_classes = data[0]
|
868 |
+
decode = data[1]
|
869 |
+
|
870 |
+
gene_classes = pd.read_csv(genes, header=0)
|
871 |
+
labels = gene_classes.columns
|
872 |
+
tokenize = TranscriptomeTokenizer()
|
873 |
+
|
874 |
+
label_dict = {}
|
875 |
+
for label in labels:
|
876 |
+
genes = gene_classes[label]
|
877 |
+
tokenized_genes = []
|
878 |
+
for gene in genes:
|
879 |
+
try:
|
880 |
+
tokenized_genes.append(tokenize.gene_token_dict[gene])
|
881 |
+
except:
|
882 |
+
continue
|
883 |
+
label_dict[label] = tokenized_genes
|
884 |
+
|
885 |
+
embex = EmbExtractor(
|
886 |
+
model_type="GeneClassifier",
|
887 |
+
num_classes=num_classes,
|
888 |
+
emb_mode="gene",
|
889 |
+
filter_data=None,
|
890 |
+
max_ncells=max_cells,
|
891 |
+
emb_layer=emb_layer,
|
892 |
+
emb_label=label_dict,
|
893 |
+
labels_to_plot=list(labels),
|
894 |
+
forward_batch_size=forward_batch,
|
895 |
+
nproc=num_cpus,
|
896 |
+
)
|
897 |
+
|
898 |
+
subprocess.call(f"mkdir -p {emb_dir}", shell=True)
|
899 |
+
|
900 |
+
embs = embex.extract_embs(
|
901 |
+
model_directory=model_location,
|
902 |
+
input_data_file=model_location / "dataset",
|
903 |
+
output_directory=emb_dir,
|
904 |
+
output_prefix=f"{label}_embbeddings",
|
905 |
+
)
|
906 |
+
|
907 |
+
emb_dict = {label: [] for label in list(set(labels))}
|
908 |
+
similarities = {key: {} for key in list(emb_dict.keys())}
|
909 |
+
|
910 |
+
for column in embs.columns:
|
911 |
+
remaining_cols = [k for k in embs.columns if k != column]
|
912 |
+
for k in remaining_cols:
|
913 |
+
embedding = torch.Tensor(embs[k])
|
914 |
+
sim = similarity(torch.Tensor(embs[column]), embedding, cosine=True)
|
915 |
+
similarities[column][k] = sim
|
916 |
+
|
917 |
+
plot_similarity_heatmap(similarities)
|
918 |
+
print(similarities)
|
919 |
+
|
920 |
+
return similarities
|
921 |
+
|
922 |
+
|
923 |
+
if __name__ == "__main__":
|
924 |
+
classify_genes(
|
925 |
+
k_validate=False,
|
926 |
+
inference=False,
|
927 |
+
skip_training=False,
|
928 |
+
emb_extract=True,
|
929 |
+
output_dir=Path("gene_emb"),
|
930 |
+
model_location=Path("gene_emb"),
|
931 |
+
epochs=5,
|
932 |
+
gene_info="../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_info_table.csv",
|
933 |
+
genes="../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
|
934 |
+
corpus_30M="../GeneFormer_repo/Genecorpus-30M/genecorpus_30M_2048.dataset/",
|
935 |
+
)
|
geneformer/modular_classifier_usage.md
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cell classifier
|
2 |
+
def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".",
|
3 |
+
dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/',
|
4 |
+
dataset_split = None,
|
5 |
+
filter_cells = .005,
|
6 |
+
epochs = 1,
|
7 |
+
cpu_cores = os.cpu_count(),
|
8 |
+
geneformer_batch_size = 12,
|
9 |
+
optimizer = 'adamw',
|
10 |
+
max_lr = 5e-5,
|
11 |
+
num_gpus = torch.cuda.device_count(),
|
12 |
+
max_input_size = 2 ** 11,
|
13 |
+
lr_schedule_fn = "linear",
|
14 |
+
warmup_steps = 500,
|
15 |
+
freeze_layers = 0,
|
16 |
+
emb_extract = False,
|
17 |
+
max_cells = 1000,
|
18 |
+
emb_layer = 0,
|
19 |
+
emb_filter = None,
|
20 |
+
emb_dir = 'embeddings',
|
21 |
+
overwrite = True,
|
22 |
+
label = "cell_type",
|
23 |
+
data_filter = None,
|
24 |
+
forward_batch = 200, model_location = None,
|
25 |
+
skip_training = False,
|
26 |
+
sample_data = 1,
|
27 |
+
inference = False,
|
28 |
+
optimize_hyperparameters = False,
|
29 |
+
output_dir = None):
|
30 |
+
|
31 |
+
'''
|
32 |
+
Primary Parameters
|
33 |
+
-------------------
|
34 |
+
dataset: path
|
35 |
+
Path to fine-tuning/testing dataset for training
|
36 |
+
|
37 |
+
model_location: path
|
38 |
+
Path to location of existing model to use for inference and embedding extraction
|
39 |
+
|
40 |
+
pretrained_model: path
|
41 |
+
Path to pretrained GeneFormer 30M model before fine-tuning
|
42 |
+
|
43 |
+
inference: bool
|
44 |
+
Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False
|
45 |
+
|
46 |
+
skip_training: bool
|
47 |
+
Chooses whether to skip training the model. Defaults to False
|
48 |
+
|
49 |
+
emb_extract: bool
|
50 |
+
Choose whether to extract embeddings and calculate similarities. Defaults to True
|
51 |
+
|
52 |
+
optimize_hyperparameters: bool
|
53 |
+
Choose whether to optimize model hyperparamters. Defaults to False
|
54 |
+
label: string
|
55 |
+
The label string in the formatted dataset that contains true class labels. Defaults to "label"
|
56 |
+
|
57 |
+
Customization Parameters
|
58 |
+
-------------------
|
59 |
+
|
60 |
+
dataset_split: str
|
61 |
+
How the dataset should be partitioned (if at all), and what ID should be used for partitioning
|
62 |
+
|
63 |
+
data_filter: list
|
64 |
+
(For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split
|
65 |
+
|
66 |
+
label: str
|
67 |
+
What feature should be read as a classification label
|
68 |
+
|
69 |
+
emb_layer: int
|
70 |
+
What layer embeddings should be extracted and compared from.
|
71 |
+
|
72 |
+
emb_filter: ['cell1', 'cell2'...]
|
73 |
+
Allows user to narrow down range of cells that embeddings will be extracted from.
|
74 |
+
|
75 |
+
max_cells: int
|
76 |
+
How many embeddings from cells should be extracted.
|
77 |
+
|
78 |
+
freeze_layers: int
|
79 |
+
Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model).
|
80 |
+
|
81 |
+
sample_data: float
|
82 |
+
What proportion of the HF dataset should be used
|
83 |
+
|
84 |
+
'''
|
85 |
+
|
86 |
+
# Gene Classifier
|
87 |
+
def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv",
|
88 |
+
genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
|
89 |
+
corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
|
90 |
+
max_input_size = 2 ** 11,
|
91 |
+
max_lr = 5e-5,
|
92 |
+
freeze_layers = 4,
|
93 |
+
num_gpus = 1,
|
94 |
+
num_proc = os.cpu_count(),
|
95 |
+
geneformer_batch_size = 9,
|
96 |
+
epochs = 1,
|
97 |
+
filter_dataset = 50_000,
|
98 |
+
emb_extract = True,
|
99 |
+
emb_layer = 0,
|
100 |
+
forward_batch = 200,
|
101 |
+
filter_data = None,
|
102 |
+
inference = False,
|
103 |
+
k_validate = True,
|
104 |
+
model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
|
105 |
+
skip_training = False,
|
106 |
+
emb_dir = 'gene_emb',
|
107 |
+
output_dir = None,
|
108 |
+
max_cells = 1000,
|
109 |
+
num_cpus = os.cpu_count()):
|
110 |
+
|
111 |
+
""""
|
112 |
+
Primary Parameters
|
113 |
+
-----------
|
114 |
+
|
115 |
+
gene_info: path
|
116 |
+
Path to gene mappings
|
117 |
+
|
118 |
+
corpus_30M: path
|
119 |
+
Path to 30M Gene Corpus
|
120 |
+
|
121 |
+
model: path
|
122 |
+
Path to pretrained GeneFormer model
|
123 |
+
|
124 |
+
genes: path
|
125 |
+
Path to csv file containing different columns of genes and the column labels
|
126 |
+
|
127 |
+
inference: bool
|
128 |
+
Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
|
129 |
+
|
130 |
+
k_validate: bool
|
131 |
+
Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
|
132 |
+
|
133 |
+
skip_training: bool
|
134 |
+
Whether the model should skip the training portion. Defaults to False
|
135 |
+
|
136 |
+
emb_extract: bool
|
137 |
+
WHether the model should extract embeddings for a given gene (WIP)
|
138 |
+
|
139 |
+
|
140 |
+
Customization Parameters
|
141 |
+
-----------
|
142 |
+
|
143 |
+
freeze_layers: int
|
144 |
+
Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
|
145 |
+
|
146 |
+
filter_dataset: int
|
147 |
+
Number of cells to filter from 30M dataset. Default is 50_000
|
148 |
+
|
149 |
+
emb_layer: int
|
150 |
+
What layer embeddings are extracted from. Default is 4
|
151 |
+
|
152 |
+
filter_data: str, list
|
153 |
+
Filters down embeddings to a single category. Default is None
|
154 |
+
|
155 |
+
|
156 |
+
"""
|