.pre-commit-config.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ # See https://pre-commit.com/hooks.html for more hooks
3
+ repos:
4
+ - repo: https://github.com/pre-commit/pre-commit-hooks
5
+ rev: v3.2.0
6
+ hooks:
7
+ - id: trailing-whitespace
8
+ - id: end-of-file-fixer
9
+ - id: check-yaml
10
+ - id: check-added-large-files
11
+ - id: check-merge-conflict
12
+ - id: mixed-line-ending
13
+ - id: check-docstring-first
14
+ - repo: https://github.com/pycqa/isort
15
+ rev: 5.12.0
16
+ hooks:
17
+ - id: isort
18
+ args: ["--profile", "black"]
19
+ - repo: https://github.com/astral-sh/ruff-pre-commit
20
+ # Ruff version.
21
+ rev: v0.1.4
22
+ hooks:
23
+ # Run the Ruff linter.
24
+ - id: ruff
25
+ # Run the Ruff formatter.
26
+ - id: ruff-format
geneformer/cell_classifier.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer cell classifier.
3
+
4
+ Usage:
5
+ from geneformer import classify_cells
6
+ classify_cells(
7
+ token_set=Path("geneformer/token_dictionary.pkl"),
8
+ median_set=Path("geneformer/gene_median_dictionary.pkl"),
9
+ pretrained_model=".",
10
+ dataset="Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/",
11
+ dataset_split=None,
12
+ filter_cells=0.005,
13
+ epochs=1,
14
+ cpu_cores=os.cpu_count(),
15
+ geneformer_batch_size=12,
16
+ optimizer="adamw",
17
+ max_lr=5e-5,
18
+ num_gpus=torch.cuda.device_count(),
19
+ max_input_size=2**11,
20
+ lr_schedule_fn="linear",
21
+ warmup_steps=500,
22
+ freeze_layers=0,
23
+ emb_extract=False,
24
+ max_cells=1000,
25
+ emb_layer=0,
26
+ emb_filter=None,
27
+ emb_dir="embeddings",
28
+ overwrite=True,
29
+ label="cell_type",
30
+ data_filter=None,
31
+ forward_batch=200,
32
+ model_location=None,
33
+ skip_training=False,
34
+ sample_data=1,
35
+ inference=False,
36
+ optimize_hyperparameters=False,
37
+ output_dir=None,
38
+ )
39
+ """
40
+
41
+ import ast
42
+ import datetime
43
+ import os
44
+ import pickle
45
+ import random
46
+ import subprocess
47
+ from collections import Counter
48
+ from pathlib import Path
49
+
50
+ import numpy as np
51
+ import seaborn as sns
52
+ import torch
53
+ import torch.nn.functional as F
54
+ from datasets import load_from_disk
55
+ from matplotlib import pyplot as plt
56
+ from ray import tune
57
+ from ray.tune.search.hyperopt import HyperOptSearch
58
+ from sklearn.metrics import accuracy_score
59
+ from sklearn.metrics import auc as precision_auc
60
+ from sklearn.metrics import f1_score, precision_recall_curve, roc_auc_score, roc_curve
61
+ from transformers import BertForSequenceClassification, Trainer
62
+ from transformers.training_args import TrainingArguments
63
+
64
+ from geneformer import DataCollatorForCellClassification, EmbExtractor
65
+
66
+ sns.set()
67
+
68
+ # Properly sets up NCCV environment
69
+ GPU_NUMBER = [i for i in range(torch.cuda.device_count())]
70
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
71
+ os.environ["NCCL_DEBUG"] = "INFO"
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
+
74
+
75
+ # Function for generating an ROC curve from data
76
+ def ROC(prediction, truth, type="GeneFormer", label=""):
77
+ fpr, tpr, _ = roc_curve(truth, prediction[:, 1])
78
+ auc = roc_auc_score(truth, prediction[:, 1])
79
+ print(f"{type} AUC: {auc}")
80
+ plt.plot(fpr, tpr, label="AUC=" + str(auc))
81
+ plt.ylabel("True Positive Rate")
82
+ plt.xlabel("False Positive Rate")
83
+ plt.title(f"{label} ROC Curve")
84
+ plt.legend(loc=4)
85
+ plt.savefig("ROC.png")
86
+
87
+ return tpr, fpr, auc
88
+
89
+
90
+ # Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
91
+ def similarity(tensor1, tensor2, cosine=False):
92
+ if cosine is False:
93
+ if tensor1.ndimension() > 1:
94
+ tensor1 = tensor1.view(1, -1)
95
+ if tensor2.ndimension() > 1:
96
+ tensor2 = tensor2.view(1, -1)
97
+ dot_product = torch.matmul(tensor1, tensor2)
98
+ norm_tensor1 = torch.norm(tensor1)
99
+ norm_tensor2 = torch.norm(tensor2)
100
+ epsilon = 1e-8
101
+ similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
102
+ similarity = (similarity.item() + 1) / 2
103
+ else:
104
+ if tensor1.shape != tensor2.shape:
105
+ raise ValueError("Input tensors must have the same shape.")
106
+
107
+ # Compute cosine similarity using PyTorch's dot product function
108
+ dot_product = torch.dot(tensor1, tensor2)
109
+ norm_tensor1 = torch.norm(tensor1)
110
+ norm_tensor2 = torch.norm(tensor2)
111
+
112
+ # Avoid division by zero by adding a small epsilon
113
+ epsilon = 1e-8
114
+ similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
115
+
116
+ return similarity.item()
117
+
118
+
119
+ # Plots heatmap between different classes/labels
120
+ def plot_similarity_heatmap(similarities):
121
+ classes = list(similarities.keys())
122
+ classlen = len(classes)
123
+ arr = np.zeros((classlen, classlen))
124
+ for i, c in enumerate(classes):
125
+ for j, cc in enumerate(classes):
126
+ if cc == c:
127
+ val = 1.0
128
+ else:
129
+ val = similarities[c][cc]
130
+ arr[i][j] = val
131
+
132
+ plt.figure(figsize=(8, 6))
133
+ plt.imshow(arr, cmap="inferno", vmin=0, vmax=1)
134
+ plt.colorbar()
135
+ plt.xticks(np.arange(classlen), classes, rotation=45, ha="right")
136
+ plt.yticks(np.arange(classlen), classes)
137
+ plt.title("Similarity Heatmap")
138
+ plt.savefig("similarity_heatmap.png")
139
+
140
+
141
+ def classify_cells(
142
+ token_set=Path("./token_dictionary.pkl"),
143
+ median_set=Path("./gene_median_dictionary.pkl"),
144
+ pretrained_model="../",
145
+ dataset="Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/",
146
+ dataset_split=None,
147
+ filter_cells=0.005,
148
+ epochs=1,
149
+ cpu_cores=os.cpu_count(),
150
+ training_batch_size=12,
151
+ optimizer="adamw",
152
+ max_lr=5e-5,
153
+ num_gpus=torch.cuda.device_count(),
154
+ max_input_size=2**11,
155
+ lr_schedule_fn="linear",
156
+ warmup_steps=500,
157
+ freeze_layers=0,
158
+ emb_extract=False,
159
+ max_cells=None,
160
+ emb_layer=-1,
161
+ emb_filter=None,
162
+ emb_dir="embeddings",
163
+ overwrite=False,
164
+ label="cell_type",
165
+ data_filter=None,
166
+ inference_batch_size=200,
167
+ finetuned_model=None,
168
+ skip_training=False,
169
+ sample_data=1,
170
+ inference=False,
171
+ optimize_hyperparameters=True,
172
+ output_dir=None,
173
+ ):
174
+ """
175
+ Primary Parameters
176
+ -------------------
177
+ dataset: path
178
+ Path to fine-tuning dataset for training
179
+
180
+ finetuned_model: path
181
+ Path to location of fine-tuned model to use for inference and embedding extraction
182
+
183
+ pretrained_model: path
184
+ Path to pretrained Geneformer model
185
+
186
+ inference: bool
187
+ Indicates whether to perform inference and return a list of similarities. Defaults to False.
188
+
189
+ skip_training: bool
190
+ Indicates whether to skip training the model. Defaults to False.
191
+
192
+ emb_extract: bool
193
+ Indicates whether to extract embeddings and calculate similarities. Defaults to True.
194
+
195
+ optimize_hyperparameters: bool
196
+ Indicates whether to optimize model hyperparamters. Defaults to False.
197
+
198
+
199
+ Customization Parameters
200
+ -------------------
201
+
202
+ dataset_split: str
203
+ Indicates how the dataset should be partitioned (if at all), and what ID should be used for partitioning
204
+
205
+ data_filter: list
206
+ (For embeddings and inference) Runs analysis on subsets of the dataset based on the ID defined by dataset_split
207
+
208
+ label: str
209
+ Feature to read as a classification label.
210
+
211
+ emb_layer: int
212
+ What layer embeddings should be extracted and compared.
213
+
214
+ emb_filter: ['cell1', 'cell2'...]
215
+ Allows user to narrow down range of cells that embeddings will be extracted from.
216
+
217
+ max_cells: int
218
+ Max number of cells to use for embedding extraction.
219
+
220
+ freeze_layers: int
221
+ Number of layers that should be frozen during fine-tuning.
222
+
223
+ sample_data: float
224
+ Proportion of the dataset that should be used.
225
+
226
+ """
227
+
228
+ dataset_list = []
229
+ evalset_list = []
230
+ split_list = []
231
+ target_dict_list = []
232
+
233
+ train_dataset = load_from_disk(dataset)
234
+ num_samples = int(len(train_dataset) * sample_data)
235
+ random_indices = random.sample(range(len(train_dataset)), num_samples)
236
+ train_dataset = train_dataset.select(random_indices)
237
+
238
+ sample = int(sample_data * len(train_dataset))
239
+ sample_indices = random.sample(range(len(train_dataset)), sample)
240
+ train_dataset = train_dataset.select(sample_indices)
241
+
242
+ def if_not_rare_cell_state(example):
243
+ return example[label] in cells_to_keep
244
+
245
+ # change labels to numerical ids
246
+ def classes_to_ids(example):
247
+ example["label"] = target_name_id_dict[example["label"]]
248
+ return example
249
+
250
+ def if_trained_label(example):
251
+ return example["label"] in trained_labels
252
+
253
+ if skip_training is not True:
254
+
255
+ def compute_metrics(pred):
256
+ labels = pred.label_ids
257
+ preds = pred.predictions.argmax(-1)
258
+ # calculate accuracy and macro f1 using sklearn's function
259
+ acc = accuracy_score(labels, preds)
260
+ macro_f1 = f1_score(labels, preds, average="macro")
261
+ return {"accuracy": acc, "macro_f1": macro_f1}
262
+
263
+ # Defines custom exceptions for collecting labels (default excluded)
264
+ excep = {"bone_marrow": "immune"}
265
+
266
+ if dataset_split is not None:
267
+ if data_filter is not None:
268
+ split_iter = [data_filter]
269
+ else:
270
+ split_iter = Counter(train_dataset[dataset_split]).keys()
271
+ for lab in split_iter:
272
+ # collect list of tissues for fine-tuning (immune and bone marrow are included together)
273
+ if lab in list(excep.keys()):
274
+ continue
275
+ elif lab == list(excep.values()):
276
+ split_ids = [excep.keys(), excep.values()]
277
+ split_list += [excep.values()]
278
+ else:
279
+ split_ids = [lab]
280
+ split_list += [lab]
281
+
282
+ # filter datasets for given organ
283
+ def if_label(example):
284
+ return example[dataset_split] == lab
285
+
286
+ trainset_label = train_dataset.filter(if_label, num_proc=cpu_cores)
287
+ label_counter = Counter(trainset_label[label])
288
+ total_cells = sum(label_counter.values())
289
+
290
+ # excludes cells with a low proportion in the dataset
291
+ cells_to_keep = [
292
+ k
293
+ for k, v in label_counter.items()
294
+ if v > (filter_cells * total_cells)
295
+ ]
296
+ trainset_label_subset = trainset_label.filter(
297
+ if_not_rare_cell_state, num_proc=cpu_cores
298
+ )
299
+
300
+ # shuffle datasets and rename columns
301
+ trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
302
+ trainset_label_shuffled = trainset_label_shuffled.rename_column(
303
+ label, "label"
304
+ )
305
+ trainset_label_shuffled = trainset_label_shuffled.remove_columns(
306
+ dataset_split
307
+ )
308
+
309
+ # create dictionary of cell types : label ids
310
+ target_names = list(Counter(trainset_label_shuffled["label"]).keys())
311
+ target_name_id_dict = dict(
312
+ zip(target_names, [i for i in range(len(target_names))])
313
+ )
314
+ target_dict_list += [target_name_id_dict]
315
+
316
+ labeled_trainset = trainset_label_shuffled.map(
317
+ classes_to_ids, num_proc=cpu_cores
318
+ )
319
+
320
+ # create 80/20 train/eval splits
321
+ labeled_train_split = trainset_label_shuffled.select(
322
+ [i for i in range(0, round(len(labeled_trainset) * 0.8))]
323
+ )
324
+ labeled_eval_split = trainset_label_shuffled.select(
325
+ [
326
+ i
327
+ for i in range(
328
+ round(len(labeled_trainset) * 0.8), len(labeled_trainset)
329
+ )
330
+ ]
331
+ )
332
+
333
+ # filter dataset for cell types in corresponding training set
334
+ trained_labels = list(Counter(labeled_train_split["label"]).keys())
335
+
336
+ labeled_eval_split_subset = labeled_eval_split.filter(
337
+ if_trained_label, num_proc=cpu_cores
338
+ )
339
+
340
+ dataset_list += [labeled_train_split]
341
+ evalset_list += [labeled_eval_split_subset]
342
+
343
+ trainset_dict = dict(zip(split_list, dataset_list))
344
+ traintargetdict_dict = dict(zip(split_list, target_dict_list))
345
+ evalset_dict = dict(zip(split_list, evalset_list))
346
+
347
+ for lab in split_list:
348
+ label_trainset = trainset_dict[lab]
349
+ label_evalset = evalset_dict[lab]
350
+ label_dict = traintargetdict_dict[lab]
351
+
352
+ # set logging steps
353
+ logging_steps = round(len(label_trainset) / training_batch_size / 10)
354
+ if logging_steps == 0:
355
+ logging_steps = 1
356
+
357
+ # load pretrained model
358
+ model = BertForSequenceClassification.from_pretrained(
359
+ pretrained_model,
360
+ num_labels=len(label_dict.keys()),
361
+ output_attentions=False,
362
+ output_hidden_states=False,
363
+ ).to(device)
364
+
365
+ # define output directory path
366
+ current_date = datetime.datetime.now()
367
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
368
+
369
+ if output_dir is None:
370
+ output_dir = f"{datestamp}_geneformer_CellClassifier_{lab}_L{max_input_size}_B{training_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
371
+
372
+ # ensure not overwriting previously saved model
373
+ saved_model_test = os.path.join(output_dir, "pytorch_model.bin")
374
+
375
+ if os.path.isfile(saved_model_test) is True and overwrite is False:
376
+ raise Exception("Model already saved to this directory.")
377
+
378
+ # make output directory
379
+ subprocess.call(f"mkdir -p {output_dir}", shell=True)
380
+
381
+ # set training arguments
382
+ training_args = {
383
+ "learning_rate": max_lr,
384
+ "do_train": True,
385
+ "do_eval": True,
386
+ "evaluation_strategy": "epoch",
387
+ "save_strategy": "epoch",
388
+ "logging_steps": logging_steps,
389
+ "group_by_length": True,
390
+ "length_column_name": "length",
391
+ "disable_tqdm": False,
392
+ "lr_scheduler_type": lr_schedule_fn,
393
+ "warmup_steps": warmup_steps,
394
+ "weight_decay": 0.001,
395
+ "per_device_train_batch_size": training_batch_size,
396
+ "per_device_eval_batch_size": training_batch_size,
397
+ "num_train_epochs": epochs,
398
+ "load_best_model_at_end": True,
399
+ "output_dir": output_dir,
400
+ }
401
+
402
+ training_args_init = TrainingArguments(**training_args)
403
+ true_labels = label_evalset["label"]
404
+
405
+ if optimize_hyperparameters is False:
406
+ # create the trainer
407
+ trainer = Trainer(
408
+ model=model,
409
+ args=training_args_init,
410
+ data_collator=DataCollatorForCellClassification(),
411
+ train_dataset=label_trainset,
412
+ eval_dataset=label_evalset,
413
+ compute_metrics=compute_metrics,
414
+ )
415
+
416
+ # train the cell type classifier
417
+ trainer.train()
418
+ predictions = trainer.predict(label_evalset)
419
+ print(
420
+ f'accuracy: {accuracy_score(predictions.argmax(), label_evalset["labels"])}'
421
+ )
422
+
423
+ tpr, fpr, auc = ROC(predictions.predictions, true_labels)
424
+
425
+ metrics = compute_metrics(predictions)
426
+ with open(f"{output_dir}predictions.pickle", "wb") as fp:
427
+ pickle.dump(predictions, fp)
428
+
429
+ trainer.save_metrics("eval", predictions.metrics)
430
+
431
+ with open(f"{output_dir}/targets.txt", "w") as f:
432
+ if len(target_dict_list) == 1:
433
+ f.write(str(target_dict_list[0]))
434
+ else:
435
+ f.write(str(target_dict_list))
436
+
437
+ try:
438
+ precision, recall, _ = precision_recall_curve(
439
+ true_labels, predictions.predictions[:, 1]
440
+ )
441
+ pr_auc = precision_auc(recall, precision)
442
+
443
+ print(f"AUC: {pr_auc}")
444
+ return recall, precision, pr_auc
445
+ except:
446
+ pass
447
+
448
+ trainer.save_model(output_dir)
449
+ else:
450
+
451
+ def model_init():
452
+ model = BertForSequenceClassification.from_pretrained(
453
+ pretrained_model,
454
+ num_labels=len(label_dict.keys()),
455
+ output_attentions=False,
456
+ output_hidden_states=False,
457
+ )
458
+ if freeze_layers is not None:
459
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
460
+ for module in modules_to_freeze:
461
+ for param in module.parameters():
462
+ param.requires_grad = False
463
+ model = model.to(device)
464
+ return model
465
+
466
+ trainer = Trainer(
467
+ model_init=model_init,
468
+ args=training_args_init,
469
+ data_collator=DataCollatorForCellClassification(),
470
+ train_dataset=label_trainset,
471
+ eval_dataset=label_evalset,
472
+ compute_metrics=compute_metrics,
473
+ )
474
+ # specify raytune hyperparameter search space
475
+ ray_config = {
476
+ "num_train_epochs": tune.choice([epochs]),
477
+ "learning_rate": tune.loguniform(1e-6, 1e-3),
478
+ "weight_decay": tune.uniform(0.0, 0.3),
479
+ "lr_scheduler_type": tune.choice(
480
+ ["linear", "cosine", "polynomial"]
481
+ ),
482
+ "warmup_steps": tune.uniform(100, 2000),
483
+ "seed": tune.uniform(0, 100),
484
+ "per_device_train_batch_size": tune.choice(
485
+ [training_batch_size]
486
+ ),
487
+ }
488
+
489
+ hyperopt_search = HyperOptSearch(metric="eval_accuracy", mode="max")
490
+
491
+ if torch.device == "cuda":
492
+ resources_per_trial = ({"cpu": 8, "gpu": 1},)
493
+ else:
494
+ resources_per_trial = {"cpu": 8}
495
+
496
+ # optimize hyperparameters
497
+ best_trial = trainer.hyperparameter_search(
498
+ direction="maximize",
499
+ backend="ray",
500
+ resources_per_trial=resources_per_trial,
501
+ hp_space=lambda _: ray_config,
502
+ search_alg=hyperopt_search,
503
+ n_trials=100, # number of trials
504
+ progress_reporter=tune.CLIReporter(
505
+ max_report_frequency=600,
506
+ sort_by_metric=True,
507
+ max_progress_rows=100,
508
+ mode="max",
509
+ metric="eval_accuracy",
510
+ metric_columns=["loss", "eval_loss", "eval_accuracy"],
511
+ ),
512
+ )
513
+ best_hyperparameters = best_trial.hyperparameters
514
+
515
+ print("Best Hyperparameters:")
516
+ print(best_hyperparameters)
517
+
518
+ else:
519
+ trainset_label = train_dataset
520
+ label_counter = Counter(trainset_label[label])
521
+ total_cells = sum(label_counter.values())
522
+
523
+ # Excludes cells with a low proportion in the dataset
524
+ cells_to_keep = [
525
+ k for k, v in label_counter.items() if v > (filter_cells * total_cells)
526
+ ]
527
+ trainset_label_subset = trainset_label.filter(
528
+ if_not_rare_cell_state, num_proc=cpu_cores
529
+ )
530
+
531
+ # shuffle datasets and rename columns
532
+ trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
533
+ trainset_label_shuffled = trainset_label_shuffled.rename_column(
534
+ label, "label"
535
+ )
536
+
537
+ # create dictionary of cell types : label ids
538
+ target_names = list(Counter(trainset_label_shuffled["label"]).keys())
539
+ target_name_id_dict = dict(
540
+ zip(target_names, [i for i in range(len(target_names))])
541
+ )
542
+ target_dict_list = target_name_id_dict
543
+
544
+ labeled_trainset = trainset_label_shuffled.map(
545
+ classes_to_ids, num_proc=cpu_cores
546
+ )
547
+
548
+ # create 80/20 train/eval splits
549
+ labeled_train_split = labeled_trainset.select(
550
+ [i for i in range(0, round(len(labeled_trainset) * 0.8))]
551
+ )
552
+ labeled_eval_split = labeled_trainset.select(
553
+ [
554
+ i
555
+ for i in range(
556
+ round(len(labeled_trainset) * 0.8), len(labeled_trainset)
557
+ )
558
+ ]
559
+ )
560
+
561
+ # filter dataset for cell types in corresponding training set
562
+ trained_labels = list(Counter(labeled_train_split["label"]).keys())
563
+ labeled_eval_split_subset = labeled_eval_split.filter(
564
+ if_trained_label, num_proc=cpu_cores
565
+ )
566
+
567
+ # set logging steps
568
+ logging_steps = round(len(trainset_label) / training_batch_size / 10)
569
+
570
+ # load pretrained model
571
+ model = BertForSequenceClassification.from_pretrained(
572
+ pretrained_model,
573
+ num_labels=len(target_dict_list.keys()),
574
+ output_attentions=False,
575
+ output_hidden_states=False,
576
+ ).to(device)
577
+ # define output directory path
578
+ current_date = datetime.datetime.now()
579
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
580
+
581
+ if output_dir is None:
582
+ output_dir = f"{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{training_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
583
+
584
+ # ensure not overwriting previously saved model
585
+ saved_model_test = os.path.join(output_dir, "pytorch_model.bin")
586
+ if os.path.isfile(saved_model_test) is True and overwrite is False:
587
+ raise Exception("Model already saved to this directory.")
588
+
589
+ # make output directory
590
+ subprocess.call(f"mkdir -p {output_dir}", shell=True)
591
+
592
+ # set training arguments
593
+ training_args = {
594
+ "learning_rate": max_lr,
595
+ "do_train": True,
596
+ "do_eval": True,
597
+ "evaluation_strategy": "epoch",
598
+ "save_strategy": "epoch",
599
+ "logging_steps": logging_steps,
600
+ "group_by_length": True,
601
+ "length_column_name": "length",
602
+ "disable_tqdm": False,
603
+ "lr_scheduler_type": lr_schedule_fn,
604
+ "warmup_steps": warmup_steps,
605
+ "weight_decay": 0.001,
606
+ "per_device_train_batch_size": training_batch_size,
607
+ "per_device_eval_batch_size": training_batch_size,
608
+ "num_train_epochs": epochs,
609
+ "load_best_model_at_end": True,
610
+ "output_dir": output_dir,
611
+ }
612
+
613
+ training_args_init = TrainingArguments(**training_args)
614
+ true_labels = labeled_eval_split_subset["label"]
615
+
616
+ if optimize_hyperparameters is False:
617
+ # create the trainer
618
+ trainer = Trainer(
619
+ model=model,
620
+ args=training_args_init,
621
+ data_collator=DataCollatorForCellClassification(),
622
+ train_dataset=labeled_train_split,
623
+ eval_dataset=labeled_eval_split_subset,
624
+ compute_metrics=compute_metrics,
625
+ )
626
+
627
+ # train the cell type classifier
628
+ trainer.train()
629
+ predictions = trainer.predict(labeled_eval_split_subset)
630
+ predictions_tensor = torch.Tensor(predictions.predictions)
631
+ predicted_labels = torch.argmax(predictions_tensor, dim=1)
632
+ print(
633
+ f'accuracy: {accuracy_score(predicted_labels, labeled_eval_split_subset["label"])}'
634
+ )
635
+ metrics = compute_metrics(predictions)
636
+
637
+ with open(f"{output_dir}predictions.pickle", "wb") as fp:
638
+ pickle.dump(predictions.predictions.argmax(-1), fp)
639
+
640
+ trainer.save_metrics("eval", predictions.metrics)
641
+ trainer.save_model(output_dir)
642
+
643
+ # Saves label conversion dictionary to output directory
644
+ with open(f"{output_dir}/targets.txt", "w") as f:
645
+ f.write(str(target_dict_list))
646
+
647
+ try:
648
+ precision, recall, _ = precision_recall_curve(
649
+ true_labels, predictions.predictions[:, 1]
650
+ )
651
+ pr_auc = precision_auc(recall, precision)
652
+
653
+ print(f"AUC: {pr_auc}")
654
+ return recall, precision, pr_auc
655
+ except:
656
+ pass
657
+
658
+ else:
659
+ # Optimizes hyperparameters
660
+
661
+ num_classes = len(list(set(labeled_train_split["label"])))
662
+
663
+ def model_init():
664
+ model = BertForSequenceClassification.from_pretrained(
665
+ pretrained_model,
666
+ num_labels=num_classes,
667
+ output_attentions=False,
668
+ output_hidden_states=False,
669
+ )
670
+
671
+ if freeze_layers is not None:
672
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
673
+ for module in modules_to_freeze:
674
+ for param in module.parameters():
675
+ param.requires_grad = False
676
+ model = model.to(device)
677
+ return model
678
+
679
+ # create the trainer
680
+ trainer = Trainer(
681
+ model_init=model_init,
682
+ args=training_args_init,
683
+ data_collator=DataCollatorForCellClassification(),
684
+ train_dataset=labeled_train_split,
685
+ eval_dataset=labeled_eval_split_subset,
686
+ compute_metrics=compute_metrics,
687
+ )
688
+
689
+ # specify raytune hyperparameter search space
690
+ ray_config = {
691
+ "num_train_epochs": tune.choice([epochs]),
692
+ "learning_rate": tune.loguniform(1e-6, 1e-3),
693
+ "weight_decay": tune.uniform(0.0, 0.3),
694
+ "lr_scheduler_type": tune.choice(
695
+ ["linear", "cosine", "polynomial"]
696
+ ),
697
+ "warmup_steps": tune.uniform(100, 2000),
698
+ "seed": tune.uniform(0, 100),
699
+ "per_device_train_batch_size": tune.choice([training_batch_size]),
700
+ }
701
+
702
+ hyperopt_search = HyperOptSearch(metric="eval_accuracy", mode="max")
703
+
704
+ if torch.device == "cuda":
705
+ resources_per_trial = ({"cpu": 8, "gpu": 1},)
706
+ else:
707
+ resources_per_trial = {"cpu": 8}
708
+
709
+ # optimize hyperparameters
710
+ best_trial = trainer.hyperparameter_search(
711
+ direction="maximize",
712
+ backend="ray",
713
+ resources_per_trial=resources_per_trial,
714
+ hp_space=lambda _: ray_config,
715
+ search_alg=hyperopt_search,
716
+ n_trials=100, # number of trials
717
+ progress_reporter=tune.CLIReporter(
718
+ max_report_frequency=600,
719
+ sort_by_metric=True,
720
+ max_progress_rows=100,
721
+ mode="max",
722
+ metric="eval_accuracy",
723
+ metric_columns=["loss", "eval_loss", "eval_accuracy"],
724
+ ),
725
+ )
726
+ best_hyperparameters = best_trial.hyperparameters
727
+
728
+ print("Best Hyperparameters:")
729
+ print(best_hyperparameters)
730
+
731
+ # Performs Inference with model
732
+ if inference is True:
733
+ if dataset_split is not None and data_filter is not None:
734
+
735
+ def if_label(example):
736
+ return example[dataset_split] == data_filter
737
+
738
+ train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
739
+
740
+ trainset_label_shuffled = train_dataset
741
+ total_cells = len(trainset_label_shuffled)
742
+
743
+ # loads dictionary of all cell labels model was trained on
744
+ with open(Path(finetuned_model) / "targets.txt", "r") as f:
745
+ data = ast.literal_eval(f.read())
746
+ if dataset_split is not None and data_filter is None:
747
+ indexer = dataset_split.index(data_filter)
748
+ data = data[indexer]
749
+
750
+ target_dict_list = {key: value for key, value in enumerate(data)}
751
+
752
+ # set logging steps
753
+ logging_steps = round(len(trainset_label_shuffled) / training_batch_size / 20)
754
+
755
+ # load pretrained model
756
+ input_ids = trainset_label_shuffled["input_ids"]
757
+ inputs = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
758
+ attention = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
759
+
760
+ for i, sentence in enumerate(input_ids):
761
+ sentence_length = len(sentence)
762
+ if sentence_length <= max_input_size:
763
+ inputs[i, :sentence_length] = torch.tensor(sentence)
764
+ attention[i, :sentence_length] = torch.ones(sentence_length)
765
+ else:
766
+ inputs[i, :] = torch.tensor(sentence[:max_input_size])
767
+ attention[i, :] = torch.ones(max_input_size)
768
+
769
+ model = BertForSequenceClassification.from_pretrained(
770
+ finetuned_model, num_labels=len(target_dict_list)
771
+ ).to(device)
772
+ model_outputs = model(inputs.to(device), attention_mask=attention)["logits"]
773
+ predictions = F.softmax(model_outputs, dim=-1).argmax(-1)
774
+
775
+ predictions = [target_dict_list[int(pred)] for pred in predictions]
776
+
777
+ return predictions
778
+
779
+ # Extracts embeddings from labeled data
780
+ if emb_extract is True:
781
+ if emb_filter is None:
782
+ with open(f"{finetuned_model}/targets.txt", "r") as f:
783
+ data = ast.literal_eval(f.read())
784
+ if dataset_split is not None and data_filter is None:
785
+ indexer = dataset_split.index(data_filter)
786
+ data = data[indexer]
787
+
788
+ target_dict_list = {key: value for key, value in enumerate(data)}
789
+ total_filter = None
790
+ else:
791
+ total_filter = emb_filter
792
+
793
+ train_dataset = load_from_disk(dataset)
794
+ if dataset_split is not None:
795
+
796
+ def if_label(example):
797
+ return example[dataset_split] == data_filter
798
+
799
+ train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
800
+
801
+ label_counter = Counter(train_dataset[label])
802
+ total_cells = sum(label_counter.values())
803
+ cells_to_keep = [
804
+ k for k, v in label_counter.items() if v > (filter_cells * total_cells)
805
+ ]
806
+
807
+ def if_not_rare(example):
808
+ return example[label] in cells_to_keep
809
+
810
+ train_dataset = train_dataset.filter(if_not_rare, num_proc=cpu_cores)
811
+
812
+ true_labels = train_dataset[label]
813
+ num_classes = len(list(set(true_labels)))
814
+
815
+ embex = EmbExtractor(
816
+ model_type="CellClassifier",
817
+ num_classes=num_classes,
818
+ filter_data=total_filter,
819
+ max_ncells=max_cells,
820
+ emb_layer=emb_layer,
821
+ emb_label=[dataset_split, label],
822
+ labels_to_plot=[label],
823
+ forward_batch_size=inference_batch_size,
824
+ nproc=cpu_cores,
825
+ )
826
+
827
+ # example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
828
+ subprocess.call(f"mkdir -p {emb_dir}", shell=True)
829
+
830
+ embs = embex.extract_embs(
831
+ model_directory=finetuned_model,
832
+ input_data_file=dataset,
833
+ output_directory=emb_dir,
834
+ output_prefix=f"{label}_embeddings",
835
+ )
836
+ true_labels = embex.filtered_input_data[label]
837
+
838
+ emb_dict = {label: [] for label in list(set(true_labels))}
839
+ for num, emb in embs.iterrows():
840
+ key = emb[label]
841
+ selection = emb.iloc[:255]
842
+ emb = torch.Tensor(selection)
843
+ emb_dict[key].append(emb)
844
+
845
+ for key in list(emb_dict.keys()):
846
+ stack = torch.stack(emb_dict[key], dim=0)
847
+ emb_dict[key] = torch.mean(stack, dim=0)
848
+ similarities = {key: {} for key in list(emb_dict.keys())}
849
+
850
+ for key in list(emb_dict.keys()):
851
+ remaining_keys = [k for k in list(emb_dict.keys()) if k != key]
852
+ for k in remaining_keys:
853
+ embedding = emb_dict[k]
854
+ sim = similarity(emb_dict[key], embedding, cosine=True)
855
+
856
+ similarities[key][k] = sim
857
+
858
+ plot_similarity_heatmap(similarities)
859
+
860
+ embex.plot_embs(
861
+ embs=embs,
862
+ plot_style="umap",
863
+ output_directory=emb_dir,
864
+ output_prefix="emb_plot",
865
+ )
866
+
867
+ embex.plot_embs(
868
+ embs=embs,
869
+ plot_style="heatmap",
870
+ output_directory=emb_dir,
871
+ output_prefix="emb_plot",
872
+ )
873
+
874
+ return similarities
geneformer/gene_classifier.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ GPU_NUMBER = [0] # CHANGE WITH MULTIGPU
5
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
6
+ os.environ["NCCL_DEBUG"] = "INFO"
7
+
8
+ import ast
9
+ import datetime
10
+ import math
11
+ import pickle
12
+ import subprocess
13
+ from pathlib import Path
14
+
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+ import pandas as pd
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from datasets import Dataset, load_from_disk
21
+ from sklearn import preprocessing
22
+ from sklearn.metrics import (
23
+ ConfusionMatrixDisplay,
24
+ accuracy_score,
25
+ auc,
26
+ confusion_matrix,
27
+ roc_auc_score,
28
+ roc_curve,
29
+ )
30
+
31
+ # imports
32
+ from sklearn.model_selection import StratifiedKFold, train_test_split
33
+ from tqdm.notebook import tqdm
34
+ from transformers import BertForTokenClassification, Trainer
35
+ from transformers.training_args import TrainingArguments
36
+
37
+ from geneformer import DataCollatorForGeneClassification, EmbExtractor
38
+ from geneformer.pretrainer import token_dictionary
39
+
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ from geneformer import TranscriptomeTokenizer
42
+
43
+
44
+ def vote(logit_pair):
45
+ a, b = logit_pair
46
+ if a > b:
47
+ return 0
48
+ elif b > a:
49
+ return 1
50
+ elif a == b:
51
+ return "tie"
52
+
53
+
54
+ def py_softmax(vector):
55
+ e = np.exp(vector)
56
+ return e / e.sum()
57
+
58
+ # Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
59
+
60
+
61
+ def similarity(tensor1, tensor2, cosine=True):
62
+ if cosine == False:
63
+ if tensor1.ndimension() > 1:
64
+ tensor1 = tensor1.view(1, -1)
65
+ if tensor2.ndimension() > 1:
66
+ tensor2 = tensor2.view(1, -1)
67
+ dot_product = torch.matmul(tensor1, tensor2)
68
+ norm_tensor1 = torch.norm(tensor1)
69
+ norm_tensor2 = torch.norm(tensor2)
70
+ epsilon = 1e-8
71
+ similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
72
+ similarity = (similarity.item() + 1) / 2
73
+ else:
74
+ if tensor1.shape != tensor2.shape:
75
+ raise ValueError("Input tensors must have the same shape.")
76
+
77
+ # Compute cosine similarity using PyTorch's dot product function
78
+ dot_product = torch.dot(tensor1, tensor2)
79
+ norm_tensor1 = torch.norm(tensor1)
80
+ norm_tensor2 = torch.norm(tensor2)
81
+
82
+ # Avoid division by zero by adding a small epsilon
83
+ epsilon = 1e-8
84
+ similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
85
+
86
+ return similarity.item()
87
+
88
+
89
+ # Plots heatmap between different classes/labels
90
+ def plot_similarity_heatmap(similarities):
91
+ classes = list(similarities.keys())
92
+ classlen = len(classes)
93
+ arr = np.zeros((classlen, classlen))
94
+ for i, c in enumerate(classes):
95
+ for j, cc in enumerate(classes):
96
+ if cc == c:
97
+ val = 1.0
98
+ else:
99
+ val = similarities[c][cc]
100
+ arr[i][j] = val
101
+
102
+ plt.figure(figsize=(8, 6))
103
+ plt.imshow(arr, cmap="inferno", vmin=0, vmax=1)
104
+ plt.colorbar()
105
+ plt.xticks(np.arange(classlen), classes, rotation=45, ha="right")
106
+ plt.yticks(np.arange(classlen), classes)
107
+ plt.title("Similarity Heatmap")
108
+ plt.savefig("similarity_heatmap.png")
109
+
110
+
111
+ # get cross-validated mean and sd metrics
112
+ def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
113
+ wts = [count / sum(all_tpr_wt) for count in all_tpr_wt]
114
+
115
+ all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)]
116
+ mean_tpr = np.sum(all_weighted_tpr, axis=0)
117
+ mean_tpr[-1] = 1.0
118
+ all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)]
119
+ roc_auc = np.sum(all_weighted_roc_auc)
120
+ roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts))
121
+ return mean_tpr, roc_auc, roc_auc_sd
122
+
123
+
124
+ def validate(
125
+ data,
126
+ targets,
127
+ labels,
128
+ nsplits,
129
+ subsample_size,
130
+ training_args,
131
+ freeze_layers,
132
+ output_dir,
133
+ num_proc,
134
+ num_labels,
135
+ pre_model,
136
+ ):
137
+ # initiate eval metrics to return
138
+ num_classes = len(set(labels))
139
+ mean_fpr = np.linspace(0, 1, 100)
140
+
141
+ # create 80/20 train/eval splits
142
+ targets_train, targets_eval, labels_train, labels_eval = train_test_split(
143
+ targets, labels, test_size=0.25, shuffle=True
144
+ )
145
+ label_dict_train = dict(zip(targets_train, labels_train))
146
+ label_dict_eval = dict(zip(targets_eval, labels_eval))
147
+
148
+ # function to filter by whether contains train or eval labels
149
+ def if_contains_train_label(example):
150
+ a = label_dict_train.keys()
151
+ b = example["input_ids"]
152
+ return not set(a).isdisjoint(b)
153
+
154
+ def if_contains_eval_label(example):
155
+ a = label_dict_eval.keys()
156
+ b = example["input_ids"]
157
+ return not set(a).isdisjoint(b)
158
+
159
+ # filter dataset for examples containing classes for this split
160
+ print(f"Filtering training data")
161
+ trainset = data.filter(if_contains_train_label, num_proc=num_proc)
162
+ print(
163
+ f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n"
164
+ )
165
+ print(f"Filtering evalation data")
166
+ evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
167
+ print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
168
+
169
+ # minimize to smaller training sample
170
+ training_size = min(subsample_size, len(trainset))
171
+ trainset_min = trainset.select([i for i in range(training_size)])
172
+ eval_size = min(training_size, len(evalset))
173
+ half_training_size = round(eval_size / 2)
174
+ evalset_train_min = evalset.select([i for i in range(half_training_size)])
175
+ evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
176
+
177
+ # label conversion functions
178
+ def generate_train_labels(example):
179
+ example["labels"] = [
180
+ label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
181
+ ]
182
+ return example
183
+
184
+ def generate_eval_labels(example):
185
+ example["labels"] = [
186
+ label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
187
+ ]
188
+ return example
189
+
190
+ # label datasets
191
+ print(f"Labeling training data")
192
+ trainset_labeled = trainset_min.map(generate_train_labels)
193
+ print(f"Labeling evaluation data")
194
+ evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
195
+ print(f"Labeling evaluation OOS data")
196
+ evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
197
+
198
+ # load model
199
+ model = BertForTokenClassification.from_pretrained(
200
+ pre_model,
201
+ num_labels=num_labels,
202
+ output_attentions=False,
203
+ output_hidden_states=False,
204
+ )
205
+ if freeze_layers is not None:
206
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
207
+ for module in modules_to_freeze:
208
+ for param in module.parameters():
209
+ param.requires_grad = False
210
+
211
+ model = model.to(device)
212
+
213
+ # add output directory to training args and initiate
214
+ training_args["output_dir"] = output_dir
215
+ training_args_init = TrainingArguments(**training_args)
216
+
217
+ # create the trainer
218
+ trainer = Trainer(
219
+ model=model,
220
+ args=training_args_init,
221
+ data_collator=DataCollatorForGeneClassification(),
222
+ train_dataset=trainset_labeled,
223
+ eval_dataset=evalset_train_labeled,
224
+ )
225
+
226
+ # train the gene classifier
227
+ trainer.train()
228
+ trainer.save_model(output_dir)
229
+
230
+ fpr, tpr, interp_tpr, conf_mat = classifier_predict(
231
+ trainer.model, evalset_oos_labeled, 200, mean_fpr
232
+ )
233
+ auc_score = auc(fpr, tpr)
234
+
235
+ return fpr, tpr, auc_score
236
+
237
+
238
+ # cross-validate gene classifier
239
+ def cross_validate(
240
+ data,
241
+ targets,
242
+ labels,
243
+ nsplits,
244
+ subsample_size,
245
+ training_args,
246
+ freeze_layers,
247
+ output_dir,
248
+ num_proc,
249
+ num_labels,
250
+ pre_model,
251
+ ):
252
+ # check if output directory already written to
253
+ # ensure not overwriting previously saved model
254
+ model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin")
255
+ # if os.path.isfile(model_dir_test) == True:
256
+ # raise Exception("Model already saved to this directory.")
257
+
258
+ device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
259
+ # initiate eval metrics to return
260
+ num_classes = len(set(labels))
261
+ mean_fpr = np.linspace(0, 1, 100)
262
+ all_tpr = []
263
+ all_roc_auc = []
264
+ all_tpr_wt = []
265
+ label_dicts = []
266
+ confusion = np.zeros((num_classes, num_classes))
267
+
268
+ # set up cross-validation splits
269
+ skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)
270
+ # train and evaluate
271
+ iteration_num = 0
272
+ for train_index, eval_index in tqdm(skf.split(targets, labels)):
273
+ if len(labels) > 500:
274
+ print("early stopping activated due to large # of training examples")
275
+ if iteration_num == 3:
276
+ break
277
+
278
+ print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n")
279
+
280
+ # generate cross-validation splits
281
+ targets_train, targets_eval = targets[train_index], targets[eval_index]
282
+ labels_train, labels_eval = labels[train_index], labels[eval_index]
283
+ label_dict_train = dict(zip(targets_train, labels_train))
284
+ label_dict_eval = dict(zip(targets_eval, labels_eval))
285
+ label_dicts += (
286
+ iteration_num,
287
+ targets_train,
288
+ targets_eval,
289
+ labels_train,
290
+ labels_eval,
291
+ )
292
+
293
+ # function to filter by whether contains train or eval labels
294
+ def if_contains_train_label(example):
295
+ a = label_dict_train.keys()
296
+ b = example["input_ids"]
297
+
298
+ return not set(a).isdisjoint(b)
299
+
300
+ def if_contains_eval_label(example):
301
+ a = label_dict_eval.keys()
302
+ b = example["input_ids"]
303
+
304
+ return not set(a).isdisjoint(b)
305
+
306
+ # filter dataset for examples containing classes for this split
307
+ print(f"Filtering training data")
308
+ trainset = data.filter(if_contains_train_label, num_proc=num_proc)
309
+ print(
310
+ f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n"
311
+ )
312
+ print(f"Filtering evalation data")
313
+ evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
314
+ print(
315
+ f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n"
316
+ )
317
+
318
+ # minimize to smaller training sample
319
+ training_size = min(subsample_size, len(trainset))
320
+ trainset_min = trainset.select([i for i in range(training_size)])
321
+ eval_size = min(training_size, len(evalset))
322
+ half_training_size = round(eval_size / 2)
323
+ evalset_train_min = evalset.select([i for i in range(half_training_size)])
324
+ evalset_oos_min = evalset.select(
325
+ [i for i in range(half_training_size, eval_size)]
326
+ )
327
+
328
+ # label conversion functions
329
+ def generate_train_labels(example):
330
+ example["labels"] = [
331
+ label_dict_train.get(token_id, -100)
332
+ for token_id in example["input_ids"]
333
+ ]
334
+ return example
335
+
336
+ def generate_eval_labels(example):
337
+ example["labels"] = [
338
+ label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
339
+ ]
340
+ return example
341
+
342
+ # label datasets
343
+ print(f"Labeling training data")
344
+ trainset_labeled = trainset_min.map(generate_train_labels)
345
+ print(f"Labeling evaluation data")
346
+ evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
347
+ print(f"Labeling evaluation OOS data")
348
+ evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
349
+
350
+ # create output directories
351
+ ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
352
+ ksplit_model_dir = os.path.join(ksplit_output_dir, "models/")
353
+
354
+ # ensure not overwriting previously saved model
355
+ model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin")
356
+ # if os.path.isfile(model_output_file) == True:
357
+ # raise Exception("Model already saved to this directory.")
358
+
359
+ # make training and model output directories
360
+ subprocess.call(f"mkdir -p {ksplit_output_dir}", shell=True)
361
+ subprocess.call(f"mkdir -p {ksplit_model_dir}", shell=True)
362
+
363
+ # load model
364
+ model = BertForTokenClassification.from_pretrained(
365
+ pre_model,
366
+ num_labels=num_labels,
367
+ output_attentions=False,
368
+ output_hidden_states=False,
369
+ )
370
+ if freeze_layers is not None:
371
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
372
+ for module in modules_to_freeze:
373
+ for param in module.parameters():
374
+ param.requires_grad = False
375
+
376
+ model = model.to(device)
377
+
378
+ # add output directory to training args and initiate
379
+ training_args["output_dir"] = ksplit_output_dir
380
+ training_args_init = TrainingArguments(**training_args)
381
+
382
+ # create the trainer
383
+ trainer = Trainer(
384
+ model=model,
385
+ args=training_args_init,
386
+ data_collator=DataCollatorForGeneClassification(),
387
+ train_dataset=trainset_labeled,
388
+ eval_dataset=evalset_train_labeled,
389
+ )
390
+
391
+ # train the gene classifier
392
+ trainer.train()
393
+
394
+ # save model
395
+ trainer.save_model(ksplit_model_dir)
396
+
397
+ # evaluate model
398
+ fpr, tpr, interp_tpr, conf_mat = classifier_predict(
399
+ trainer.model, evalset_oos_labeled, 200, mean_fpr
400
+ )
401
+
402
+ # append to tpr and roc lists
403
+ confusion = confusion + conf_mat
404
+ all_tpr.append(interp_tpr)
405
+ all_roc_auc.append(auc(fpr, tpr))
406
+ # append number of eval examples by which to weight tpr in averaged graphs
407
+ all_tpr_wt.append(len(tpr))
408
+
409
+ iteration_num = iteration_num + 1
410
+
411
+ # get overall metrics for cross-validation
412
+ mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(
413
+ all_tpr, all_roc_auc, all_tpr_wt
414
+ )
415
+ return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts
416
+
417
+
418
+ # Computes metrics
419
+ def compute_metrics(pred):
420
+ labels = pred.label_ids
421
+ preds = pred.predictions.argmax(-1)
422
+ # calculate accuracy and macro f1 using sklearn's function
423
+ acc = accuracy_score(labels, preds)
424
+ macro_f1 = f1_score(labels, preds, average="macro")
425
+
426
+ return {"accuracy": acc, "macro_f1": macro_f1}
427
+
428
+
429
+ # plot ROC curve
430
+ def plot_ROC(bundled_data, title):
431
+ plt.figure()
432
+ lw = 2
433
+ for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
434
+ plt.plot(
435
+ mean_fpr,
436
+ mean_tpr,
437
+ color=color,
438
+ lw=lw,
439
+ label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(
440
+ sample, roc_auc, roc_auc_sd
441
+ ),
442
+ )
443
+
444
+ plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--")
445
+ plt.xlim([0.0, 1.0])
446
+ plt.ylim([0.0, 1.05])
447
+ plt.xlabel("False Positive Rate")
448
+ plt.ylabel("True Positive Rate")
449
+ plt.title(title)
450
+ plt.legend(loc="lower right")
451
+ plt.savefig("ROC.png")
452
+
453
+ return mean_fpr, mean_tpr, roc_auc
454
+
455
+
456
+ # plot confusion matrix
457
+ def plot_confusion_matrix(classes_list, conf_mat, title):
458
+ display_labels = []
459
+ i = 0
460
+ for label in classes_list:
461
+ display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:, i]))]
462
+ i = i + 1
463
+ display = ConfusionMatrixDisplay(
464
+ confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"),
465
+ display_labels=display_labels,
466
+ )
467
+ display.plot(cmap="Blues", values_format=".2g")
468
+ plt.title(title)
469
+ plt.savefig("CM.png")
470
+
471
+
472
+ # Function to find the largest number smaller
473
+ # than or equal to N that is divisible by k
474
+ def find_largest_div(N, K):
475
+ rem = N % K
476
+ if rem == 0:
477
+ return N
478
+ else:
479
+ return N - rem
480
+
481
+
482
+ def preprocess_classifier_batch(cell_batch, max_len):
483
+ if max_len == None:
484
+ max_len = max([len(i) for i in cell_batch["input_ids"]])
485
+
486
+ def pad_label_example(example):
487
+ example["labels"] = np.pad(
488
+ example["labels"],
489
+ (0, max_len - len(example["input_ids"])),
490
+ mode="constant",
491
+ constant_values=-100,
492
+ )
493
+ example["input_ids"] = np.pad(
494
+ example["input_ids"],
495
+ (0, max_len - len(example["input_ids"])),
496
+ mode="constant",
497
+ constant_values=token_dictionary.get("<pad>"),
498
+ )
499
+ example["attention_mask"] = (
500
+ example["input_ids"] != token_dictionary.get("<pad>")
501
+ ).astype(int)
502
+ return example
503
+
504
+ padded_batch = cell_batch.map(pad_label_example)
505
+ return padded_batch
506
+
507
+
508
+ # forward batch size is batch size for model inference (e.g. 200)
509
+ def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
510
+ predict_logits = []
511
+ predict_labels = []
512
+ model.to("cpu")
513
+ model.eval()
514
+
515
+ # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
516
+ evalset_len = len(evalset)
517
+ max_divisible = find_largest_div(evalset_len, forward_batch_size)
518
+ if len(evalset) - max_divisible == 1:
519
+ evalset_len = max_divisible
520
+
521
+ max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
522
+
523
+ for i in range(0, evalset_len, forward_batch_size):
524
+ max_range = min(i + forward_batch_size, evalset_len)
525
+ batch_evalset = evalset.select([i for i in range(i, max_range)])
526
+ padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
527
+ padded_batch.set_format(type="torch")
528
+
529
+ input_data_batch = padded_batch["input_ids"]
530
+ attn_msk_batch = padded_batch["attention_mask"]
531
+ label_batch = padded_batch["labels"]
532
+ with torch.no_grad():
533
+ input_ids = input_data_batch
534
+ attn_mask = attn_msk_batch
535
+ labels = label_batch
536
+ outputs = model(
537
+ input_ids=input_ids, attention_mask=attn_mask, labels=labels
538
+ )
539
+ predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
540
+ predict_labels += [torch.squeeze(label_batch.to("cpu"))]
541
+
542
+ logits_by_cell = torch.cat(predict_logits)
543
+ all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
544
+ labels_by_cell = torch.cat(predict_labels)
545
+ all_labels = torch.flatten(labels_by_cell)
546
+ logit_label_paired = [
547
+ item
548
+ for item in list(zip(all_logits.tolist(), all_labels.tolist()))
549
+ if item[1] != -100
550
+ ]
551
+ y_pred = [vote(item[0]) for item in logit_label_paired]
552
+ y_true = [item[1] for item in logit_label_paired]
553
+ logits_list = [item[0] for item in logit_label_paired]
554
+ # probability of class 1
555
+ y_score = [py_softmax(item)[1] for item in logits_list]
556
+ conf_mat = confusion_matrix(y_true, y_pred)
557
+ fpr, tpr, _ = roc_curve(y_true, y_score)
558
+ # plot roc_curve for this split
559
+ plt.plot(fpr, tpr)
560
+ plt.xlim([0.0, 1.0])
561
+ plt.ylim([0.0, 1.05])
562
+ plt.xlabel("False Positive Rate")
563
+ plt.ylabel("True Positive Rate")
564
+ plt.title("ROC")
565
+ plt.show()
566
+ # interpolate to graph
567
+ interp_tpr = np.interp(mean_fpr, fpr, tpr)
568
+ interp_tpr[0] = 0.0
569
+ return fpr, tpr, interp_tpr, conf_mat
570
+
571
+
572
+ def classify_genes(
573
+ gene_info="Genecorpus-30M/example_input_files/gene_info_table.csv",
574
+ genes="Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
575
+ corpus_30M="Genecorpus-30M/genecorpus_30M_2048.dataset/",
576
+ model=".",
577
+ max_input_size=2**11,
578
+ max_lr=5e-5,
579
+ freeze_layers=4,
580
+ num_gpus=1,
581
+ num_proc=os.cpu_count(),
582
+ geneformer_batch_size=9,
583
+ epochs=1,
584
+ filter_dataset=50_000,
585
+ emb_extract=True,
586
+ emb_layer=0,
587
+ forward_batch=200,
588
+ filter_data=None,
589
+ inference=False,
590
+ k_validate=True,
591
+ model_location="230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
592
+ skip_training=False,
593
+ emb_dir="gene_emb",
594
+ output_dir=None,
595
+ max_cells=1000,
596
+ num_cpus=os.cpu_count(),
597
+ ):
598
+ """ "
599
+ Primary Parameters
600
+ -----------
601
+
602
+ gene_info: path
603
+ Path to gene mappings
604
+
605
+ corpus_30M: path
606
+ Path to 30M Gene Corpus
607
+
608
+ model: path
609
+ Path to pretrained GeneFormer model
610
+
611
+ genes: path
612
+ Path to csv file containing different columns of genes and the column labels
613
+
614
+ inference: bool
615
+ Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
616
+
617
+ k_validate: bool
618
+ Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
619
+
620
+ skip_training: bool
621
+ Whether the model should skip the training portion. Defaults to False
622
+
623
+ emb_extract: bool
624
+ WHether the model should extract embeddings for a given gene (WIP)
625
+
626
+
627
+ Customization Parameters
628
+ -----------
629
+
630
+ freeze_layers: int
631
+ Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
632
+
633
+ filter_dataset: int
634
+ Number of cells to filter from 30M dataset. Default is 50_000
635
+
636
+ emb_layer: int
637
+ What layer embeddings are extracted from. Default is 4
638
+
639
+ filter_data: str, list
640
+ Filters down embeddings to a single category. Default is None
641
+
642
+
643
+ """
644
+
645
+ # table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)
646
+ gene_info = pd.read_csv(gene_info, index_col=0)
647
+ labels = gene_info.columns
648
+
649
+ # create dictionaries for corresponding attributes
650
+ gene_id_type_dict = dict(zip(gene_info["ensembl_id"], gene_info["gene_type"]))
651
+ gene_name_id_dict = dict(zip(gene_info["gene_name"], gene_info["ensembl_id"]))
652
+ gene_id_name_dict = {v: k for k, v in gene_name_id_dict.items()}
653
+
654
+ # function for preparing targets and labels
655
+ def prep_inputs(label_store, id_type):
656
+ target_list = []
657
+ if id_type == "gene_name":
658
+ for key in list(label_store.keys()):
659
+ targets = [
660
+ gene_name_id_dict[gene]
661
+ for gene in label_store[key]
662
+ if gene_name_id_dict.get(gene) in token_dictionary
663
+ ]
664
+ targets_id = [token_dictionary[gene] for gene in targets]
665
+ target_list.append(targets_id)
666
+ elif id_type == "ensembl_id":
667
+ for key in list(label_store.keys()):
668
+ targets = [
669
+ gene for gene in label_store[key] if gene in token_dictionary
670
+ ]
671
+ targets_id = [token_dictionary[gene] for gene in targets]
672
+ target_list.append(targets_id)
673
+
674
+ targets, labels = [], []
675
+ for targ in target_list:
676
+ targets = targets + targ
677
+ targets = np.array(targets)
678
+ for num, targ in enumerate(target_list):
679
+ label = [num] * len(targ)
680
+ labels = labels + label
681
+ labels = np.array(labels)
682
+ unique_labels = num + 1
683
+
684
+ nsplits = min(5, min([len(targ) for targ in target_list]) - 1)
685
+ assert nsplits > 2
686
+
687
+ return targets, labels, nsplits, unique_labels
688
+
689
+ if skip_training == False:
690
+ # preparing targets and labels for dosage sensitive vs insensitive TFs
691
+ gene_classes = pd.read_csv(genes, header=0)
692
+ if filter_data == None:
693
+ labels = gene_classes.columns
694
+ else:
695
+ if isinstance(filter_data, list):
696
+ labels = filter_data
697
+ else:
698
+ labels = [filter_data]
699
+ label_store = {}
700
+
701
+ # Dictionary for decoding labels
702
+ decode = {i: labels[i] for i in range(len(labels))}
703
+
704
+ for label in labels:
705
+ label_store[label] = gene_classes[label].dropna()
706
+
707
+ targets, labels, nsplits, unique_labels = prep_inputs(label_store, "ensembl_id")
708
+
709
+ # load training dataset
710
+ train_dataset = load_from_disk(corpus_30M)
711
+ shuffled_train_dataset = train_dataset.shuffle(seed=42)
712
+ subsampled_train_dataset = shuffled_train_dataset.select(
713
+ [i for i in range(filter_dataset)]
714
+ )
715
+ lr_schedule_fn = "linear"
716
+ warmup_steps = 500
717
+ optimizer = "adamw"
718
+ subsample_size = 10_000
719
+
720
+ training_args = {
721
+ "learning_rate": max_lr,
722
+ "do_train": True,
723
+ "evaluation_strategy": "no",
724
+ "save_strategy": "epoch",
725
+ "logging_steps": 10,
726
+ "group_by_length": True,
727
+ "length_column_name": "length",
728
+ "disable_tqdm": False,
729
+ "lr_scheduler_type": lr_schedule_fn,
730
+ "warmup_steps": warmup_steps,
731
+ "weight_decay": 0.001,
732
+ "per_device_train_batch_size": geneformer_batch_size,
733
+ "per_device_eval_batch_size": geneformer_batch_size,
734
+ "num_train_epochs": epochs,
735
+ }
736
+
737
+ # define output directory path
738
+ current_date = datetime.datetime.now()
739
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
740
+
741
+ if output_dir == None:
742
+ training_output_dir = Path(
743
+ f"{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/"
744
+ )
745
+ else:
746
+ training_output_dir = Path(output_dir)
747
+
748
+ # make output directory
749
+ subprocess.call(f"mkdir -p {training_output_dir}", shell=True)
750
+
751
+ # Places number of classes + in directory
752
+ num_classes = len(set(labels))
753
+ info_list = [num_classes, decode]
754
+
755
+ with open(training_output_dir / "classes.txt", "w") as f:
756
+ f.write(str(info_list))
757
+
758
+ subsampled_train_dataset.save_to_disk(output_dir / "dataset")
759
+
760
+ if k_validate == True:
761
+ ksplit_model = "ksplit0/models"
762
+ ksplit_model_test = os.path.join(training_output_dir, ksplit_model)
763
+ # if os.path.isfile(ksplit_model_test) == True:
764
+ # raise Exception("Model already saved to this directory.")
765
+ # cross-validate gene classifier
766
+ (
767
+ all_roc_auc,
768
+ roc_auc,
769
+ roc_auc_sd,
770
+ mean_fpr,
771
+ mean_tpr,
772
+ confusion,
773
+ label_dicts,
774
+ ) = cross_validate(
775
+ subsampled_train_dataset,
776
+ targets,
777
+ labels,
778
+ nsplits,
779
+ subsample_size,
780
+ training_args,
781
+ freeze_layers,
782
+ training_output_dir,
783
+ 1,
784
+ unique_labels,
785
+ model,
786
+ )
787
+
788
+ bundled_data = []
789
+ bundled_data += [
790
+ (roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red")
791
+ ]
792
+ graph_title = " ".join(
793
+ [
794
+ i + " vs" if count < len(label_store) - 1 else i
795
+ for count, i in enumerate(label_store)
796
+ ]
797
+ )
798
+ fpr, tpr, auc = plot_ROC(
799
+ bundled_data, "Dosage Sensitive vs Insensitive TFs"
800
+ )
801
+ print(auc)
802
+ # plot confusion matrix
803
+ plot_confusion_matrix(label_store, confusion, "Geneformer")
804
+ else:
805
+ fpr, tpr, auc = validate(
806
+ subsampled_train_dataset,
807
+ targets,
808
+ labels,
809
+ nsplits,
810
+ subsample_size,
811
+ training_args,
812
+ freeze_layers,
813
+ training_output_dir,
814
+ 1,
815
+ unique_labels,
816
+ model,
817
+ )
818
+ print(auc)
819
+
820
+ if inference == True:
821
+ # preparing targets and labels for dosage sensitive vs insensitive TFs
822
+ gene_classes = pd.read_csv(genes, header=0)
823
+ targets = []
824
+ for column in gene_classes.columns:
825
+ targets += list(gene_classes[column])
826
+ tokens = []
827
+ for target in targets:
828
+ try:
829
+ tokens.append(token_dictionary[target])
830
+ except:
831
+ tokens.append(0)
832
+
833
+ targets = torch.LongTensor([tokens])
834
+
835
+ with open(f"{model_location}classes.txt", "r") as f:
836
+ info_list = ast.literal_eval(f.read())
837
+ num_classes = info_list[0]
838
+ labels = info_list[1]
839
+
840
+ model = BertForTokenClassification.from_pretrained(
841
+ model_location,
842
+ num_labels=num_classes,
843
+ output_attentions=False,
844
+ output_hidden_states=False,
845
+ local_files_only=True,
846
+ )
847
+ if freeze_layers is not None:
848
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
849
+ for module in modules_to_freeze:
850
+ for param in module.parameters():
851
+ param.requires_grad = False
852
+
853
+ model = model.to(device)
854
+
855
+ # evaluate model
856
+ predictions = F.softmax(model(targets.to(device))["logits"], dim=-1).argmax(-1)[
857
+ 0
858
+ ]
859
+ predictions = [labels[int(pred)] for pred in predictions]
860
+
861
+ return predictions
862
+
863
+ # Extracts aggregate gene embeddings for each label
864
+ if emb_extract == True:
865
+ with open(f"{model_location}/classes.txt", "r") as f:
866
+ data = ast.literal_eval(f.read())
867
+ num_classes = data[0]
868
+ decode = data[1]
869
+
870
+ gene_classes = pd.read_csv(genes, header=0)
871
+ labels = gene_classes.columns
872
+ tokenize = TranscriptomeTokenizer()
873
+
874
+ label_dict = {}
875
+ for label in labels:
876
+ genes = gene_classes[label]
877
+ tokenized_genes = []
878
+ for gene in genes:
879
+ try:
880
+ tokenized_genes.append(tokenize.gene_token_dict[gene])
881
+ except:
882
+ continue
883
+ label_dict[label] = tokenized_genes
884
+
885
+ embex = EmbExtractor(
886
+ model_type="GeneClassifier",
887
+ num_classes=num_classes,
888
+ emb_mode="gene",
889
+ filter_data=None,
890
+ max_ncells=max_cells,
891
+ emb_layer=emb_layer,
892
+ emb_label=label_dict,
893
+ labels_to_plot=list(labels),
894
+ forward_batch_size=forward_batch,
895
+ nproc=num_cpus,
896
+ )
897
+
898
+ subprocess.call(f"mkdir -p {emb_dir}", shell=True)
899
+
900
+ embs = embex.extract_embs(
901
+ model_directory=model_location,
902
+ input_data_file=model_location / "dataset",
903
+ output_directory=emb_dir,
904
+ output_prefix=f"{label}_embbeddings",
905
+ )
906
+
907
+ emb_dict = {label: [] for label in list(set(labels))}
908
+ similarities = {key: {} for key in list(emb_dict.keys())}
909
+
910
+ for column in embs.columns:
911
+ remaining_cols = [k for k in embs.columns if k != column]
912
+ for k in remaining_cols:
913
+ embedding = torch.Tensor(embs[k])
914
+ sim = similarity(torch.Tensor(embs[column]), embedding, cosine=True)
915
+ similarities[column][k] = sim
916
+
917
+ plot_similarity_heatmap(similarities)
918
+ print(similarities)
919
+
920
+ return similarities
921
+
922
+
923
+ if __name__ == "__main__":
924
+ classify_genes(
925
+ k_validate=False,
926
+ inference=False,
927
+ skip_training=False,
928
+ emb_extract=True,
929
+ output_dir=Path("gene_emb"),
930
+ model_location=Path("gene_emb"),
931
+ epochs=5,
932
+ gene_info="../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_info_table.csv",
933
+ genes="../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
934
+ corpus_30M="../GeneFormer_repo/Genecorpus-30M/genecorpus_30M_2048.dataset/",
935
+ )
geneformer/modular_classifier_usage.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cell classifier
2
+ def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".",
3
+ dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/',
4
+ dataset_split = None,
5
+ filter_cells = .005,
6
+ epochs = 1,
7
+ cpu_cores = os.cpu_count(),
8
+ geneformer_batch_size = 12,
9
+ optimizer = 'adamw',
10
+ max_lr = 5e-5,
11
+ num_gpus = torch.cuda.device_count(),
12
+ max_input_size = 2 ** 11,
13
+ lr_schedule_fn = "linear",
14
+ warmup_steps = 500,
15
+ freeze_layers = 0,
16
+ emb_extract = False,
17
+ max_cells = 1000,
18
+ emb_layer = 0,
19
+ emb_filter = None,
20
+ emb_dir = 'embeddings',
21
+ overwrite = True,
22
+ label = "cell_type",
23
+ data_filter = None,
24
+ forward_batch = 200, model_location = None,
25
+ skip_training = False,
26
+ sample_data = 1,
27
+ inference = False,
28
+ optimize_hyperparameters = False,
29
+ output_dir = None):
30
+
31
+ '''
32
+ Primary Parameters
33
+ -------------------
34
+ dataset: path
35
+ Path to fine-tuning/testing dataset for training
36
+
37
+ model_location: path
38
+ Path to location of existing model to use for inference and embedding extraction
39
+
40
+ pretrained_model: path
41
+ Path to pretrained GeneFormer 30M model before fine-tuning
42
+
43
+ inference: bool
44
+ Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False
45
+
46
+ skip_training: bool
47
+ Chooses whether to skip training the model. Defaults to False
48
+
49
+ emb_extract: bool
50
+ Choose whether to extract embeddings and calculate similarities. Defaults to True
51
+
52
+ optimize_hyperparameters: bool
53
+ Choose whether to optimize model hyperparamters. Defaults to False
54
+ label: string
55
+ The label string in the formatted dataset that contains true class labels. Defaults to "label"
56
+
57
+ Customization Parameters
58
+ -------------------
59
+
60
+ dataset_split: str
61
+ How the dataset should be partitioned (if at all), and what ID should be used for partitioning
62
+
63
+ data_filter: list
64
+ (For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split
65
+
66
+ label: str
67
+ What feature should be read as a classification label
68
+
69
+ emb_layer: int
70
+ What layer embeddings should be extracted and compared from.
71
+
72
+ emb_filter: ['cell1', 'cell2'...]
73
+ Allows user to narrow down range of cells that embeddings will be extracted from.
74
+
75
+ max_cells: int
76
+ How many embeddings from cells should be extracted.
77
+
78
+ freeze_layers: int
79
+ Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model).
80
+
81
+ sample_data: float
82
+ What proportion of the HF dataset should be used
83
+
84
+ '''
85
+
86
+ # Gene Classifier
87
+ def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv",
88
+ genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
89
+ corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
90
+ max_input_size = 2 ** 11,
91
+ max_lr = 5e-5,
92
+ freeze_layers = 4,
93
+ num_gpus = 1,
94
+ num_proc = os.cpu_count(),
95
+ geneformer_batch_size = 9,
96
+ epochs = 1,
97
+ filter_dataset = 50_000,
98
+ emb_extract = True,
99
+ emb_layer = 0,
100
+ forward_batch = 200,
101
+ filter_data = None,
102
+ inference = False,
103
+ k_validate = True,
104
+ model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
105
+ skip_training = False,
106
+ emb_dir = 'gene_emb',
107
+ output_dir = None,
108
+ max_cells = 1000,
109
+ num_cpus = os.cpu_count()):
110
+
111
+ """"
112
+ Primary Parameters
113
+ -----------
114
+
115
+ gene_info: path
116
+ Path to gene mappings
117
+
118
+ corpus_30M: path
119
+ Path to 30M Gene Corpus
120
+
121
+ model: path
122
+ Path to pretrained GeneFormer model
123
+
124
+ genes: path
125
+ Path to csv file containing different columns of genes and the column labels
126
+
127
+ inference: bool
128
+ Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
129
+
130
+ k_validate: bool
131
+ Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
132
+
133
+ skip_training: bool
134
+ Whether the model should skip the training portion. Defaults to False
135
+
136
+ emb_extract: bool
137
+ WHether the model should extract embeddings for a given gene (WIP)
138
+
139
+
140
+ Customization Parameters
141
+ -----------
142
+
143
+ freeze_layers: int
144
+ Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
145
+
146
+ filter_dataset: int
147
+ Number of cells to filter from 30M dataset. Default is 50_000
148
+
149
+ emb_layer: int
150
+ What layer embeddings are extracted from. Default is 4
151
+
152
+ filter_data: str, list
153
+ Filters down embeddings to a single category. Default is None
154
+
155
+
156
+ """