Christina Theodoris commited on
Commit
025e1b8
1 Parent(s): e562c0c

update cell classifier module

Browse files
.pre-commit-config.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ # See https://pre-commit.com/hooks.html for more hooks
3
+ repos:
4
+ - repo: https://github.com/pre-commit/pre-commit-hooks
5
+ rev: v3.2.0
6
+ hooks:
7
+ - id: trailing-whitespace
8
+ - id: end-of-file-fixer
9
+ - id: check-yaml
10
+ - id: check-added-large-files
11
+ - id: check-merge-conflict
12
+ - id: mixed-line-ending
13
+ - id: check-docstring-first
14
+ - repo: https://github.com/pycqa/isort
15
+ rev: 5.12.0
16
+ hooks:
17
+ - id: isort
18
+ - repo: https://github.com/astral-sh/ruff-pre-commit
19
+ # Ruff version.
20
+ rev: v0.1.4
21
+ hooks:
22
+ # Run the Ruff linter.
23
+ - id: ruff
24
+ # Run the Ruff formatter.
25
+ - id: ruff-format
Immune_modelpredictions.pickle DELETED
Binary file (99.1 kB)
 
gene_embclasses.txt DELETED
@@ -1 +0,0 @@
1
- [2, {0: 0, 1: 0}]
 
 
gene_embdataset.pk DELETED
Binary file (1.76 kB)
 
Cell_classifier.py → geneformer/cell_classifier.py RENAMED
@@ -1,65 +1,95 @@
1
- # Package Imports
2
- import tqdm
3
- import sys
4
- import polars as pl
5
- import pysam
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import os
7
- from datasets import Dataset
8
- from collections import Counter
9
  import random
10
- import datetime
11
- from pathlib import Path
12
  import subprocess
13
- import seaborn as sns; sns.set()
14
- from datasets import load_from_disk
15
- import fastcluster
16
- from sklearn.metrics import accuracy_score, f1_score
17
- from transformers import BertForSequenceClassification
18
- from transformers import Trainer
19
- from transformers.training_args import TrainingArguments
20
- from geneformer import DataCollatorForCellClassification, EmbExtractor
21
- import pickle
22
- from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve
23
- from sklearn.metrics import auc as precision_auc
24
- from sklearn.preprocessing import label_binarize
25
- import pyarrow as pa
26
- import concurrent.futures
27
- from matplotlib import pyplot as plt
28
  import torch
29
  import torch.nn.functional as F
30
- from scipy.stats import ranksums
31
- import ray
32
- import ast
33
  from ray import tune
34
- from ray.tune import ExperimentAnalysis
35
  from ray.tune.search.hyperopt import HyperOptSearch
36
- import numpy as np
 
 
 
 
 
 
 
 
37
 
38
  # Properly sets up NCCV environment
39
- GPU_NUMBER = [i for i in range(torch.cuda.device_count())]
40
  os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
41
  os.environ["NCCL_DEBUG"] = "INFO"
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
 
44
- # Function for generating a ROC curve from data
45
- def ROC(prediction, truth, type = 'GeneFormer', label = ''):
46
-
47
  fpr, tpr, _ = roc_curve(truth, prediction[:, 1])
48
  auc = roc_auc_score(truth, prediction[:, 1])
49
- print(f'{type} AUC: {auc}')
50
- plt.plot(fpr,tpr, label="AUC="+str(auc))
51
- plt.ylabel('True Positive Rate')
52
- plt.xlabel('False Positive Rate')
53
- plt.title(f'{label} ROC Curve')
54
  plt.legend(loc=4)
55
- plt.savefig('ROC.png')
56
-
57
- return tpr, fpr, auc
58
-
 
59
  # Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
60
- def similarity(tensor1, tensor2, cosine = False):
61
-
62
- if cosine == False:
63
  if tensor1.ndimension() > 1:
64
  tensor1 = tensor1.view(1, -1)
65
  if tensor2.ndimension() > 1:
@@ -69,7 +99,7 @@ def similarity(tensor1, tensor2, cosine = False):
69
  norm_tensor2 = torch.norm(tensor2)
70
  epsilon = 1e-8
71
  similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
72
- similarity = (similarity.item() + 1)/2
73
  else:
74
  if tensor1.shape != tensor2.shape:
75
  raise ValueError("Input tensors must have the same shape.")
@@ -78,13 +108,14 @@ def similarity(tensor1, tensor2, cosine = False):
78
  dot_product = torch.dot(tensor1, tensor2)
79
  norm_tensor1 = torch.norm(tensor1)
80
  norm_tensor2 = torch.norm(tensor2)
81
-
82
  # Avoid division by zero by adding a small epsilon
83
  epsilon = 1e-8
84
  similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
85
-
86
  return similarity.item()
87
-
 
88
  # Plots heatmap between different classes/labels
89
  def plot_similarity_heatmap(similarities):
90
  classes = list(similarities.keys())
@@ -97,327 +128,256 @@ def plot_similarity_heatmap(similarities):
97
  else:
98
  val = similarities[c][cc]
99
  arr[i][j] = val
100
-
101
  plt.figure(figsize=(8, 6))
102
- plt.imshow(arr, cmap='inferno', vmin=0, vmax=1)
103
  plt.colorbar()
104
- plt.xticks(np.arange(classlen), classes, rotation = 45, ha = 'right')
105
  plt.yticks(np.arange(classlen), classes)
106
  plt.title("Similarity Heatmap")
107
  plt.savefig("similarity_heatmap.png")
108
-
109
- # Function for tokenizing genes into ranked-value encodings from Geneformer
110
- def tokenize_dataset(gene_set, type = None, token_set = 'token_dictionary.pkl', species = 'human'):
111
- token_dataset = open(token_set, 'rb')
112
- token_dict = pickle.load(token_dataset)
113
- wrap = True
114
-
115
- if isinstance(gene_set[0], list) == False:
116
- gene_set = [gene_set]
117
- wrap = False
118
-
119
- pool = Pool()
120
- converted_set = []
121
-
122
- def process_gene(gene):
123
- api_url = f"https://rest.ensembl.org/xrefs/symbol/{species}/{gene}?object_type=gene"
124
- response = requests.get(api_url, headers={"Content-Type": "application/json"})
125
- try:
126
- data = response.json()
127
- gene = data[0]['id']
128
- except:
129
- gene = None
130
- return gene
131
-
132
- def process_hgnc(gene):
133
- for gene in tqdm.tqdm(genes, total = len(genes)):
134
- api_url = f"https://rest.ensembl.org/xrefs/symbol/{species}/{hgnc_id}?object_type=gene"
135
- response = requests.get(api_url, headers={"Content-Type": "application/json"})
136
- try:
137
- data = response.json()
138
- gene = data[0]['id']
139
- except:
140
- gene = None
141
- return gene
142
-
143
- def process_go(gene):
144
- mg = mygene.MyGeneInfo()
145
- results = mg.query(gene, scopes="go", species=species, fields="ensembl.gene")
146
-
147
- ensembl_ids = []
148
- max_score = 0
149
- for hit_num, hit in enumerate(results["hits"]):
150
- if hit['_score'] > max_score:
151
- max_score = hit['_score']
152
- chosen_hit = hit
153
- try:
154
- try:
155
- gene = chosen_hit["ensembl"]["gene"]
156
- except:
157
- gene = chosen_hit["ensembl"][0]["gene"]
158
- except:
159
- gene = None
160
- return gene
161
-
162
- if type == None or type.upper() == 'ENSEMBL':
163
- converted_set = gene_set
164
- elif type.upper() == 'GENE':
165
- for genes in gene_set:
166
- converted_genes = []
167
- for result in tqdm.tqdm(pool.imap(process_gene, genes), total = len(genes)):
168
- converted_genes.append(result)
169
- converted_set.append(converted_genes)
170
- elif type.upper() == 'GO':
171
- for genes in gene_set:
172
- converted_genes = []
173
- for result in tqdm.tqdm(pool.imap(process_go, genes), total = len(genes)):
174
- converted_genes.append(result)
175
- converted_set.append(converted_genes)
176
- elif type.upper() == 'HGNC':
177
- for genes in gene_set:
178
- converted_genes = []
179
- for result in tqdm.tqdm(pool.imap(process_hgnc, genes), total = len(genes)):
180
- converted_genes.append(result)
181
- converted_set.append(converted_genes)
182
-
183
- Chembl = []
184
- for set_num, set in enumerate(converted_set):
185
- Chembl.append([])
186
- for gene in set:
187
- if gene == None:
188
- Chembl[set_num].append(None)
189
- else:
190
- try:
191
- Chembl[set_num].append(token_dict[gene])
192
- except:
193
- print(f'{gene} not found in tokenized dataset!')
194
- Chembl[set_num].append(None)
195
-
196
- if wrap == False:
197
- Chembl = Chembl[0]
198
-
199
- return Chembl
200
-
201
-
202
- # '/work/ccnr/GeneFormer/GeneFormer_repo/Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/'
203
- # '/work/ccnr/GeneFormer/GeneFormer_repo/Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/'
204
- '''
205
- ======================================================
206
 
207
- PRIMARY CELL - CLASSIFIER AND EMBEDDING EXTRACTOR CLASS
208
-
209
- +++++++++++++++++++++++++++++++++++++++++++++++++++++++
210
 
211
- Runs cell-level classification and embedding extraction with Geneformer
212
-
213
- '''
214
-
215
- def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".",
216
- dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/', dataset_split = None, filter_cells = .005, epochs = 1, cpu_cores = os.cpu_count(), geneformer_batch_size = 12, optimizer = 'adamw', max_lr = 5e-5, num_gpus = torch.cuda.device_count(), max_input_size = 2 ** 11, lr_schedule_fn = "linear", warmup_steps = 500, freeze_layers = 0, emb_extract = False, max_cells = 1000, emb_layer = 0, emb_filter = None, emb_dir = 'embeddings', overwrite = True, label = "cell_type", data_filter = None,
217
- forward_batch = 200, model_location = None, skip_training = False, sample_data = 1, inference = False, optimize_hyperparameters = False, output_dir = None):
218
-
219
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  Primary Parameters
221
  -------------------
222
  dataset: path
223
- Path to fine-tuning/testing dataset for training
 
 
 
224
 
225
- model_location: path
226
- Path to location of existing model to use for inference and embedding extraction
227
-
228
  pretrained_model: path
229
- Path to pretrained GeneFormer 30M model before fine-tuning
230
-
231
- inference: bool
232
- Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False
233
-
234
- skip_training: bool
235
- Chooses whether to skip training the model. Defaults to False
236
-
237
  emb_extract: bool
238
- Choose whether to extract embeddings and calculate similarities. Defaults to True
239
-
240
  optimize_hyperparameters: bool
241
- Choose whether to optimize model hyperparamters. Defaults to False
242
-
243
-
244
  Customization Parameters
245
  -------------------
246
-
247
  dataset_split: str
248
- How the dataset should be partitioned (if at all), and what ID should be used for partitioning
249
-
250
  data_filter: list
251
- (For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split
252
-
253
  label: str
254
- What feature should be read as a classification label
255
-
256
  emb_layer: int
257
- What layer embeddings should be extracted and compared from.
258
-
259
  emb_filter: ['cell1', 'cell2'...]
260
  Allows user to narrow down range of cells that embeddings will be extracted from.
261
-
262
  max_cells: int
263
- How many embeddings from cells should be extracted.
264
-
265
  freeze_layers: int
266
- Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model).
267
-
268
  sample_data: float
269
- What proportion of the HF dataset should be used
270
-
271
- '''`
272
-
273
  dataset_list = []
274
  evalset_list = []
275
  split_list = []
276
  target_dict_list = []
277
-
278
- '''
279
- For loading and pretraining with custom median expressions and/or custom gene conversions
280
- -------------------------------------------------------------
281
-
282
- token set: path
283
- Path to token conversion dictionary
284
-
285
- median set: path
286
- Path to median gene dictionary (ensembl IDs as the keys)
287
-
288
-
289
- median_data = pickle.load(open(median_set, 'rb'))
290
- median_data['<pad>'] = None
291
- median_data['<mask>'] = None
292
-
293
- token_set = pickle.load(open(token_set, 'rb'))
294
- median_dict = {key:median_data[key] for key in list(token_set.keys())}
295
- '''
296
-
297
  train_dataset = load_from_disk(dataset)
298
  num_samples = int(len(train_dataset) * sample_data)
299
  random_indices = random.sample(range(len(train_dataset)), num_samples)
300
  train_dataset = train_dataset.select(random_indices)
301
-
302
- sample = int(sample_data * len(train_dataset))
303
  sample_indices = random.sample(range(len(train_dataset)), sample)
304
  train_dataset = train_dataset.select(sample_indices)
305
 
306
- def if_not_rare_celltype(example):
307
  return example[label] in cells_to_keep
308
-
309
  # change labels to numerical ids
310
  def classes_to_ids(example):
311
  example["label"] = target_name_id_dict[example["label"]]
312
  return example
313
-
314
  def if_trained_label(example):
315
- return example["label"] in trained_labels
316
-
317
- if skip_training != True:
 
318
  def compute_metrics(pred):
319
  labels = pred.label_ids
320
  preds = pred.predictions.argmax(-1)
321
  # calculate accuracy and macro f1 using sklearn's function
322
  acc = accuracy_score(labels, preds)
323
- macro_f1 = f1_score(labels, preds, average='macro')
324
- return {
325
- 'accuracy': acc,
326
- 'macro_f1': macro_f1
327
- }
328
-
329
  # Defines custom exceptions for collecting labels (default excluded)
330
- excep = {"bone_marrow":"immune"}
331
-
332
- if dataset_split != None:
333
- if data_filter != None:
334
- split_iter = [data_filter]
335
  else:
336
- split_iter = Counter(train_dataset[dataset_split]).keys()
337
  for lab in split_iter:
338
-
339
  # collect list of tissues for fine-tuning (immune and bone marrow are included together)
340
  if lab in list(excep.keys()):
341
  continue
342
  elif lab == list(excep.values()):
343
- split_ids = [excep.keys(),excep.values()]
344
  split_list += [excep.values()]
345
  else:
346
  split_ids = [lab]
347
  split_list += [lab]
348
-
349
  # filter datasets for given organ
350
  def if_label(example):
351
  return example[dataset_split] == lab
352
-
353
  trainset_label = train_dataset.filter(if_label, num_proc=cpu_cores)
354
  label_counter = Counter(trainset_label[label])
355
  total_cells = sum(label_counter.values())
356
-
357
- # Throws out cells with a low proportion in the dataset (drop cell types representing <0.5% of cells per deepsort published method)
358
- cells_to_keep = [k for k,v in label_counter.items() if v>(filter_cells*total_cells)]
359
- trainset_label_subset = trainset_label.filter(if_not_rare_celltype, num_proc=cpu_cores)
360
-
 
 
 
 
 
 
361
  # shuffle datasets and rename columns
362
  trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
363
- trainset_label_shuffled = trainset_label_shuffled.rename_column(label,"label")
364
- trainset_label_shuffled = trainset_label_shuffled.remove_columns(dataset_split)
365
-
 
 
 
 
366
  # create dictionary of cell types : label ids
367
  target_names = list(Counter(trainset_label_shuffled["label"]).keys())
368
- target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
 
 
369
  target_dict_list += [target_name_id_dict]
370
-
371
- labeled_trainset = trainset_label_shuffled.map(classes_to_ids, num_proc=cpu_cores)
372
-
 
 
373
  # create 80/20 train/eval splits
374
- labeled_train_split = trainset_label_shuffled.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
375
- labeled_eval_split = trainset_label_shuffled.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
376
-
 
 
 
 
 
 
 
 
 
377
  # filter dataset for cell types in corresponding training set
378
  trained_labels = list(Counter(labeled_train_split["label"]).keys())
379
-
380
- labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=cpu_cores)
381
-
 
 
382
  dataset_list += [labeled_train_split]
383
  evalset_list += [labeled_eval_split_subset]
384
-
385
- trainset_dict = dict(zip(split_list,dataset_list))
386
- traintargetdict_dict = dict(zip(split_list,target_dict_list))
387
- evalset_dict = dict(zip(split_list,evalset_list))
388
-
389
  for lab in split_list:
390
  label_trainset = trainset_dict[lab]
391
  label_evalset = evalset_dict[lab]
392
  label_dict = traintargetdict_dict[lab]
393
-
394
  # set logging steps
395
- logging_steps = round(len(label_trainset)/geneformer_batch_size/10)
396
  if logging_steps == 0:
397
  logging_steps = 1
398
-
399
- # reload pretrained model
400
- model = BertForSequenceClassification.from_pretrained("/work/ccnr/GeneFormer/GeneFormer_repo",
401
- num_labels=len(label_dict.keys()),
402
- output_attentions = False,
403
- output_hidden_states = False).to(device)
404
-
 
 
405
  # define output directory path
406
  current_date = datetime.datetime.now()
407
  datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
408
-
409
- if output_dir == None:
410
- output_dir = f"{datestamp}_geneformer_CellClassifier_{lab}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
411
-
412
  # ensure not overwriting previously saved model
413
- saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
414
-
415
- if os.path.isfile(saved_model_test) == True and overwrite == False:
416
  raise Exception("Model already saved to this directory.")
417
-
418
  # make output directory
419
- subprocess.call(f'mkdir -p {output_dir}', shell=True)
420
-
421
  # set training arguments
422
  training_args = {
423
  "learning_rate": max_lr,
@@ -432,332 +392,371 @@ def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_s
432
  "lr_scheduler_type": lr_schedule_fn,
433
  "warmup_steps": warmup_steps,
434
  "weight_decay": 0.001,
435
- "per_device_train_batch_size": geneformer_batch_size,
436
- "per_device_eval_batch_size": geneformer_batch_size,
437
  "num_train_epochs": epochs,
438
  "load_best_model_at_end": True,
439
  "output_dir": output_dir,
440
  }
441
-
442
-
443
  training_args_init = TrainingArguments(**training_args)
444
- true_labels = label_evalset['label']
445
-
446
-
447
- if optimize_hyperparameters == False:
448
  # create the trainer
449
  trainer = Trainer(
450
- model=model,
451
- args=training_args_init,
452
- data_collator=DataCollatorForCellClassification(),
453
- train_dataset=label_trainset,
454
- eval_dataset=label_evalset,
455
- compute_metrics=compute_metrics
456
  )
457
-
458
  # train the cell type classifier
459
  trainer.train()
460
  predictions = trainer.predict(label_evalset)
461
- print(f'accuracy: {accuracy_score(predictions.argmax(), label_evalset["labels"])}')
462
-
 
 
463
  tpr, fpr, auc = ROC(predictions.predictions, true_labels)
464
-
465
  metrics = compute_metrics(predictions)
466
  with open(f"{output_dir}predictions.pickle", "wb") as fp:
467
  pickle.dump(predictions, fp)
468
-
469
- trainer.save_metrics("eval",predictions.metrics)
470
-
471
- with open(f'{output_dir}/targets.txt', 'w') as f:
472
  if len(target_dict_list) == 1:
473
  f.write(str(target_dict_list[0]))
474
  else:
475
  f.write(str(target_dict_list))
476
-
477
- try:
478
-
479
- precision, recall, _ = precision_recall_curve(true_labels, predictions.predictions[:, 1])
480
- pr_auc = precision_auc(recall, precision)
481
-
482
- print(f'AUC: {pr_auc}')
483
- return recall, precision, pr_auc
484
- except:
485
- pass
486
-
 
487
  trainer.save_model(output_dir)
488
  else:
489
-
490
  def model_init():
491
- model = BertForSequenceClassification.from_pretrained(pretrained_model,
492
- num_labels=len(label_dict.keys()),
493
- output_attentions = False,
494
- output_hidden_states = False)
 
 
495
  if freeze_layers is not None:
496
  modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
497
  for module in modules_to_freeze:
498
  for param in module.parameters():
499
  param.requires_grad = False
500
- model = model.to(device)
501
  return model
502
-
503
  trainer = Trainer(
504
  model_init=model_init,
505
  args=training_args_init,
506
  data_collator=DataCollatorForCellClassification(),
507
  train_dataset=label_trainset,
508
  eval_dataset=label_evalset,
509
- compute_metrics=compute_metrics
510
  )
511
  # specify raytune hyperparameter search space
512
  ray_config = {
513
  "num_train_epochs": tune.choice([epochs]),
514
  "learning_rate": tune.loguniform(1e-6, 1e-3),
515
  "weight_decay": tune.uniform(0.0, 0.3),
516
- "lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
 
 
517
  "warmup_steps": tune.uniform(100, 2000),
518
- "seed": tune.uniform(0,100),
519
- "per_device_train_batch_size": tune.choice([geneformer_batch_size])
 
 
520
  }
521
-
522
- hyperopt_search = HyperOptSearch(
523
- metric="eval_accuracy", mode="max")
524
-
525
- if torch.device == 'cuda':
526
- resources_per_trial={"cpu":8,"gpu":1},
527
  else:
528
- resources_per_trial={"cpu":8}
529
-
530
  # optimize hyperparameters
531
  best_trial = trainer.hyperparameter_search(
532
  direction="maximize",
533
  backend="ray",
534
- resources_per_trial = resources_per_trial,
535
  hp_space=lambda _: ray_config,
536
  search_alg=hyperopt_search,
537
- n_trials=10, # number of trials
538
- progress_reporter=tune.CLIReporter(max_report_frequency=600,
539
- sort_by_metric=True,
540
- max_progress_rows=100,
541
- mode="max",
542
- metric="eval_accuracy",
543
- metric_columns=["loss", "eval_loss", "eval_accuracy"]))
 
 
 
544
  best_hyperparameters = best_trial.hyperparameters
545
-
546
  print("Best Hyperparameters:")
547
  print(best_hyperparameters)
548
-
549
-
550
-
551
  else:
552
- trainset_label = train_dataset
553
- label_counter = Counter(trainset_label[label])
554
- total_cells = sum(label_counter.values())
555
-
556
- # Throws out cells with a low proportion in the dataset
557
- cells_to_keep = [k for k,v in label_counter.items() if v>(filter_cells*total_cells)]
558
- trainset_label_subset = trainset_label.filter(if_not_rare_celltype, num_proc=cpu_cores)
559
-
560
- # shuffle datasets and rename columns
561
- trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
562
- trainset_label_shuffled = trainset_label_shuffled.rename_column(label,"label")
563
-
564
- # create dictionary of cell types : label ids
565
- target_names = list(Counter(trainset_label_shuffled["label"]).keys())
566
- target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
567
- target_dict_list = target_name_id_dict
568
-
569
- labeled_trainset = trainset_label_shuffled.map(classes_to_ids, num_proc=cpu_cores)
570
-
571
- # create 80/20 train/eval splits
572
- labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
573
- labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
574
-
575
- # filter dataset for cell types in corresponding training set
576
- trained_labels = list(Counter(labeled_train_split["label"]).keys())
577
- labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=cpu_cores)
578
-
579
- # set logging steps
580
- logging_steps = round(len(trainset_label)/geneformer_batch_size/10)
581
-
582
- # reload pretrained model
583
- model = BertForSequenceClassification.from_pretrained(pretrained_model,
584
- num_labels=len(target_dict_list.keys()),
585
- output_attentions = False,
586
- output_hidden_states = False).to(device)
587
- # define output directory path
588
- current_date = datetime.datetime.now()
589
- datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
590
-
591
- if output_dir == None:
592
- output_dir = f"{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
593
-
594
- # ensure not overwriting previously saved model
595
- saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
596
- if os.path.isfile(saved_model_test) == True and overwrite == False:
597
- raise Exception("Model already saved to this directory.")
598
-
599
- # make output directory
600
- subprocess.call(f'mkdir -p {output_dir}', shell=True)
601
-
602
- # set training arguments
603
- training_args = {
604
- "learning_rate": max_lr,
605
- "do_train": True,
606
- "do_eval": True,
607
- "evaluation_strategy": "epoch",
608
- "save_strategy": "epoch",
609
- "logging_steps": logging_steps,
610
- "group_by_length": True,
611
- "length_column_name": "length",
612
- "disable_tqdm": False,
613
- "lr_scheduler_type": lr_schedule_fn,
614
- "warmup_steps": warmup_steps,
615
- "weight_decay": 0.001,
616
- "per_device_train_batch_size": geneformer_batch_size,
617
- "per_device_eval_batch_size": geneformer_batch_size,
618
- "num_train_epochs": epochs,
619
- "load_best_model_at_end": True,
620
- "output_dir": output_dir,}
621
-
622
- training_args_init = TrainingArguments(**training_args)
623
- true_labels = labeled_eval_split_subset['label']
624
-
625
- if optimize_hyperparameters == False:
626
-
627
- # create the trainer
628
- trainer = Trainer(
629
- model=model,
630
- args=training_args_init,
631
- data_collator=DataCollatorForCellClassification(),
632
- train_dataset=labeled_train_split,
633
- eval_dataset=labeled_eval_split_subset,
634
- compute_metrics=compute_metrics
635
- )
636
-
637
- # train the cell type classifier
638
- trainer.train()
639
- predictions = trainer.predict(labeled_eval_split_subset)
640
- predictions_tensor = torch.Tensor(predictions.predictions)
641
- predicted_labels = torch.argmax(predictions_tensor, dim=1)
642
- print(f'accuracy: {accuracy_score(predicted_labels, labeled_eval_split_subset["label"])}')
643
- metrics = compute_metrics(predictions)
644
-
645
- with open(f"{output_dir}predictions.pickle", "wb") as fp:
646
- pickle.dump(predictions.predictions.argmax(-1), fp)
647
-
648
- trainer.save_metrics("eval",predictions.metrics)
649
- trainer.save_model(output_dir)
650
-
651
- # Saves label conversion dictionary to output directory
652
- with open(f'{output_dir}/targets.txt', 'w') as f:
653
- f.write(str(target_dict_list))
654
-
655
- try:
656
-
657
- precision, recall, _ = precision_recall_curve(true_labels, predictions.predictions[:, 1])
658
- pr_auc = precision_auc(recall, precision)
659
-
660
- print(f'AUC: {pr_auc}')
661
- return recall, precision, pr_auc
662
- except:
663
- pass
664
-
665
- else:
666
- # Optimizes hyperparameters
667
-
668
- num_classes = len(list(set(labeled_train_split['label'])))
669
- def model_init():
670
- model = BertForSequenceClassification.from_pretrained(pretrained_model,
671
- num_labels=num_classes,
672
- output_attentions = False,
673
- output_hidden_states = False)
674
-
675
- if freeze_layers is not None:
676
- modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
677
- for module in modules_to_freeze:
678
- for param in module.parameters():
679
- param.requires_grad = False
680
- model = model.to(device)
681
- return model
682
-
683
-
684
- # create the trainer
685
- trainer = Trainer(
686
- model_init=model_init,
687
- args=training_args_init,
688
- data_collator=DataCollatorForCellClassification(),
689
- train_dataset=labeled_train_split,
690
- eval_dataset=labeled_eval_split_subset,
691
- compute_metrics=compute_metrics
692
- )
693
-
694
- # specify raytune hyperparameter search space
695
- ray_config = {
696
- "num_train_epochs": tune.choice([epochs]),
697
- "learning_rate": tune.loguniform(1e-6, 1e-3),
698
- "weight_decay": tune.uniform(0.0, 0.3),
699
- "lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
700
- "warmup_steps": tune.uniform(100, 2000),
701
- "seed": tune.uniform(0,100),
702
- "per_device_train_batch_size": tune.choice([geneformer_batch_size])
703
- }
704
-
705
- hyperopt_search = HyperOptSearch(
706
- metric="eval_accuracy", mode="max")
707
-
708
- if torch.device == 'cuda':
709
- resources_per_trial={"cpu":8,"gpu":1},
710
- else:
711
- resources_per_trial={"cpu":8}
712
-
713
- # optimize hyperparameters
714
- best_trial = trainer.hyperparameter_search(
715
- direction="maximize",
716
- backend="ray",
717
- resources_per_trial = resources_per_trial,
718
- hp_space=lambda _: ray_config,
719
- search_alg=hyperopt_search,
720
- n_trials=10, # number of trials
721
- progress_reporter=tune.CLIReporter(max_report_frequency=600,
722
- sort_by_metric=True,
723
- max_progress_rows=100,
724
- mode="max",
725
- metric="eval_accuracy",
726
- metric_columns=["loss", "eval_loss", "eval_accuracy"]))
727
- best_hyperparameters = best_trial.hyperparameters
728
-
729
- print("Best Hyperparameters:")
730
- print(best_hyperparameters)
731
-
732
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
  # Performs Inference with model
734
- if inference == True:
735
- if dataset_split != None and data_filter != None:
 
736
  def if_label(example):
737
- return example[dataset_split] == data_filter
738
-
739
  train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
740
-
741
  trainset_label_shuffled = train_dataset
742
  total_cells = len(trainset_label_shuffled)
743
-
744
  # loads dictionary of all cell labels model was trained on
745
- with open(Path(model_location) / 'targets.txt', 'r') as f:
746
  data = ast.literal_eval(f.read())
747
- if dataset_split != None and data_filter == None:
748
  indexer = dataset_split.index(data_filter)
749
  data = data[indexer]
750
-
751
- target_dict_list = {key:value for key, value in enumerate(data)}
752
-
753
  # set logging steps
754
- logging_steps = round(len(trainset_label_shuffled)/geneformer_batch_size/20)
755
-
756
- # reload pretrained model
757
  input_ids = trainset_label_shuffled["input_ids"]
758
  inputs = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
759
  attention = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
760
-
761
  for i, sentence in enumerate(input_ids):
762
  sentence_length = len(sentence)
763
  if sentence_length <= max_input_size:
@@ -766,59 +765,77 @@ def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_s
766
  else:
767
  inputs[i, :] = torch.tensor(sentence[:max_input_size])
768
  attention[i, :] = torch.ones(max_input_size)
769
-
770
- model = BertForSequenceClassification.from_pretrained(model_location, num_labels=len(target_dict_list)).to(device)
771
- model_outputs = model(inputs.to(device), attention_mask = attention)["logits"]
772
- predictions = F.softmax(model_outputs, dim = -1).argmax(-1)
 
 
773
 
774
  predictions = [target_dict_list[int(pred)] for pred in predictions]
775
 
776
  return predictions
777
-
778
- # Extracts embeddings from labelled data
779
- if emb_extract == True:
780
- if emb_filter == None:
781
- with open(f'{model_location}/targets.txt', 'r') as f:
782
  data = ast.literal_eval(f.read())
783
- if dataset_split != None and data_filter == None:
784
  indexer = dataset_split.index(data_filter)
785
  data = data[indexer]
786
-
787
- target_dict_list = {key:value for key, value in enumerate(data)}
788
  total_filter = None
789
  else:
790
  total_filter = emb_filter
791
-
792
  train_dataset = load_from_disk(dataset)
793
- if dataset_split != None:
 
794
  def if_label(example):
795
- return example[dataset_split] == data_filter
796
-
797
  train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
798
-
799
  label_counter = Counter(train_dataset[label])
800
  total_cells = sum(label_counter.values())
801
- cells_to_keep = [k for k,v in label_counter.items() if v>(filter_cells*total_cells)]
802
-
 
 
803
  def if_not_rare(example):
804
- return example[label] in cells_to_keep
805
-
806
  train_dataset = train_dataset.filter(if_not_rare, num_proc=cpu_cores)
807
-
808
  true_labels = train_dataset[label]
809
  num_classes = len(list(set(true_labels)))
810
-
811
- embex = EmbExtractor(model_type="CellClassifier", num_classes=num_classes,
812
- filter_data=total_filter, max_ncells=max_cells, emb_layer=emb_layer,
813
- emb_label=[dataset_split,label], labels_to_plot=[label], forward_batch_size=forward_batch, nproc=cpu_cores)
814
-
 
 
 
 
 
 
 
 
815
  # example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
816
- subprocess.call(f'mkdir -p {emb_dir}', shell = True)
817
-
818
- embs = embex.extract_embs(model_directory = model_location, input_data_file = dataset, output_directory = emb_dir, output_prefix = f"{label}_embbeddings")
 
 
 
 
 
819
  true_labels = embex.filtered_input_data[label]
820
 
821
- emb_dict = {label:[] for label in list(set(true_labels))}
822
  for num, emb in embs.iterrows():
823
  key = emb[label]
824
  selection = emb.iloc[:255]
@@ -826,36 +843,32 @@ def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_s
826
  emb_dict[key].append(emb)
827
 
828
  for key in list(emb_dict.keys()):
829
- stack = torch.stack(emb_dict[key], dim = 0)
830
  emb_dict[key] = torch.mean(stack, dim=0)
831
- similarities = {key:{} for key in list(emb_dict.keys())}
832
-
833
  for key in list(emb_dict.keys()):
834
  remaining_keys = [k for k in list(emb_dict.keys()) if k != key]
835
  for k in remaining_keys:
836
  embedding = emb_dict[k]
837
- sim = similarity(emb_dict[key], embedding, cosine = True)
838
-
839
  similarities[key][k] = sim
840
-
841
  plot_similarity_heatmap(similarities)
842
-
843
- embex.plot_embs(embs=embs,
844
- plot_style="umap",
845
- output_directory=emb_dir,
846
- output_prefix="emb_plot")
847
-
848
-
849
- embex.plot_embs(embs=embs,
850
- plot_style="heatmap",
851
- output_directory=emb_dir,
852
- output_prefix="emb_plot")
853
-
854
-
855
- return similarities
856
-
857
- if __name__ == '__main__':
858
- predictions = finetune_cells(skip_training = False, dataset_split = None, label = "disease", sample_data = .5, data_filter = 'hcm', epochs = 10, output_dir = 'hcm_model', model_location = 'hcm_model',
859
- emb_extract = True, geneformer_batch_size = 12, inference = False, dataset = "/work/ccnr/GeneFormer/GeneFormer_repo/Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/")
860
 
861
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geneformer cell classifier.
3
+
4
+ Usage:
5
+ from geneformer import classify_cells
6
+ classify_cells(
7
+ token_set=Path("geneformer/token_dictionary.pkl"),
8
+ median_set=Path("geneformer/gene_median_dictionary.pkl"),
9
+ pretrained_model=".",
10
+ dataset="Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/",
11
+ dataset_split=None,
12
+ filter_cells=0.005,
13
+ epochs=1,
14
+ cpu_cores=os.cpu_count(),
15
+ geneformer_batch_size=12,
16
+ optimizer="adamw",
17
+ max_lr=5e-5,
18
+ num_gpus=torch.cuda.device_count(),
19
+ max_input_size=2**11,
20
+ lr_schedule_fn="linear",
21
+ warmup_steps=500,
22
+ freeze_layers=0,
23
+ emb_extract=False,
24
+ max_cells=1000,
25
+ emb_layer=0,
26
+ emb_filter=None,
27
+ emb_dir="embeddings",
28
+ overwrite=True,
29
+ label="cell_type",
30
+ data_filter=None,
31
+ forward_batch=200,
32
+ model_location=None,
33
+ skip_training=False,
34
+ sample_data=1,
35
+ inference=False,
36
+ optimize_hyperparameters=False,
37
+ output_dir=None,
38
+ )
39
+ """
40
+
41
+ import ast
42
+ import datetime
43
  import os
44
+ import pickle
 
45
  import random
 
 
46
  import subprocess
47
+ from collections import Counter
48
+ from pathlib import Path
49
+
50
+ import numpy as np
51
+ import seaborn as sns
 
 
 
 
 
 
 
 
 
 
52
  import torch
53
  import torch.nn.functional as F
54
+ from datasets import load_from_disk
55
+ from matplotlib import pyplot as plt
 
56
  from ray import tune
 
57
  from ray.tune.search.hyperopt import HyperOptSearch
58
+ from sklearn.metrics import accuracy_score
59
+ from sklearn.metrics import auc as precision_auc
60
+ from sklearn.metrics import f1_score, precision_recall_curve, roc_auc_score, roc_curve
61
+ from transformers import BertForSequenceClassification, Trainer
62
+ from transformers.training_args import TrainingArguments
63
+
64
+ from geneformer import DataCollatorForCellClassification, EmbExtractor
65
+
66
+ sns.set()
67
 
68
  # Properly sets up NCCV environment
69
+ GPU_NUMBER = [i for i in range(torch.cuda.device_count())]
70
  os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
71
  os.environ["NCCL_DEBUG"] = "INFO"
72
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
 
74
+
75
+ # Function for generating an ROC curve from data
76
+ def ROC(prediction, truth, type="GeneFormer", label=""):
77
  fpr, tpr, _ = roc_curve(truth, prediction[:, 1])
78
  auc = roc_auc_score(truth, prediction[:, 1])
79
+ print(f"{type} AUC: {auc}")
80
+ plt.plot(fpr, tpr, label="AUC=" + str(auc))
81
+ plt.ylabel("True Positive Rate")
82
+ plt.xlabel("False Positive Rate")
83
+ plt.title(f"{label} ROC Curve")
84
  plt.legend(loc=4)
85
+ plt.savefig("ROC.png")
86
+
87
+ return tpr, fpr, auc
88
+
89
+
90
  # Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
91
+ def similarity(tensor1, tensor2, cosine=False):
92
+ if cosine is False:
 
93
  if tensor1.ndimension() > 1:
94
  tensor1 = tensor1.view(1, -1)
95
  if tensor2.ndimension() > 1:
 
99
  norm_tensor2 = torch.norm(tensor2)
100
  epsilon = 1e-8
101
  similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
102
+ similarity = (similarity.item() + 1) / 2
103
  else:
104
  if tensor1.shape != tensor2.shape:
105
  raise ValueError("Input tensors must have the same shape.")
 
108
  dot_product = torch.dot(tensor1, tensor2)
109
  norm_tensor1 = torch.norm(tensor1)
110
  norm_tensor2 = torch.norm(tensor2)
111
+
112
  # Avoid division by zero by adding a small epsilon
113
  epsilon = 1e-8
114
  similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
115
+
116
  return similarity.item()
117
+
118
+
119
  # Plots heatmap between different classes/labels
120
  def plot_similarity_heatmap(similarities):
121
  classes = list(similarities.keys())
 
128
  else:
129
  val = similarities[c][cc]
130
  arr[i][j] = val
131
+
132
  plt.figure(figsize=(8, 6))
133
+ plt.imshow(arr, cmap="inferno", vmin=0, vmax=1)
134
  plt.colorbar()
135
+ plt.xticks(np.arange(classlen), classes, rotation=45, ha="right")
136
  plt.yticks(np.arange(classlen), classes)
137
  plt.title("Similarity Heatmap")
138
  plt.savefig("similarity_heatmap.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
 
 
 
140
 
141
+ def classify_cells(
142
+ token_set=Path("./token_dictionary.pkl"),
143
+ median_set=Path("./gene_median_dictionary.pkl"),
144
+ pretrained_model="../",
145
+ dataset="Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/",
146
+ dataset_split=None,
147
+ filter_cells=0.005,
148
+ epochs=1,
149
+ cpu_cores=os.cpu_count(),
150
+ training_batch_size=12,
151
+ optimizer="adamw",
152
+ max_lr=5e-5,
153
+ num_gpus=torch.cuda.device_count(),
154
+ max_input_size=2**11,
155
+ lr_schedule_fn="linear",
156
+ warmup_steps=500,
157
+ freeze_layers=0,
158
+ emb_extract=False,
159
+ max_cells=None,
160
+ emb_layer=-1,
161
+ emb_filter=None,
162
+ emb_dir="embeddings",
163
+ overwrite=False,
164
+ label="cell_type",
165
+ data_filter=None,
166
+ inference_batch_size=200,
167
+ finetuned_model=None,
168
+ skip_training=False,
169
+ sample_data=1,
170
+ inference=False,
171
+ optimize_hyperparameters=True,
172
+ output_dir=None,
173
+ ):
174
+ """
175
  Primary Parameters
176
  -------------------
177
  dataset: path
178
+ Path to fine-tuning dataset for training
179
+
180
+ finetuned_model: path
181
+ Path to location of fine-tuned model to use for inference and embedding extraction
182
 
 
 
 
183
  pretrained_model: path
184
+ Path to pretrained Geneformer model
185
+
186
+ inference: bool
187
+ Indicates whether to perform inference and return a list of similarities. Defaults to False.
188
+
189
+ skip_training: bool
190
+ Indicates whether to skip training the model. Defaults to False.
191
+
192
  emb_extract: bool
193
+ Indicates whether to extract embeddings and calculate similarities. Defaults to True.
194
+
195
  optimize_hyperparameters: bool
196
+ Indicates whether to optimize model hyperparamters. Defaults to False.
197
+
198
+
199
  Customization Parameters
200
  -------------------
201
+
202
  dataset_split: str
203
+ Indicates how the dataset should be partitioned (if at all), and what ID should be used for partitioning
204
+
205
  data_filter: list
206
+ (For embeddings and inference) Runs analysis on subsets of the dataset based on the ID defined by dataset_split
207
+
208
  label: str
209
+ Feature to read as a classification label.
210
+
211
  emb_layer: int
212
+ What layer embeddings should be extracted and compared.
213
+
214
  emb_filter: ['cell1', 'cell2'...]
215
  Allows user to narrow down range of cells that embeddings will be extracted from.
216
+
217
  max_cells: int
218
+ Max number of cells to use for embedding extraction.
219
+
220
  freeze_layers: int
221
+ Number of layers that should be frozen during fine-tuning.
222
+
223
  sample_data: float
224
+ Proportion of the dataset that should be used.
225
+
226
+ """
227
+
228
  dataset_list = []
229
  evalset_list = []
230
  split_list = []
231
  target_dict_list = []
232
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  train_dataset = load_from_disk(dataset)
234
  num_samples = int(len(train_dataset) * sample_data)
235
  random_indices = random.sample(range(len(train_dataset)), num_samples)
236
  train_dataset = train_dataset.select(random_indices)
237
+
238
+ sample = int(sample_data * len(train_dataset))
239
  sample_indices = random.sample(range(len(train_dataset)), sample)
240
  train_dataset = train_dataset.select(sample_indices)
241
 
242
+ def if_not_rare_cell_state(example):
243
  return example[label] in cells_to_keep
244
+
245
  # change labels to numerical ids
246
  def classes_to_ids(example):
247
  example["label"] = target_name_id_dict[example["label"]]
248
  return example
249
+
250
  def if_trained_label(example):
251
+ return example["label"] in trained_labels
252
+
253
+ if skip_training is not True:
254
+
255
  def compute_metrics(pred):
256
  labels = pred.label_ids
257
  preds = pred.predictions.argmax(-1)
258
  # calculate accuracy and macro f1 using sklearn's function
259
  acc = accuracy_score(labels, preds)
260
+ macro_f1 = f1_score(labels, preds, average="macro")
261
+ return {"accuracy": acc, "macro_f1": macro_f1}
262
+
 
 
 
263
  # Defines custom exceptions for collecting labels (default excluded)
264
+ excep = {"bone_marrow": "immune"}
265
+
266
+ if dataset_split is not None:
267
+ if data_filter is not None:
268
+ split_iter = [data_filter]
269
  else:
270
+ split_iter = Counter(train_dataset[dataset_split]).keys()
271
  for lab in split_iter:
 
272
  # collect list of tissues for fine-tuning (immune and bone marrow are included together)
273
  if lab in list(excep.keys()):
274
  continue
275
  elif lab == list(excep.values()):
276
+ split_ids = [excep.keys(), excep.values()]
277
  split_list += [excep.values()]
278
  else:
279
  split_ids = [lab]
280
  split_list += [lab]
281
+
282
  # filter datasets for given organ
283
  def if_label(example):
284
  return example[dataset_split] == lab
285
+
286
  trainset_label = train_dataset.filter(if_label, num_proc=cpu_cores)
287
  label_counter = Counter(trainset_label[label])
288
  total_cells = sum(label_counter.values())
289
+
290
+ # excludes cells with a low proportion in the dataset
291
+ cells_to_keep = [
292
+ k
293
+ for k, v in label_counter.items()
294
+ if v > (filter_cells * total_cells)
295
+ ]
296
+ trainset_label_subset = trainset_label.filter(
297
+ if_not_rare_cell_state, num_proc=cpu_cores
298
+ )
299
+
300
  # shuffle datasets and rename columns
301
  trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
302
+ trainset_label_shuffled = trainset_label_shuffled.rename_column(
303
+ label, "label"
304
+ )
305
+ trainset_label_shuffled = trainset_label_shuffled.remove_columns(
306
+ dataset_split
307
+ )
308
+
309
  # create dictionary of cell types : label ids
310
  target_names = list(Counter(trainset_label_shuffled["label"]).keys())
311
+ target_name_id_dict = dict(
312
+ zip(target_names, [i for i in range(len(target_names))])
313
+ )
314
  target_dict_list += [target_name_id_dict]
315
+
316
+ labeled_trainset = trainset_label_shuffled.map(
317
+ classes_to_ids, num_proc=cpu_cores
318
+ )
319
+
320
  # create 80/20 train/eval splits
321
+ labeled_train_split = trainset_label_shuffled.select(
322
+ [i for i in range(0, round(len(labeled_trainset) * 0.8))]
323
+ )
324
+ labeled_eval_split = trainset_label_shuffled.select(
325
+ [
326
+ i
327
+ for i in range(
328
+ round(len(labeled_trainset) * 0.8), len(labeled_trainset)
329
+ )
330
+ ]
331
+ )
332
+
333
  # filter dataset for cell types in corresponding training set
334
  trained_labels = list(Counter(labeled_train_split["label"]).keys())
335
+
336
+ labeled_eval_split_subset = labeled_eval_split.filter(
337
+ if_trained_label, num_proc=cpu_cores
338
+ )
339
+
340
  dataset_list += [labeled_train_split]
341
  evalset_list += [labeled_eval_split_subset]
342
+
343
+ trainset_dict = dict(zip(split_list, dataset_list))
344
+ traintargetdict_dict = dict(zip(split_list, target_dict_list))
345
+ evalset_dict = dict(zip(split_list, evalset_list))
346
+
347
  for lab in split_list:
348
  label_trainset = trainset_dict[lab]
349
  label_evalset = evalset_dict[lab]
350
  label_dict = traintargetdict_dict[lab]
351
+
352
  # set logging steps
353
+ logging_steps = round(len(label_trainset) / training_batch_size / 10)
354
  if logging_steps == 0:
355
  logging_steps = 1
356
+
357
+ # load pretrained model
358
+ model = BertForSequenceClassification.from_pretrained(
359
+ pretrained_model,
360
+ num_labels=len(label_dict.keys()),
361
+ output_attentions=False,
362
+ output_hidden_states=False,
363
+ ).to(device)
364
+
365
  # define output directory path
366
  current_date = datetime.datetime.now()
367
  datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
368
+
369
+ if output_dir is None:
370
+ output_dir = f"{datestamp}_geneformer_CellClassifier_{lab}_L{max_input_size}_B{training_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
371
+
372
  # ensure not overwriting previously saved model
373
+ saved_model_test = os.path.join(output_dir, "pytorch_model.bin")
374
+
375
+ if os.path.isfile(saved_model_test) is True and overwrite is False:
376
  raise Exception("Model already saved to this directory.")
377
+
378
  # make output directory
379
+ subprocess.call(f"mkdir -p {output_dir}", shell=True)
380
+
381
  # set training arguments
382
  training_args = {
383
  "learning_rate": max_lr,
 
392
  "lr_scheduler_type": lr_schedule_fn,
393
  "warmup_steps": warmup_steps,
394
  "weight_decay": 0.001,
395
+ "per_device_train_batch_size": training_batch_size,
396
+ "per_device_eval_batch_size": training_batch_size,
397
  "num_train_epochs": epochs,
398
  "load_best_model_at_end": True,
399
  "output_dir": output_dir,
400
  }
401
+
 
402
  training_args_init = TrainingArguments(**training_args)
403
+ true_labels = label_evalset["label"]
404
+
405
+ if optimize_hyperparameters is False:
 
406
  # create the trainer
407
  trainer = Trainer(
408
+ model=model,
409
+ args=training_args_init,
410
+ data_collator=DataCollatorForCellClassification(),
411
+ train_dataset=label_trainset,
412
+ eval_dataset=label_evalset,
413
+ compute_metrics=compute_metrics,
414
  )
415
+
416
  # train the cell type classifier
417
  trainer.train()
418
  predictions = trainer.predict(label_evalset)
419
+ print(
420
+ f'accuracy: {accuracy_score(predictions.argmax(), label_evalset["labels"])}'
421
+ )
422
+
423
  tpr, fpr, auc = ROC(predictions.predictions, true_labels)
424
+
425
  metrics = compute_metrics(predictions)
426
  with open(f"{output_dir}predictions.pickle", "wb") as fp:
427
  pickle.dump(predictions, fp)
428
+
429
+ trainer.save_metrics("eval", predictions.metrics)
430
+
431
+ with open(f"{output_dir}/targets.txt", "w") as f:
432
  if len(target_dict_list) == 1:
433
  f.write(str(target_dict_list[0]))
434
  else:
435
  f.write(str(target_dict_list))
436
+
437
+ try:
438
+ precision, recall, _ = precision_recall_curve(
439
+ true_labels, predictions.predictions[:, 1]
440
+ )
441
+ pr_auc = precision_auc(recall, precision)
442
+
443
+ print(f"AUC: {pr_auc}")
444
+ return recall, precision, pr_auc
445
+ except:
446
+ pass
447
+
448
  trainer.save_model(output_dir)
449
  else:
450
+
451
  def model_init():
452
+ model = BertForSequenceClassification.from_pretrained(
453
+ pretrained_model,
454
+ num_labels=len(label_dict.keys()),
455
+ output_attentions=False,
456
+ output_hidden_states=False,
457
+ )
458
  if freeze_layers is not None:
459
  modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
460
  for module in modules_to_freeze:
461
  for param in module.parameters():
462
  param.requires_grad = False
463
+ model = model.to(device)
464
  return model
465
+
466
  trainer = Trainer(
467
  model_init=model_init,
468
  args=training_args_init,
469
  data_collator=DataCollatorForCellClassification(),
470
  train_dataset=label_trainset,
471
  eval_dataset=label_evalset,
472
+ compute_metrics=compute_metrics,
473
  )
474
  # specify raytune hyperparameter search space
475
  ray_config = {
476
  "num_train_epochs": tune.choice([epochs]),
477
  "learning_rate": tune.loguniform(1e-6, 1e-3),
478
  "weight_decay": tune.uniform(0.0, 0.3),
479
+ "lr_scheduler_type": tune.choice(
480
+ ["linear", "cosine", "polynomial"]
481
+ ),
482
  "warmup_steps": tune.uniform(100, 2000),
483
+ "seed": tune.uniform(0, 100),
484
+ "per_device_train_batch_size": tune.choice(
485
+ [training_batch_size]
486
+ ),
487
  }
488
+
489
+ hyperopt_search = HyperOptSearch(metric="eval_accuracy", mode="max")
490
+
491
+ if torch.device == "cuda":
492
+ resources_per_trial = ({"cpu": 8, "gpu": 1},)
 
493
  else:
494
+ resources_per_trial = {"cpu": 8}
495
+
496
  # optimize hyperparameters
497
  best_trial = trainer.hyperparameter_search(
498
  direction="maximize",
499
  backend="ray",
500
+ resources_per_trial=resources_per_trial,
501
  hp_space=lambda _: ray_config,
502
  search_alg=hyperopt_search,
503
+ n_trials=100, # number of trials
504
+ progress_reporter=tune.CLIReporter(
505
+ max_report_frequency=600,
506
+ sort_by_metric=True,
507
+ max_progress_rows=100,
508
+ mode="max",
509
+ metric="eval_accuracy",
510
+ metric_columns=["loss", "eval_loss", "eval_accuracy"],
511
+ ),
512
+ )
513
  best_hyperparameters = best_trial.hyperparameters
514
+
515
  print("Best Hyperparameters:")
516
  print(best_hyperparameters)
517
+
 
 
518
  else:
519
+ trainset_label = train_dataset
520
+ label_counter = Counter(trainset_label[label])
521
+ total_cells = sum(label_counter.values())
522
+
523
+ # Excludes cells with a low proportion in the dataset
524
+ cells_to_keep = [
525
+ k for k, v in label_counter.items() if v > (filter_cells * total_cells)
526
+ ]
527
+ trainset_label_subset = trainset_label.filter(
528
+ if_not_rare_cell_state, num_proc=cpu_cores
529
+ )
530
+
531
+ # shuffle datasets and rename columns
532
+ trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
533
+ trainset_label_shuffled = trainset_label_shuffled.rename_column(
534
+ label, "label"
535
+ )
536
+
537
+ # create dictionary of cell types : label ids
538
+ target_names = list(Counter(trainset_label_shuffled["label"]).keys())
539
+ target_name_id_dict = dict(
540
+ zip(target_names, [i for i in range(len(target_names))])
541
+ )
542
+ target_dict_list = target_name_id_dict
543
+
544
+ labeled_trainset = trainset_label_shuffled.map(
545
+ classes_to_ids, num_proc=cpu_cores
546
+ )
547
+
548
+ # create 80/20 train/eval splits
549
+ labeled_train_split = labeled_trainset.select(
550
+ [i for i in range(0, round(len(labeled_trainset) * 0.8))]
551
+ )
552
+ labeled_eval_split = labeled_trainset.select(
553
+ [
554
+ i
555
+ for i in range(
556
+ round(len(labeled_trainset) * 0.8), len(labeled_trainset)
557
+ )
558
+ ]
559
+ )
560
+
561
+ # filter dataset for cell types in corresponding training set
562
+ trained_labels = list(Counter(labeled_train_split["label"]).keys())
563
+ labeled_eval_split_subset = labeled_eval_split.filter(
564
+ if_trained_label, num_proc=cpu_cores
565
+ )
566
+
567
+ # set logging steps
568
+ logging_steps = round(len(trainset_label) / training_batch_size / 10)
569
+
570
+ # load pretrained model
571
+ model = BertForSequenceClassification.from_pretrained(
572
+ pretrained_model,
573
+ num_labels=len(target_dict_list.keys()),
574
+ output_attentions=False,
575
+ output_hidden_states=False,
576
+ ).to(device)
577
+ # define output directory path
578
+ current_date = datetime.datetime.now()
579
+ datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
580
+
581
+ if output_dir is None:
582
+ output_dir = f"{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{training_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
583
+
584
+ # ensure not overwriting previously saved model
585
+ saved_model_test = os.path.join(output_dir, "pytorch_model.bin")
586
+ if os.path.isfile(saved_model_test) is True and overwrite is False:
587
+ raise Exception("Model already saved to this directory.")
588
+
589
+ # make output directory
590
+ subprocess.call(f"mkdir -p {output_dir}", shell=True)
591
+
592
+ # set training arguments
593
+ training_args = {
594
+ "learning_rate": max_lr,
595
+ "do_train": True,
596
+ "do_eval": True,
597
+ "evaluation_strategy": "epoch",
598
+ "save_strategy": "epoch",
599
+ "logging_steps": logging_steps,
600
+ "group_by_length": True,
601
+ "length_column_name": "length",
602
+ "disable_tqdm": False,
603
+ "lr_scheduler_type": lr_schedule_fn,
604
+ "warmup_steps": warmup_steps,
605
+ "weight_decay": 0.001,
606
+ "per_device_train_batch_size": training_batch_size,
607
+ "per_device_eval_batch_size": training_batch_size,
608
+ "num_train_epochs": epochs,
609
+ "load_best_model_at_end": True,
610
+ "output_dir": output_dir,
611
+ }
612
+
613
+ training_args_init = TrainingArguments(**training_args)
614
+ true_labels = labeled_eval_split_subset["label"]
615
+
616
+ if optimize_hyperparameters is False:
617
+ # create the trainer
618
+ trainer = Trainer(
619
+ model=model,
620
+ args=training_args_init,
621
+ data_collator=DataCollatorForCellClassification(),
622
+ train_dataset=labeled_train_split,
623
+ eval_dataset=labeled_eval_split_subset,
624
+ compute_metrics=compute_metrics,
625
+ )
626
+
627
+ # train the cell type classifier
628
+ trainer.train()
629
+ predictions = trainer.predict(labeled_eval_split_subset)
630
+ predictions_tensor = torch.Tensor(predictions.predictions)
631
+ predicted_labels = torch.argmax(predictions_tensor, dim=1)
632
+ print(
633
+ f'accuracy: {accuracy_score(predicted_labels, labeled_eval_split_subset["label"])}'
634
+ )
635
+ metrics = compute_metrics(predictions)
636
+
637
+ with open(f"{output_dir}predictions.pickle", "wb") as fp:
638
+ pickle.dump(predictions.predictions.argmax(-1), fp)
639
+
640
+ trainer.save_metrics("eval", predictions.metrics)
641
+ trainer.save_model(output_dir)
642
+
643
+ # Saves label conversion dictionary to output directory
644
+ with open(f"{output_dir}/targets.txt", "w") as f:
645
+ f.write(str(target_dict_list))
646
+
647
+ try:
648
+ precision, recall, _ = precision_recall_curve(
649
+ true_labels, predictions.predictions[:, 1]
650
+ )
651
+ pr_auc = precision_auc(recall, precision)
652
+
653
+ print(f"AUC: {pr_auc}")
654
+ return recall, precision, pr_auc
655
+ except:
656
+ pass
657
+
658
+ else:
659
+ # Optimizes hyperparameters
660
+
661
+ num_classes = len(list(set(labeled_train_split["label"])))
662
+
663
+ def model_init():
664
+ model = BertForSequenceClassification.from_pretrained(
665
+ pretrained_model,
666
+ num_labels=num_classes,
667
+ output_attentions=False,
668
+ output_hidden_states=False,
669
+ )
670
+
671
+ if freeze_layers is not None:
672
+ modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
673
+ for module in modules_to_freeze:
674
+ for param in module.parameters():
675
+ param.requires_grad = False
676
+ model = model.to(device)
677
+ return model
678
+
679
+ # create the trainer
680
+ trainer = Trainer(
681
+ model_init=model_init,
682
+ args=training_args_init,
683
+ data_collator=DataCollatorForCellClassification(),
684
+ train_dataset=labeled_train_split,
685
+ eval_dataset=labeled_eval_split_subset,
686
+ compute_metrics=compute_metrics,
687
+ )
688
+
689
+ # specify raytune hyperparameter search space
690
+ ray_config = {
691
+ "num_train_epochs": tune.choice([epochs]),
692
+ "learning_rate": tune.loguniform(1e-6, 1e-3),
693
+ "weight_decay": tune.uniform(0.0, 0.3),
694
+ "lr_scheduler_type": tune.choice(
695
+ ["linear", "cosine", "polynomial"]
696
+ ),
697
+ "warmup_steps": tune.uniform(100, 2000),
698
+ "seed": tune.uniform(0, 100),
699
+ "per_device_train_batch_size": tune.choice([training_batch_size]),
700
+ }
701
+
702
+ hyperopt_search = HyperOptSearch(metric="eval_accuracy", mode="max")
703
+
704
+ if torch.device == "cuda":
705
+ resources_per_trial = ({"cpu": 8, "gpu": 1},)
706
+ else:
707
+ resources_per_trial = {"cpu": 8}
708
+
709
+ # optimize hyperparameters
710
+ best_trial = trainer.hyperparameter_search(
711
+ direction="maximize",
712
+ backend="ray",
713
+ resources_per_trial=resources_per_trial,
714
+ hp_space=lambda _: ray_config,
715
+ search_alg=hyperopt_search,
716
+ n_trials=100, # number of trials
717
+ progress_reporter=tune.CLIReporter(
718
+ max_report_frequency=600,
719
+ sort_by_metric=True,
720
+ max_progress_rows=100,
721
+ mode="max",
722
+ metric="eval_accuracy",
723
+ metric_columns=["loss", "eval_loss", "eval_accuracy"],
724
+ ),
725
+ )
726
+ best_hyperparameters = best_trial.hyperparameters
727
+
728
+ print("Best Hyperparameters:")
729
+ print(best_hyperparameters)
730
+
731
  # Performs Inference with model
732
+ if inference is True:
733
+ if dataset_split is not None and data_filter is not None:
734
+
735
  def if_label(example):
736
+ return example[dataset_split] == data_filter
737
+
738
  train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
739
+
740
  trainset_label_shuffled = train_dataset
741
  total_cells = len(trainset_label_shuffled)
742
+
743
  # loads dictionary of all cell labels model was trained on
744
+ with open(Path(finetuned_model) / "targets.txt", "r") as f:
745
  data = ast.literal_eval(f.read())
746
+ if dataset_split is not None and data_filter is None:
747
  indexer = dataset_split.index(data_filter)
748
  data = data[indexer]
749
+
750
+ target_dict_list = {key: value for key, value in enumerate(data)}
751
+
752
  # set logging steps
753
+ logging_steps = round(len(trainset_label_shuffled) / training_batch_size / 20)
754
+
755
+ # load pretrained model
756
  input_ids = trainset_label_shuffled["input_ids"]
757
  inputs = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
758
  attention = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
759
+
760
  for i, sentence in enumerate(input_ids):
761
  sentence_length = len(sentence)
762
  if sentence_length <= max_input_size:
 
765
  else:
766
  inputs[i, :] = torch.tensor(sentence[:max_input_size])
767
  attention[i, :] = torch.ones(max_input_size)
768
+
769
+ model = BertForSequenceClassification.from_pretrained(
770
+ finetuned_model, num_labels=len(target_dict_list)
771
+ ).to(device)
772
+ model_outputs = model(inputs.to(device), attention_mask=attention)["logits"]
773
+ predictions = F.softmax(model_outputs, dim=-1).argmax(-1)
774
 
775
  predictions = [target_dict_list[int(pred)] for pred in predictions]
776
 
777
  return predictions
778
+
779
+ # Extracts embeddings from labeled data
780
+ if emb_extract is True:
781
+ if emb_filter is None:
782
+ with open(f"{finetuned_model}/targets.txt", "r") as f:
783
  data = ast.literal_eval(f.read())
784
+ if dataset_split is not None and data_filter is None:
785
  indexer = dataset_split.index(data_filter)
786
  data = data[indexer]
787
+
788
+ target_dict_list = {key: value for key, value in enumerate(data)}
789
  total_filter = None
790
  else:
791
  total_filter = emb_filter
792
+
793
  train_dataset = load_from_disk(dataset)
794
+ if dataset_split is not None:
795
+
796
  def if_label(example):
797
+ return example[dataset_split] == data_filter
798
+
799
  train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
800
+
801
  label_counter = Counter(train_dataset[label])
802
  total_cells = sum(label_counter.values())
803
+ cells_to_keep = [
804
+ k for k, v in label_counter.items() if v > (filter_cells * total_cells)
805
+ ]
806
+
807
  def if_not_rare(example):
808
+ return example[label] in cells_to_keep
809
+
810
  train_dataset = train_dataset.filter(if_not_rare, num_proc=cpu_cores)
811
+
812
  true_labels = train_dataset[label]
813
  num_classes = len(list(set(true_labels)))
814
+
815
+ embex = EmbExtractor(
816
+ model_type="CellClassifier",
817
+ num_classes=num_classes,
818
+ filter_data=total_filter,
819
+ max_ncells=max_cells,
820
+ emb_layer=emb_layer,
821
+ emb_label=[dataset_split, label],
822
+ labels_to_plot=[label],
823
+ forward_batch_size=inference_batch_size,
824
+ nproc=cpu_cores,
825
+ )
826
+
827
  # example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
828
+ subprocess.call(f"mkdir -p {emb_dir}", shell=True)
829
+
830
+ embs = embex.extract_embs(
831
+ model_directory=finetuned_model,
832
+ input_data_file=dataset,
833
+ output_directory=emb_dir,
834
+ output_prefix=f"{label}_embeddings",
835
+ )
836
  true_labels = embex.filtered_input_data[label]
837
 
838
+ emb_dict = {label: [] for label in list(set(true_labels))}
839
  for num, emb in embs.iterrows():
840
  key = emb[label]
841
  selection = emb.iloc[:255]
 
843
  emb_dict[key].append(emb)
844
 
845
  for key in list(emb_dict.keys()):
846
+ stack = torch.stack(emb_dict[key], dim=0)
847
  emb_dict[key] = torch.mean(stack, dim=0)
848
+ similarities = {key: {} for key in list(emb_dict.keys())}
849
+
850
  for key in list(emb_dict.keys()):
851
  remaining_keys = [k for k in list(emb_dict.keys()) if k != key]
852
  for k in remaining_keys:
853
  embedding = emb_dict[k]
854
+ sim = similarity(emb_dict[key], embedding, cosine=True)
855
+
856
  similarities[key][k] = sim
857
+
858
  plot_similarity_heatmap(similarities)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859
 
860
+ embex.plot_embs(
861
+ embs=embs,
862
+ plot_style="umap",
863
+ output_directory=emb_dir,
864
+ output_prefix="emb_plot",
865
+ )
866
+
867
+ embex.plot_embs(
868
+ embs=embs,
869
+ plot_style="heatmap",
870
+ output_directory=emb_dir,
871
+ output_prefix="emb_plot",
872
+ )
873
+
874
+ return similarities
Gene_classifier.py → geneformer/gene_classifier.py RENAMED
@@ -1,37 +1,47 @@
1
  import os
2
  import sys
3
- GPU_NUMBER = [0] # CHANGE WITH MULTIGPU
 
4
  os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
5
  os.environ["NCCL_DEBUG"] = "INFO"
6
 
7
- # imports
8
- from sklearn.model_selection import train_test_split
9
  import datetime
 
 
10
  import subprocess
11
  from pathlib import Path
12
- import math
13
  import matplotlib.pyplot as plt
14
  import numpy as np
15
- import pickle
16
  import pandas as pd
17
- from datasets import load_from_disk, Dataset
18
- from sklearn import preprocessing
19
- from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve
20
- from sklearn.model_selection import StratifiedKFold
21
  import torch
22
- from transformers import BertForTokenClassification
23
- from transformers import Trainer
24
- from transformers.training_args import TrainingArguments
 
 
 
 
 
 
 
 
 
 
 
25
  from tqdm.notebook import tqdm
26
- from sklearn.metrics import roc_curve, roc_auc_score
 
 
27
  from geneformer import DataCollatorForGeneClassification, EmbExtractor
28
  from geneformer.pretrainer import token_dictionary
29
- import ast
30
- import torch.nn.functional as F
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
  from geneformer import TranscriptomeTokenizer
33
 
34
- def vote(logit_pair):
 
35
  a, b = logit_pair
36
  if a > b:
37
  return 0
@@ -39,13 +49,16 @@ def vote(logit_pair):
39
  return 1
40
  elif a == b:
41
  return "tie"
42
-
 
43
  def py_softmax(vector):
44
- e = np.exp(vector)
45
- return e / e.sum()
46
-
47
- # Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
48
- def similarity(tensor1, tensor2, cosine = True):
 
 
49
  if cosine == False:
50
  if tensor1.ndimension() > 1:
51
  tensor1 = tensor1.view(1, -1)
@@ -56,7 +69,7 @@ def similarity(tensor1, tensor2, cosine = True):
56
  norm_tensor2 = torch.norm(tensor2)
57
  epsilon = 1e-8
58
  similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
59
- similarity = (similarity.item() + 1)/2
60
  else:
61
  if tensor1.shape != tensor2.shape:
62
  raise ValueError("Input tensors must have the same shape.")
@@ -65,13 +78,14 @@ def similarity(tensor1, tensor2, cosine = True):
65
  dot_product = torch.dot(tensor1, tensor2)
66
  norm_tensor1 = torch.norm(tensor1)
67
  norm_tensor2 = torch.norm(tensor2)
68
-
69
  # Avoid division by zero by adding a small epsilon
70
  epsilon = 1e-8
71
  similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
72
-
73
  return similarity.item()
74
-
 
75
  # Plots heatmap between different classes/labels
76
  def plot_similarity_heatmap(similarities):
77
  classes = list(similarities.keys())
@@ -84,100 +98,122 @@ def plot_similarity_heatmap(similarities):
84
  else:
85
  val = similarities[c][cc]
86
  arr[i][j] = val
87
-
88
  plt.figure(figsize=(8, 6))
89
- plt.imshow(arr, cmap='inferno', vmin=0, vmax=1)
90
  plt.colorbar()
91
- plt.xticks(np.arange(classlen), classes, rotation = 45, ha = 'right')
92
  plt.yticks(np.arange(classlen), classes)
93
  plt.title("Similarity Heatmap")
94
  plt.savefig("similarity_heatmap.png")
95
-
 
96
  # get cross-validated mean and sd metrics
97
  def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
98
- wts = [count/sum(all_tpr_wt) for count in all_tpr_wt]
99
-
100
- all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)]
101
  mean_tpr = np.sum(all_weighted_tpr, axis=0)
102
  mean_tpr[-1] = 1.0
103
- all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]
104
  roc_auc = np.sum(all_weighted_roc_auc)
105
- roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))
106
  return mean_tpr, roc_auc, roc_auc_sd
107
 
108
- def validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc, num_labels, pre_model):
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  # initiate eval metrics to return
110
  num_classes = len(set(labels))
111
  mean_fpr = np.linspace(0, 1, 100)
112
-
113
  # create 80/20 train/eval splits
114
- targets_train, targets_eval, labels_train, labels_eval = train_test_split(targets, labels ,test_size=0.25, shuffle=True)
 
 
115
  label_dict_train = dict(zip(targets_train, labels_train))
116
  label_dict_eval = dict(zip(targets_eval, labels_eval))
117
-
118
  # function to filter by whether contains train or eval labels
119
  def if_contains_train_label(example):
120
  a = label_dict_train.keys()
121
- b = example['input_ids']
122
  return not set(a).isdisjoint(b)
123
 
124
  def if_contains_eval_label(example):
125
  a = label_dict_eval.keys()
126
- b = example['input_ids']
127
  return not set(a).isdisjoint(b)
128
-
129
  # filter dataset for examples containing classes for this split
130
  print(f"Filtering training data")
131
  trainset = data.filter(if_contains_train_label, num_proc=num_proc)
132
- print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n")
 
 
133
  print(f"Filtering evalation data")
134
  evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
135
  print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
136
-
137
  # minimize to smaller training sample
138
  training_size = min(subsample_size, len(trainset))
139
  trainset_min = trainset.select([i for i in range(training_size)])
140
  eval_size = min(training_size, len(evalset))
141
- half_training_size = round(eval_size/2)
142
  evalset_train_min = evalset.select([i for i in range(half_training_size)])
143
  evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
144
-
145
  # label conversion functions
146
  def generate_train_labels(example):
147
- example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]]
 
 
148
  return example
149
 
150
  def generate_eval_labels(example):
151
- example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]]
 
 
152
  return example
153
-
154
- # label datasets
155
  print(f"Labeling training data")
156
  trainset_labeled = trainset_min.map(generate_train_labels)
157
  print(f"Labeling evaluation data")
158
  evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
159
  print(f"Labeling evaluation OOS data")
160
  evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
161
-
162
  # load model
163
  model = BertForTokenClassification.from_pretrained(
164
- pre_model,
165
- num_labels=num_labels,
166
- output_attentions = False,
167
- output_hidden_states = False,
168
  )
169
  if freeze_layers is not None:
170
  modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
171
  for module in modules_to_freeze:
172
  for param in module.parameters():
173
  param.requires_grad = False
174
-
175
  model = model.to(device)
176
-
177
  # add output directory to training args and initiate
178
  training_args["output_dir"] = output_dir
179
  training_args_init = TrainingArguments(**training_args)
180
-
181
  # create the trainer
182
  trainer = Trainer(
183
  model=model,
@@ -186,25 +222,40 @@ def validate(data, targets, labels, nsplits, subsample_size, training_args, free
186
  train_dataset=trainset_labeled,
187
  eval_dataset=evalset_train_labeled,
188
  )
189
-
190
  # train the gene classifier
191
  trainer.train()
192
  trainer.save_model(output_dir)
193
-
194
- fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 200, mean_fpr)
 
 
195
  auc_score = auc(fpr, tpr)
196
-
197
  return fpr, tpr, auc_score
198
-
 
199
  # cross-validate gene classifier
200
- def cross_validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc, num_labels, pre_model):
 
 
 
 
 
 
 
 
 
 
 
 
201
  # check if output directory already written to
202
  # ensure not overwriting previously saved model
203
  model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin")
204
- #if os.path.isfile(model_dir_test) == True:
205
  # raise Exception("Model already saved to this directory.")
206
-
207
- device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
208
  # initiate eval metrics to return
209
  num_classes = len(set(labels))
210
  mean_fpr = np.linspace(0, 1, 100)
@@ -212,8 +263,8 @@ def cross_validate(data, targets, labels, nsplits, subsample_size, training_args
212
  all_roc_auc = []
213
  all_tpr_wt = []
214
  label_dicts = []
215
- confusion = np.zeros((num_classes,num_classes))
216
-
217
  # set up cross-validation splits
218
  skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)
219
  # train and evaluate
@@ -223,236 +274,280 @@ def cross_validate(data, targets, labels, nsplits, subsample_size, training_args
223
  print("early stopping activated due to large # of training examples")
224
  if iteration_num == 3:
225
  break
226
-
227
  print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n")
228
-
229
  # generate cross-validation splits
230
  targets_train, targets_eval = targets[train_index], targets[eval_index]
231
  labels_train, labels_eval = labels[train_index], labels[eval_index]
232
  label_dict_train = dict(zip(targets_train, labels_train))
233
  label_dict_eval = dict(zip(targets_eval, labels_eval))
234
- label_dicts += (iteration_num, targets_train, targets_eval, labels_train, labels_eval)
235
-
 
 
 
 
 
 
236
  # function to filter by whether contains train or eval labels
237
  def if_contains_train_label(example):
238
  a = label_dict_train.keys()
239
- b = example['input_ids']
240
-
241
  return not set(a).isdisjoint(b)
242
 
243
  def if_contains_eval_label(example):
244
  a = label_dict_eval.keys()
245
- b = example['input_ids']
246
-
247
  return not set(a).isdisjoint(b)
248
-
249
  # filter dataset for examples containing classes for this split
250
  print(f"Filtering training data")
251
  trainset = data.filter(if_contains_train_label, num_proc=num_proc)
252
- print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n")
 
 
253
  print(f"Filtering evalation data")
254
  evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
255
- print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
 
 
256
 
257
  # minimize to smaller training sample
258
  training_size = min(subsample_size, len(trainset))
259
  trainset_min = trainset.select([i for i in range(training_size)])
260
  eval_size = min(training_size, len(evalset))
261
- half_training_size = round(eval_size/2)
262
  evalset_train_min = evalset.select([i for i in range(half_training_size)])
263
- evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
264
-
 
 
265
  # label conversion functions
266
  def generate_train_labels(example):
267
- example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]]
 
 
 
268
  return example
269
 
270
  def generate_eval_labels(example):
271
- example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]]
 
 
272
  return example
273
-
274
- # label datasets
275
  print(f"Labeling training data")
276
  trainset_labeled = trainset_min.map(generate_train_labels)
277
  print(f"Labeling evaluation data")
278
  evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
279
  print(f"Labeling evaluation OOS data")
280
  evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
281
-
282
  # create output directories
283
  ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
284
- ksplit_model_dir = os.path.join(ksplit_output_dir, "models/")
285
-
286
  # ensure not overwriting previously saved model
287
  model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin")
288
- #if os.path.isfile(model_output_file) == True:
289
  # raise Exception("Model already saved to this directory.")
290
 
291
  # make training and model output directories
292
- subprocess.call(f'mkdir -p {ksplit_output_dir}', shell=True)
293
- subprocess.call(f'mkdir -p {ksplit_model_dir}', shell=True)
294
-
295
  # load model
296
  model = BertForTokenClassification.from_pretrained(
297
  pre_model,
298
  num_labels=num_labels,
299
- output_attentions = False,
300
- output_hidden_states = False,
301
  )
302
  if freeze_layers is not None:
303
  modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
304
  for module in modules_to_freeze:
305
  for param in module.parameters():
306
  param.requires_grad = False
307
-
308
  model = model.to(device)
309
-
310
  # add output directory to training args and initiate
311
  training_args["output_dir"] = ksplit_output_dir
312
  training_args_init = TrainingArguments(**training_args)
313
-
314
  # create the trainer
315
  trainer = Trainer(
316
  model=model,
317
  args=training_args_init,
318
  data_collator=DataCollatorForGeneClassification(),
319
  train_dataset=trainset_labeled,
320
- eval_dataset=evalset_train_labeled
321
  )
322
 
323
  # train the gene classifier
324
  trainer.train()
325
-
326
  # save model
327
  trainer.save_model(ksplit_model_dir)
328
-
329
  # evaluate model
330
- fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 200, mean_fpr)
331
-
 
 
332
  # append to tpr and roc lists
333
  confusion = confusion + conf_mat
334
  all_tpr.append(interp_tpr)
335
  all_roc_auc.append(auc(fpr, tpr))
336
  # append number of eval examples by which to weight tpr in averaged graphs
337
  all_tpr_wt.append(len(tpr))
338
-
339
  iteration_num = iteration_num + 1
340
-
341
  # get overall metrics for cross-validation
342
- mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt)
 
 
343
  return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts
344
-
 
345
  # Computes metrics
346
  def compute_metrics(pred):
347
  labels = pred.label_ids
348
  preds = pred.predictions.argmax(-1)
349
  # calculate accuracy and macro f1 using sklearn's function
350
  acc = accuracy_score(labels, preds)
351
- macro_f1 = f1_score(labels, preds, average='macro')
352
-
353
- return {
354
- 'accuracy': acc,
355
- 'macro_f1': macro_f1
356
- }
357
 
358
  # plot ROC curve
359
  def plot_ROC(bundled_data, title):
360
  plt.figure()
361
  lw = 2
362
  for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
363
- plt.plot(mean_fpr, mean_tpr, color=color,
364
- lw=lw, label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(sample, roc_auc, roc_auc_sd))
365
-
366
- plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
 
 
 
 
 
 
 
367
  plt.xlim([0.0, 1.0])
368
  plt.ylim([0.0, 1.05])
369
- plt.xlabel('False Positive Rate')
370
- plt.ylabel('True Positive Rate')
371
  plt.title(title)
372
  plt.legend(loc="lower right")
373
  plt.savefig("ROC.png")
374
-
375
  return mean_fpr, mean_tpr, roc_auc
376
-
 
377
  # plot confusion matrix
378
  def plot_confusion_matrix(classes_list, conf_mat, title):
379
  display_labels = []
380
  i = 0
381
  for label in classes_list:
382
- display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:,i]))]
383
  i = i + 1
384
- display = ConfusionMatrixDisplay(confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"),
385
- display_labels=display_labels)
386
- display.plot(cmap="Blues",values_format=".2g")
 
 
387
  plt.title(title)
388
  plt.savefig("CM.png")
389
-
 
390
  # Function to find the largest number smaller
391
  # than or equal to N that is divisible by k
392
  def find_largest_div(N, K):
393
  rem = N % K
394
- if(rem == 0):
395
  return N
396
  else:
397
  return N - rem
398
-
 
399
  def preprocess_classifier_batch(cell_batch, max_len):
400
  if max_len == None:
401
  max_len = max([len(i) for i in cell_batch["input_ids"]])
 
402
  def pad_label_example(example):
403
- example["labels"] = np.pad(example["labels"],
404
- (0, max_len-len(example["input_ids"])),
405
- mode='constant', constant_values=-100)
406
- example["input_ids"] = np.pad(example["input_ids"],
407
- (0, max_len-len(example["input_ids"])),
408
- mode='constant', constant_values=token_dictionary.get("<pad>"))
409
- example["attention_mask"] = (example["input_ids"] != token_dictionary.get("<pad>")).astype(int)
 
 
 
 
 
 
 
 
410
  return example
 
411
  padded_batch = cell_batch.map(pad_label_example)
412
  return padded_batch
413
 
 
414
  # forward batch size is batch size for model inference (e.g. 200)
415
  def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
416
  predict_logits = []
417
  predict_labels = []
418
- model.to('cpu')
419
  model.eval()
420
-
421
  # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
422
  evalset_len = len(evalset)
423
  max_divisible = find_largest_div(evalset_len, forward_batch_size)
424
  if len(evalset) - max_divisible == 1:
425
  evalset_len = max_divisible
426
-
427
  max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
428
-
429
  for i in range(0, evalset_len, forward_batch_size):
430
- max_range = min(i+forward_batch_size, evalset_len)
431
  batch_evalset = evalset.select([i for i in range(i, max_range)])
432
  padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
433
  padded_batch.set_format(type="torch")
434
-
435
  input_data_batch = padded_batch["input_ids"]
436
  attn_msk_batch = padded_batch["attention_mask"]
437
  label_batch = padded_batch["labels"]
438
  with torch.no_grad():
439
  input_ids = input_data_batch
440
  attn_mask = attn_msk_batch
441
- labels = label_batch
442
  outputs = model(
443
-
444
- input_ids = input_ids,
445
- attention_mask = attn_mask,
446
- labels = labels
447
  )
448
  predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
449
  predict_labels += [torch.squeeze(label_batch.to("cpu"))]
450
-
451
  logits_by_cell = torch.cat(predict_logits)
452
  all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
453
  labels_by_cell = torch.cat(predict_labels)
454
  all_labels = torch.flatten(labels_by_cell)
455
- logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100]
 
 
 
 
456
  y_pred = [vote(item[0]) for item in logit_label_paired]
457
  y_true = [item[1] for item in logit_label_paired]
458
  logits_list = [item[0] for item in logit_label_paired]
@@ -464,106 +559,133 @@ def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
464
  plt.plot(fpr, tpr)
465
  plt.xlim([0.0, 1.0])
466
  plt.ylim([0.0, 1.05])
467
- plt.xlabel('False Positive Rate')
468
- plt.ylabel('True Positive Rate')
469
- plt.title('ROC')
470
  plt.show()
471
  # interpolate to graph
472
  interp_tpr = np.interp(mean_fpr, fpr, tpr)
473
  interp_tpr[0] = 0.0
474
- return fpr, tpr, interp_tpr, conf_mat
475
-
476
- def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv", genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
477
- corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
478
- max_input_size = 2 ** 11, max_lr = 5e-5, freeze_layers = 4, num_gpus = 1, num_proc = os.cpu_count(), geneformer_batch_size = 9, epochs = 1, filter_dataset = 50_000,
479
- emb_extract = True, emb_layer = 0, forward_batch = 200, filter_data = None, inference = False, k_validate = True, model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/", skip_training = False, emb_dir = 'gene_emb', output_dir = None, max_cells = 1000, num_cpus = os.cpu_count()):
480
-
481
-
482
- """"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  Primary Parameters
484
  -----------
485
-
486
  gene_info: path
487
  Path to gene mappings
488
-
489
  corpus_30M: path
490
  Path to 30M Gene Corpus
491
-
492
  model: path
493
  Path to pretrained GeneFormer model
494
-
495
  genes: path
496
  Path to csv file containing different columns of genes and the column labels
497
-
498
  inference: bool
499
  Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
500
-
501
  k_validate: bool
502
  Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
503
-
504
  skip_training: bool
505
  Whether the model should skip the training portion. Defaults to False
506
-
507
  emb_extract: bool
508
  WHether the model should extract embeddings for a given gene (WIP)
509
-
510
-
511
  Customization Parameters
512
  -----------
513
-
514
  freeze_layers: int
515
  Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
516
-
517
  filter_dataset: int
518
  Number of cells to filter from 30M dataset. Default is 50_000
519
-
520
  emb_layer: int
521
  What layer embeddings are extracted from. Default is 4
522
-
523
  filter_data: str, list
524
  Filters down embeddings to a single category. Default is None
525
-
526
-
527
  """
528
-
529
  # table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)
530
  gene_info = pd.read_csv(gene_info, index_col=0)
531
  labels = gene_info.columns
532
 
533
  # create dictionaries for corresponding attributes
534
- gene_id_type_dict = dict(zip(gene_info["ensembl_id"],gene_info["gene_type"]))
535
- gene_name_id_dict = dict(zip(gene_info["gene_name"],gene_info["ensembl_id"]))
536
- gene_id_name_dict = {v: k for k,v in gene_name_id_dict.items()}
537
 
538
  # function for preparing targets and labels
539
  def prep_inputs(label_store, id_type):
540
  target_list = []
541
  if id_type == "gene_name":
542
  for key in list(label_store.keys()):
543
- targets = [gene_name_id_dict[gene] for gene in label_store[key] if gene_name_id_dict.get(gene) in token_dictionary]
 
 
 
 
544
  targets_id = [token_dictionary[gene] for gene in targets]
545
  target_list.append(targets_id)
546
  elif id_type == "ensembl_id":
547
  for key in list(label_store.keys()):
548
- targets = [gene for gene in label_store[key] if gene in token_dictionary]
 
 
549
  targets_id = [token_dictionary[gene] for gene in targets]
550
  target_list.append(targets_id)
551
-
552
  targets, labels = [], []
553
  for targ in target_list:
554
  targets = targets + targ
555
  targets = np.array(targets)
556
  for num, targ in enumerate(target_list):
557
- label = [num]*len(targ)
558
  labels = labels + label
559
  labels = np.array(labels)
560
  unique_labels = num + 1
561
-
562
- nsplits = min(5, min([len(targ) for targ in target_list])-1)
563
  assert nsplits > 2
564
-
565
  return targets, labels, nsplits, unique_labels
566
-
567
  if skip_training == False:
568
  # preparing targets and labels for dosage sensitive vs insensitive TFs
569
  gene_classes = pd.read_csv(genes, header=0)
@@ -575,26 +697,26 @@ def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_tab
575
  else:
576
  labels = [filter_data]
577
  label_store = {}
578
-
579
  # Dictionary for decoding labels
580
- decode = {i:labels[i] for i in range(len(labels))}
581
-
582
  for label in labels:
583
  label_store[label] = gene_classes[label].dropna()
584
-
585
  targets, labels, nsplits, unique_labels = prep_inputs(label_store, "ensembl_id")
586
-
587
-
588
-
589
  # load training dataset
590
- train_dataset=load_from_disk(corpus_30M)
591
  shuffled_train_dataset = train_dataset.shuffle(seed=42)
592
- subsampled_train_dataset = shuffled_train_dataset.select([i for i in range(filter_dataset)])
 
 
593
  lr_schedule_fn = "linear"
594
  warmup_steps = 500
595
  optimizer = "adamw"
596
  subsample_size = 10_000
597
-
598
  training_args = {
599
  "learning_rate": max_lr,
600
  "do_train": True,
@@ -611,52 +733,95 @@ def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_tab
611
  "per_device_eval_batch_size": geneformer_batch_size,
612
  "num_train_epochs": epochs,
613
  }
614
-
615
  # define output directory path
616
  current_date = datetime.datetime.now()
617
  datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
618
-
619
  if output_dir == None:
620
- training_output_dir = Path(f"{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/")
 
 
621
  else:
622
  training_output_dir = Path(output_dir)
623
-
624
  # make output directory
625
- subprocess.call(f'mkdir -p {training_output_dir}', shell=True)
626
-
627
  # Places number of classes + in directory
628
  num_classes = len(set(labels))
629
  info_list = [num_classes, decode]
630
-
631
- with open(training_output_dir / 'classes.txt', 'w') as f:
632
  f.write(str(info_list))
633
-
634
- subsampled_train_dataset.save_to_disk(output_dir / 'dataset')
635
-
636
  if k_validate == True:
637
- ksplit_model ="ksplit0/models"
638
  ksplit_model_test = os.path.join(training_output_dir, ksplit_model)
639
- #if os.path.isfile(ksplit_model_test) == True:
640
  # raise Exception("Model already saved to this directory.")
641
- # cross-validate gene classifier
642
- all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts = cross_validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1, unique_labels, model)
643
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
  bundled_data = []
645
- bundled_data += [(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red")]
646
- graph_title = " ".join([i + ' vs' if count < len(label_store) - 1 else i for count, i in enumerate(label_store)])
647
- fpr, tpr, auc = plot_ROC(bundled_data, 'Dosage Sensitive vs Insensitive TFs')
 
 
 
 
 
 
 
 
 
648
  print(auc)
649
  # plot confusion matrix
650
  plot_confusion_matrix(label_store, confusion, "Geneformer")
651
- else:
652
- fpr, tpr, auc = validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1, unique_labels, model)
 
 
 
 
 
 
 
 
 
 
 
 
653
  print(auc)
654
-
655
  if inference == True:
656
  # preparing targets and labels for dosage sensitive vs insensitive TFs
657
  gene_classes = pd.read_csv(genes, header=0)
658
  targets = []
659
- for column in gene_classes.columns:
660
  targets += list(gene_classes[column])
661
  tokens = []
662
  for target in targets:
@@ -664,47 +829,48 @@ def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_tab
664
  tokens.append(token_dictionary[target])
665
  except:
666
  tokens.append(0)
667
-
668
  targets = torch.LongTensor([tokens])
669
-
670
-
671
- with open(f'{model_location}classes.txt', 'r') as f:
672
  info_list = ast.literal_eval(f.read())
673
  num_classes = info_list[0]
674
  labels = info_list[1]
675
-
676
  model = BertForTokenClassification.from_pretrained(
677
  model_location,
678
  num_labels=num_classes,
679
- output_attentions = False,
680
- output_hidden_states = False,
681
- local_files_only = True
682
  )
683
  if freeze_layers is not None:
684
  modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
685
  for module in modules_to_freeze:
686
  for param in module.parameters():
687
  param.requires_grad = False
688
-
689
  model = model.to(device)
690
-
691
  # evaluate model
692
- predictions = F.softmax(model(targets.to(device))["logits"], dim = -1).argmax(-1)[0]
 
 
693
  predictions = [labels[int(pred)] for pred in predictions]
694
-
695
  return predictions
696
-
697
- # Extracts aggregate gene embeddings for each label
698
  if emb_extract == True:
699
- with open(f'{model_location}/classes.txt', 'r') as f:
700
  data = ast.literal_eval(f.read())
701
  num_classes = data[0]
702
  decode = data[1]
703
-
704
  gene_classes = pd.read_csv(genes, header=0)
705
  labels = gene_classes.columns
706
  tokenize = TranscriptomeTokenizer()
707
-
708
  label_dict = {}
709
  for label in labels:
710
  genes = gene_classes[label]
@@ -715,32 +881,55 @@ def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_tab
715
  except:
716
  continue
717
  label_dict[label] = tokenized_genes
718
-
719
- embex = EmbExtractor(model_type="GeneClassifier", num_classes=num_classes, emb_mode = "gene",
720
- filter_data=None, max_ncells=max_cells, emb_layer=emb_layer,
721
- emb_label=label_dict, labels_to_plot=list(labels), forward_batch_size=forward_batch, nproc=num_cpus)
722
-
723
-
724
- subprocess.call(f'mkdir -p {emb_dir}', shell = True)
725
-
726
- embs = embex.extract_embs(model_directory = model_location, input_data_file = model_location / 'dataset', output_directory = emb_dir, output_prefix = f"{label}_embbeddings")
727
-
728
- emb_dict = {label:[] for label in list(set(labels))}
729
- similarities = {key:{} for key in list(emb_dict.keys())}
730
-
 
 
 
 
 
 
 
 
 
 
 
 
 
731
  for column in embs.columns:
732
  remaining_cols = [k for k in embs.columns if k != column]
733
  for k in remaining_cols:
734
  embedding = torch.Tensor(embs[k])
735
- sim = similarity(torch.Tensor(embs[column]), embedding, cosine = True)
736
  similarities[column][k] = sim
737
-
738
  plot_similarity_heatmap(similarities)
739
  print(similarities)
740
-
741
  return similarities
742
-
743
- if __name__ == '__main__':
744
- classify_genes(k_validate = False, inference = False, skip_training = False, emb_extract = True, output_dir = Path('gene_emb'), model_location = Path('gene_emb'), epochs = 5, gene_info = "../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_info_table.csv", genes = "../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", corpus_30M = "../GeneFormer_repo/Genecorpus-30M/genecorpus_30M_2048.dataset/")
745
 
746
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
+
4
+ GPU_NUMBER = [0] # CHANGE WITH MULTIGPU
5
  os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
6
  os.environ["NCCL_DEBUG"] = "INFO"
7
 
8
+ import ast
 
9
  import datetime
10
+ import math
11
+ import pickle
12
  import subprocess
13
  from pathlib import Path
14
+
15
  import matplotlib.pyplot as plt
16
  import numpy as np
 
17
  import pandas as pd
 
 
 
 
18
  import torch
19
+ import torch.nn.functional as F
20
+ from datasets import Dataset, load_from_disk
21
+ from sklearn import preprocessing
22
+ from sklearn.metrics import (
23
+ ConfusionMatrixDisplay,
24
+ accuracy_score,
25
+ auc,
26
+ confusion_matrix,
27
+ roc_auc_score,
28
+ roc_curve,
29
+ )
30
+
31
+ # imports
32
+ from sklearn.model_selection import StratifiedKFold, train_test_split
33
  from tqdm.notebook import tqdm
34
+ from transformers import BertForTokenClassification, Trainer
35
+ from transformers.training_args import TrainingArguments
36
+
37
  from geneformer import DataCollatorForGeneClassification, EmbExtractor
38
  from geneformer.pretrainer import token_dictionary
39
+
 
40
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
  from geneformer import TranscriptomeTokenizer
42
 
43
+
44
+ def vote(logit_pair):
45
  a, b = logit_pair
46
  if a > b:
47
  return 0
 
49
  return 1
50
  elif a == b:
51
  return "tie"
52
+
53
+
54
  def py_softmax(vector):
55
+ e = np.exp(vector)
56
+ return e / e.sum()
57
+
58
+ # Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
59
+
60
+
61
+ def similarity(tensor1, tensor2, cosine=True):
62
  if cosine == False:
63
  if tensor1.ndimension() > 1:
64
  tensor1 = tensor1.view(1, -1)
 
69
  norm_tensor2 = torch.norm(tensor2)
70
  epsilon = 1e-8
71
  similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
72
+ similarity = (similarity.item() + 1) / 2
73
  else:
74
  if tensor1.shape != tensor2.shape:
75
  raise ValueError("Input tensors must have the same shape.")
 
78
  dot_product = torch.dot(tensor1, tensor2)
79
  norm_tensor1 = torch.norm(tensor1)
80
  norm_tensor2 = torch.norm(tensor2)
81
+
82
  # Avoid division by zero by adding a small epsilon
83
  epsilon = 1e-8
84
  similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
85
+
86
  return similarity.item()
87
+
88
+
89
  # Plots heatmap between different classes/labels
90
  def plot_similarity_heatmap(similarities):
91
  classes = list(similarities.keys())
 
98
  else:
99
  val = similarities[c][cc]
100
  arr[i][j] = val
101
+
102
  plt.figure(figsize=(8, 6))
103
+ plt.imshow(arr, cmap="inferno", vmin=0, vmax=1)
104
  plt.colorbar()
105
+ plt.xticks(np.arange(classlen), classes, rotation=45, ha="right")
106
  plt.yticks(np.arange(classlen), classes)
107
  plt.title("Similarity Heatmap")
108
  plt.savefig("similarity_heatmap.png")
109
+
110
+
111
  # get cross-validated mean and sd metrics
112
  def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
113
+ wts = [count / sum(all_tpr_wt) for count in all_tpr_wt]
114
+
115
+ all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)]
116
  mean_tpr = np.sum(all_weighted_tpr, axis=0)
117
  mean_tpr[-1] = 1.0
118
+ all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)]
119
  roc_auc = np.sum(all_weighted_roc_auc)
120
+ roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts))
121
  return mean_tpr, roc_auc, roc_auc_sd
122
 
123
+
124
+ def validate(
125
+ data,
126
+ targets,
127
+ labels,
128
+ nsplits,
129
+ subsample_size,
130
+ training_args,
131
+ freeze_layers,
132
+ output_dir,
133
+ num_proc,
134
+ num_labels,
135
+ pre_model,
136
+ ):
137
  # initiate eval metrics to return
138
  num_classes = len(set(labels))
139
  mean_fpr = np.linspace(0, 1, 100)
140
+
141
  # create 80/20 train/eval splits
142
+ targets_train, targets_eval, labels_train, labels_eval = train_test_split(
143
+ targets, labels, test_size=0.25, shuffle=True
144
+ )
145
  label_dict_train = dict(zip(targets_train, labels_train))
146
  label_dict_eval = dict(zip(targets_eval, labels_eval))
147
+
148
  # function to filter by whether contains train or eval labels
149
  def if_contains_train_label(example):
150
  a = label_dict_train.keys()
151
+ b = example["input_ids"]
152
  return not set(a).isdisjoint(b)
153
 
154
  def if_contains_eval_label(example):
155
  a = label_dict_eval.keys()
156
+ b = example["input_ids"]
157
  return not set(a).isdisjoint(b)
158
+
159
  # filter dataset for examples containing classes for this split
160
  print(f"Filtering training data")
161
  trainset = data.filter(if_contains_train_label, num_proc=num_proc)
162
+ print(
163
+ f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n"
164
+ )
165
  print(f"Filtering evalation data")
166
  evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
167
  print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
168
+
169
  # minimize to smaller training sample
170
  training_size = min(subsample_size, len(trainset))
171
  trainset_min = trainset.select([i for i in range(training_size)])
172
  eval_size = min(training_size, len(evalset))
173
+ half_training_size = round(eval_size / 2)
174
  evalset_train_min = evalset.select([i for i in range(half_training_size)])
175
  evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
176
+
177
  # label conversion functions
178
  def generate_train_labels(example):
179
+ example["labels"] = [
180
+ label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
181
+ ]
182
  return example
183
 
184
  def generate_eval_labels(example):
185
+ example["labels"] = [
186
+ label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
187
+ ]
188
  return example
189
+
190
+ # label datasets
191
  print(f"Labeling training data")
192
  trainset_labeled = trainset_min.map(generate_train_labels)
193
  print(f"Labeling evaluation data")
194
  evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
195
  print(f"Labeling evaluation OOS data")
196
  evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
197
+
198
  # load model
199
  model = BertForTokenClassification.from_pretrained(
200
+ pre_model,
201
+ num_labels=num_labels,
202
+ output_attentions=False,
203
+ output_hidden_states=False,
204
  )
205
  if freeze_layers is not None:
206
  modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
207
  for module in modules_to_freeze:
208
  for param in module.parameters():
209
  param.requires_grad = False
210
+
211
  model = model.to(device)
212
+
213
  # add output directory to training args and initiate
214
  training_args["output_dir"] = output_dir
215
  training_args_init = TrainingArguments(**training_args)
216
+
217
  # create the trainer
218
  trainer = Trainer(
219
  model=model,
 
222
  train_dataset=trainset_labeled,
223
  eval_dataset=evalset_train_labeled,
224
  )
225
+
226
  # train the gene classifier
227
  trainer.train()
228
  trainer.save_model(output_dir)
229
+
230
+ fpr, tpr, interp_tpr, conf_mat = classifier_predict(
231
+ trainer.model, evalset_oos_labeled, 200, mean_fpr
232
+ )
233
  auc_score = auc(fpr, tpr)
234
+
235
  return fpr, tpr, auc_score
236
+
237
+
238
  # cross-validate gene classifier
239
+ def cross_validate(
240
+ data,
241
+ targets,
242
+ labels,
243
+ nsplits,
244
+ subsample_size,
245
+ training_args,
246
+ freeze_layers,
247
+ output_dir,
248
+ num_proc,
249
+ num_labels,
250
+ pre_model,
251
+ ):
252
  # check if output directory already written to
253
  # ensure not overwriting previously saved model
254
  model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin")
255
+ # if os.path.isfile(model_dir_test) == True:
256
  # raise Exception("Model already saved to this directory.")
257
+
258
+ device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
259
  # initiate eval metrics to return
260
  num_classes = len(set(labels))
261
  mean_fpr = np.linspace(0, 1, 100)
 
263
  all_roc_auc = []
264
  all_tpr_wt = []
265
  label_dicts = []
266
+ confusion = np.zeros((num_classes, num_classes))
267
+
268
  # set up cross-validation splits
269
  skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)
270
  # train and evaluate
 
274
  print("early stopping activated due to large # of training examples")
275
  if iteration_num == 3:
276
  break
277
+
278
  print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n")
279
+
280
  # generate cross-validation splits
281
  targets_train, targets_eval = targets[train_index], targets[eval_index]
282
  labels_train, labels_eval = labels[train_index], labels[eval_index]
283
  label_dict_train = dict(zip(targets_train, labels_train))
284
  label_dict_eval = dict(zip(targets_eval, labels_eval))
285
+ label_dicts += (
286
+ iteration_num,
287
+ targets_train,
288
+ targets_eval,
289
+ labels_train,
290
+ labels_eval,
291
+ )
292
+
293
  # function to filter by whether contains train or eval labels
294
  def if_contains_train_label(example):
295
  a = label_dict_train.keys()
296
+ b = example["input_ids"]
297
+
298
  return not set(a).isdisjoint(b)
299
 
300
  def if_contains_eval_label(example):
301
  a = label_dict_eval.keys()
302
+ b = example["input_ids"]
303
+
304
  return not set(a).isdisjoint(b)
305
+
306
  # filter dataset for examples containing classes for this split
307
  print(f"Filtering training data")
308
  trainset = data.filter(if_contains_train_label, num_proc=num_proc)
309
+ print(
310
+ f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n"
311
+ )
312
  print(f"Filtering evalation data")
313
  evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
314
+ print(
315
+ f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n"
316
+ )
317
 
318
  # minimize to smaller training sample
319
  training_size = min(subsample_size, len(trainset))
320
  trainset_min = trainset.select([i for i in range(training_size)])
321
  eval_size = min(training_size, len(evalset))
322
+ half_training_size = round(eval_size / 2)
323
  evalset_train_min = evalset.select([i for i in range(half_training_size)])
324
+ evalset_oos_min = evalset.select(
325
+ [i for i in range(half_training_size, eval_size)]
326
+ )
327
+
328
  # label conversion functions
329
  def generate_train_labels(example):
330
+ example["labels"] = [
331
+ label_dict_train.get(token_id, -100)
332
+ for token_id in example["input_ids"]
333
+ ]
334
  return example
335
 
336
  def generate_eval_labels(example):
337
+ example["labels"] = [
338
+ label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
339
+ ]
340
  return example
341
+
342
+ # label datasets
343
  print(f"Labeling training data")
344
  trainset_labeled = trainset_min.map(generate_train_labels)
345
  print(f"Labeling evaluation data")
346
  evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
347
  print(f"Labeling evaluation OOS data")
348
  evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
349
+
350
  # create output directories
351
  ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
352
+ ksplit_model_dir = os.path.join(ksplit_output_dir, "models/")
353
+
354
  # ensure not overwriting previously saved model
355
  model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin")
356
+ # if os.path.isfile(model_output_file) == True:
357
  # raise Exception("Model already saved to this directory.")
358
 
359
  # make training and model output directories
360
+ subprocess.call(f"mkdir -p {ksplit_output_dir}", shell=True)
361
+ subprocess.call(f"mkdir -p {ksplit_model_dir}", shell=True)
362
+
363
  # load model
364
  model = BertForTokenClassification.from_pretrained(
365
  pre_model,
366
  num_labels=num_labels,
367
+ output_attentions=False,
368
+ output_hidden_states=False,
369
  )
370
  if freeze_layers is not None:
371
  modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
372
  for module in modules_to_freeze:
373
  for param in module.parameters():
374
  param.requires_grad = False
375
+
376
  model = model.to(device)
377
+
378
  # add output directory to training args and initiate
379
  training_args["output_dir"] = ksplit_output_dir
380
  training_args_init = TrainingArguments(**training_args)
381
+
382
  # create the trainer
383
  trainer = Trainer(
384
  model=model,
385
  args=training_args_init,
386
  data_collator=DataCollatorForGeneClassification(),
387
  train_dataset=trainset_labeled,
388
+ eval_dataset=evalset_train_labeled,
389
  )
390
 
391
  # train the gene classifier
392
  trainer.train()
393
+
394
  # save model
395
  trainer.save_model(ksplit_model_dir)
396
+
397
  # evaluate model
398
+ fpr, tpr, interp_tpr, conf_mat = classifier_predict(
399
+ trainer.model, evalset_oos_labeled, 200, mean_fpr
400
+ )
401
+
402
  # append to tpr and roc lists
403
  confusion = confusion + conf_mat
404
  all_tpr.append(interp_tpr)
405
  all_roc_auc.append(auc(fpr, tpr))
406
  # append number of eval examples by which to weight tpr in averaged graphs
407
  all_tpr_wt.append(len(tpr))
408
+
409
  iteration_num = iteration_num + 1
410
+
411
  # get overall metrics for cross-validation
412
+ mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(
413
+ all_tpr, all_roc_auc, all_tpr_wt
414
+ )
415
  return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts
416
+
417
+
418
  # Computes metrics
419
  def compute_metrics(pred):
420
  labels = pred.label_ids
421
  preds = pred.predictions.argmax(-1)
422
  # calculate accuracy and macro f1 using sklearn's function
423
  acc = accuracy_score(labels, preds)
424
+ macro_f1 = f1_score(labels, preds, average="macro")
425
+
426
+ return {"accuracy": acc, "macro_f1": macro_f1}
427
+
 
 
428
 
429
  # plot ROC curve
430
  def plot_ROC(bundled_data, title):
431
  plt.figure()
432
  lw = 2
433
  for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
434
+ plt.plot(
435
+ mean_fpr,
436
+ mean_tpr,
437
+ color=color,
438
+ lw=lw,
439
+ label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(
440
+ sample, roc_auc, roc_auc_sd
441
+ ),
442
+ )
443
+
444
+ plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--")
445
  plt.xlim([0.0, 1.0])
446
  plt.ylim([0.0, 1.05])
447
+ plt.xlabel("False Positive Rate")
448
+ plt.ylabel("True Positive Rate")
449
  plt.title(title)
450
  plt.legend(loc="lower right")
451
  plt.savefig("ROC.png")
452
+
453
  return mean_fpr, mean_tpr, roc_auc
454
+
455
+
456
  # plot confusion matrix
457
  def plot_confusion_matrix(classes_list, conf_mat, title):
458
  display_labels = []
459
  i = 0
460
  for label in classes_list:
461
+ display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:, i]))]
462
  i = i + 1
463
+ display = ConfusionMatrixDisplay(
464
+ confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"),
465
+ display_labels=display_labels,
466
+ )
467
+ display.plot(cmap="Blues", values_format=".2g")
468
  plt.title(title)
469
  plt.savefig("CM.png")
470
+
471
+
472
  # Function to find the largest number smaller
473
  # than or equal to N that is divisible by k
474
  def find_largest_div(N, K):
475
  rem = N % K
476
+ if rem == 0:
477
  return N
478
  else:
479
  return N - rem
480
+
481
+
482
  def preprocess_classifier_batch(cell_batch, max_len):
483
  if max_len == None:
484
  max_len = max([len(i) for i in cell_batch["input_ids"]])
485
+
486
  def pad_label_example(example):
487
+ example["labels"] = np.pad(
488
+ example["labels"],
489
+ (0, max_len - len(example["input_ids"])),
490
+ mode="constant",
491
+ constant_values=-100,
492
+ )
493
+ example["input_ids"] = np.pad(
494
+ example["input_ids"],
495
+ (0, max_len - len(example["input_ids"])),
496
+ mode="constant",
497
+ constant_values=token_dictionary.get("<pad>"),
498
+ )
499
+ example["attention_mask"] = (
500
+ example["input_ids"] != token_dictionary.get("<pad>")
501
+ ).astype(int)
502
  return example
503
+
504
  padded_batch = cell_batch.map(pad_label_example)
505
  return padded_batch
506
 
507
+
508
  # forward batch size is batch size for model inference (e.g. 200)
509
  def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
510
  predict_logits = []
511
  predict_labels = []
512
+ model.to("cpu")
513
  model.eval()
514
+
515
  # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
516
  evalset_len = len(evalset)
517
  max_divisible = find_largest_div(evalset_len, forward_batch_size)
518
  if len(evalset) - max_divisible == 1:
519
  evalset_len = max_divisible
520
+
521
  max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
522
+
523
  for i in range(0, evalset_len, forward_batch_size):
524
+ max_range = min(i + forward_batch_size, evalset_len)
525
  batch_evalset = evalset.select([i for i in range(i, max_range)])
526
  padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
527
  padded_batch.set_format(type="torch")
528
+
529
  input_data_batch = padded_batch["input_ids"]
530
  attn_msk_batch = padded_batch["attention_mask"]
531
  label_batch = padded_batch["labels"]
532
  with torch.no_grad():
533
  input_ids = input_data_batch
534
  attn_mask = attn_msk_batch
535
+ labels = label_batch
536
  outputs = model(
537
+ input_ids=input_ids, attention_mask=attn_mask, labels=labels
 
 
 
538
  )
539
  predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
540
  predict_labels += [torch.squeeze(label_batch.to("cpu"))]
541
+
542
  logits_by_cell = torch.cat(predict_logits)
543
  all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
544
  labels_by_cell = torch.cat(predict_labels)
545
  all_labels = torch.flatten(labels_by_cell)
546
+ logit_label_paired = [
547
+ item
548
+ for item in list(zip(all_logits.tolist(), all_labels.tolist()))
549
+ if item[1] != -100
550
+ ]
551
  y_pred = [vote(item[0]) for item in logit_label_paired]
552
  y_true = [item[1] for item in logit_label_paired]
553
  logits_list = [item[0] for item in logit_label_paired]
 
559
  plt.plot(fpr, tpr)
560
  plt.xlim([0.0, 1.0])
561
  plt.ylim([0.0, 1.05])
562
+ plt.xlabel("False Positive Rate")
563
+ plt.ylabel("True Positive Rate")
564
+ plt.title("ROC")
565
  plt.show()
566
  # interpolate to graph
567
  interp_tpr = np.interp(mean_fpr, fpr, tpr)
568
  interp_tpr[0] = 0.0
569
+ return fpr, tpr, interp_tpr, conf_mat
570
+
571
+
572
+ def classify_genes(
573
+ gene_info="Genecorpus-30M/example_input_files/gene_info_table.csv",
574
+ genes="Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
575
+ corpus_30M="Genecorpus-30M/genecorpus_30M_2048.dataset/",
576
+ model=".",
577
+ max_input_size=2**11,
578
+ max_lr=5e-5,
579
+ freeze_layers=4,
580
+ num_gpus=1,
581
+ num_proc=os.cpu_count(),
582
+ geneformer_batch_size=9,
583
+ epochs=1,
584
+ filter_dataset=50_000,
585
+ emb_extract=True,
586
+ emb_layer=0,
587
+ forward_batch=200,
588
+ filter_data=None,
589
+ inference=False,
590
+ k_validate=True,
591
+ model_location="230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
592
+ skip_training=False,
593
+ emb_dir="gene_emb",
594
+ output_dir=None,
595
+ max_cells=1000,
596
+ num_cpus=os.cpu_count(),
597
+ ):
598
+ """ "
599
  Primary Parameters
600
  -----------
601
+
602
  gene_info: path
603
  Path to gene mappings
604
+
605
  corpus_30M: path
606
  Path to 30M Gene Corpus
607
+
608
  model: path
609
  Path to pretrained GeneFormer model
610
+
611
  genes: path
612
  Path to csv file containing different columns of genes and the column labels
613
+
614
  inference: bool
615
  Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
616
+
617
  k_validate: bool
618
  Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
619
+
620
  skip_training: bool
621
  Whether the model should skip the training portion. Defaults to False
622
+
623
  emb_extract: bool
624
  WHether the model should extract embeddings for a given gene (WIP)
625
+
626
+
627
  Customization Parameters
628
  -----------
629
+
630
  freeze_layers: int
631
  Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
632
+
633
  filter_dataset: int
634
  Number of cells to filter from 30M dataset. Default is 50_000
635
+
636
  emb_layer: int
637
  What layer embeddings are extracted from. Default is 4
638
+
639
  filter_data: str, list
640
  Filters down embeddings to a single category. Default is None
641
+
642
+
643
  """
644
+
645
  # table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)
646
  gene_info = pd.read_csv(gene_info, index_col=0)
647
  labels = gene_info.columns
648
 
649
  # create dictionaries for corresponding attributes
650
+ gene_id_type_dict = dict(zip(gene_info["ensembl_id"], gene_info["gene_type"]))
651
+ gene_name_id_dict = dict(zip(gene_info["gene_name"], gene_info["ensembl_id"]))
652
+ gene_id_name_dict = {v: k for k, v in gene_name_id_dict.items()}
653
 
654
  # function for preparing targets and labels
655
  def prep_inputs(label_store, id_type):
656
  target_list = []
657
  if id_type == "gene_name":
658
  for key in list(label_store.keys()):
659
+ targets = [
660
+ gene_name_id_dict[gene]
661
+ for gene in label_store[key]
662
+ if gene_name_id_dict.get(gene) in token_dictionary
663
+ ]
664
  targets_id = [token_dictionary[gene] for gene in targets]
665
  target_list.append(targets_id)
666
  elif id_type == "ensembl_id":
667
  for key in list(label_store.keys()):
668
+ targets = [
669
+ gene for gene in label_store[key] if gene in token_dictionary
670
+ ]
671
  targets_id = [token_dictionary[gene] for gene in targets]
672
  target_list.append(targets_id)
673
+
674
  targets, labels = [], []
675
  for targ in target_list:
676
  targets = targets + targ
677
  targets = np.array(targets)
678
  for num, targ in enumerate(target_list):
679
+ label = [num] * len(targ)
680
  labels = labels + label
681
  labels = np.array(labels)
682
  unique_labels = num + 1
683
+
684
+ nsplits = min(5, min([len(targ) for targ in target_list]) - 1)
685
  assert nsplits > 2
686
+
687
  return targets, labels, nsplits, unique_labels
688
+
689
  if skip_training == False:
690
  # preparing targets and labels for dosage sensitive vs insensitive TFs
691
  gene_classes = pd.read_csv(genes, header=0)
 
697
  else:
698
  labels = [filter_data]
699
  label_store = {}
700
+
701
  # Dictionary for decoding labels
702
+ decode = {i: labels[i] for i in range(len(labels))}
703
+
704
  for label in labels:
705
  label_store[label] = gene_classes[label].dropna()
706
+
707
  targets, labels, nsplits, unique_labels = prep_inputs(label_store, "ensembl_id")
708
+
 
 
709
  # load training dataset
710
+ train_dataset = load_from_disk(corpus_30M)
711
  shuffled_train_dataset = train_dataset.shuffle(seed=42)
712
+ subsampled_train_dataset = shuffled_train_dataset.select(
713
+ [i for i in range(filter_dataset)]
714
+ )
715
  lr_schedule_fn = "linear"
716
  warmup_steps = 500
717
  optimizer = "adamw"
718
  subsample_size = 10_000
719
+
720
  training_args = {
721
  "learning_rate": max_lr,
722
  "do_train": True,
 
733
  "per_device_eval_batch_size": geneformer_batch_size,
734
  "num_train_epochs": epochs,
735
  }
736
+
737
  # define output directory path
738
  current_date = datetime.datetime.now()
739
  datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
740
+
741
  if output_dir == None:
742
+ training_output_dir = Path(
743
+ f"{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/"
744
+ )
745
  else:
746
  training_output_dir = Path(output_dir)
747
+
748
  # make output directory
749
+ subprocess.call(f"mkdir -p {training_output_dir}", shell=True)
750
+
751
  # Places number of classes + in directory
752
  num_classes = len(set(labels))
753
  info_list = [num_classes, decode]
754
+
755
+ with open(training_output_dir / "classes.txt", "w") as f:
756
  f.write(str(info_list))
757
+
758
+ subsampled_train_dataset.save_to_disk(output_dir / "dataset")
759
+
760
  if k_validate == True:
761
+ ksplit_model = "ksplit0/models"
762
  ksplit_model_test = os.path.join(training_output_dir, ksplit_model)
763
+ # if os.path.isfile(ksplit_model_test) == True:
764
  # raise Exception("Model already saved to this directory.")
765
+ # cross-validate gene classifier
766
+ (
767
+ all_roc_auc,
768
+ roc_auc,
769
+ roc_auc_sd,
770
+ mean_fpr,
771
+ mean_tpr,
772
+ confusion,
773
+ label_dicts,
774
+ ) = cross_validate(
775
+ subsampled_train_dataset,
776
+ targets,
777
+ labels,
778
+ nsplits,
779
+ subsample_size,
780
+ training_args,
781
+ freeze_layers,
782
+ training_output_dir,
783
+ 1,
784
+ unique_labels,
785
+ model,
786
+ )
787
+
788
  bundled_data = []
789
+ bundled_data += [
790
+ (roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red")
791
+ ]
792
+ graph_title = " ".join(
793
+ [
794
+ i + " vs" if count < len(label_store) - 1 else i
795
+ for count, i in enumerate(label_store)
796
+ ]
797
+ )
798
+ fpr, tpr, auc = plot_ROC(
799
+ bundled_data, "Dosage Sensitive vs Insensitive TFs"
800
+ )
801
  print(auc)
802
  # plot confusion matrix
803
  plot_confusion_matrix(label_store, confusion, "Geneformer")
804
+ else:
805
+ fpr, tpr, auc = validate(
806
+ subsampled_train_dataset,
807
+ targets,
808
+ labels,
809
+ nsplits,
810
+ subsample_size,
811
+ training_args,
812
+ freeze_layers,
813
+ training_output_dir,
814
+ 1,
815
+ unique_labels,
816
+ model,
817
+ )
818
  print(auc)
819
+
820
  if inference == True:
821
  # preparing targets and labels for dosage sensitive vs insensitive TFs
822
  gene_classes = pd.read_csv(genes, header=0)
823
  targets = []
824
+ for column in gene_classes.columns:
825
  targets += list(gene_classes[column])
826
  tokens = []
827
  for target in targets:
 
829
  tokens.append(token_dictionary[target])
830
  except:
831
  tokens.append(0)
832
+
833
  targets = torch.LongTensor([tokens])
834
+
835
+ with open(f"{model_location}classes.txt", "r") as f:
 
836
  info_list = ast.literal_eval(f.read())
837
  num_classes = info_list[0]
838
  labels = info_list[1]
839
+
840
  model = BertForTokenClassification.from_pretrained(
841
  model_location,
842
  num_labels=num_classes,
843
+ output_attentions=False,
844
+ output_hidden_states=False,
845
+ local_files_only=True,
846
  )
847
  if freeze_layers is not None:
848
  modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
849
  for module in modules_to_freeze:
850
  for param in module.parameters():
851
  param.requires_grad = False
852
+
853
  model = model.to(device)
854
+
855
  # evaluate model
856
+ predictions = F.softmax(model(targets.to(device))["logits"], dim=-1).argmax(-1)[
857
+ 0
858
+ ]
859
  predictions = [labels[int(pred)] for pred in predictions]
860
+
861
  return predictions
862
+
863
+ # Extracts aggregate gene embeddings for each label
864
  if emb_extract == True:
865
+ with open(f"{model_location}/classes.txt", "r") as f:
866
  data = ast.literal_eval(f.read())
867
  num_classes = data[0]
868
  decode = data[1]
869
+
870
  gene_classes = pd.read_csv(genes, header=0)
871
  labels = gene_classes.columns
872
  tokenize = TranscriptomeTokenizer()
873
+
874
  label_dict = {}
875
  for label in labels:
876
  genes = gene_classes[label]
 
881
  except:
882
  continue
883
  label_dict[label] = tokenized_genes
884
+
885
+ embex = EmbExtractor(
886
+ model_type="GeneClassifier",
887
+ num_classes=num_classes,
888
+ emb_mode="gene",
889
+ filter_data=None,
890
+ max_ncells=max_cells,
891
+ emb_layer=emb_layer,
892
+ emb_label=label_dict,
893
+ labels_to_plot=list(labels),
894
+ forward_batch_size=forward_batch,
895
+ nproc=num_cpus,
896
+ )
897
+
898
+ subprocess.call(f"mkdir -p {emb_dir}", shell=True)
899
+
900
+ embs = embex.extract_embs(
901
+ model_directory=model_location,
902
+ input_data_file=model_location / "dataset",
903
+ output_directory=emb_dir,
904
+ output_prefix=f"{label}_embbeddings",
905
+ )
906
+
907
+ emb_dict = {label: [] for label in list(set(labels))}
908
+ similarities = {key: {} for key in list(emb_dict.keys())}
909
+
910
  for column in embs.columns:
911
  remaining_cols = [k for k in embs.columns if k != column]
912
  for k in remaining_cols:
913
  embedding = torch.Tensor(embs[k])
914
+ sim = similarity(torch.Tensor(embs[column]), embedding, cosine=True)
915
  similarities[column][k] = sim
916
+
917
  plot_similarity_heatmap(similarities)
918
  print(similarities)
919
+
920
  return similarities
 
 
 
921
 
922
+
923
+ if __name__ == "__main__":
924
+ classify_genes(
925
+ k_validate=False,
926
+ inference=False,
927
+ skip_training=False,
928
+ emb_extract=True,
929
+ output_dir=Path("gene_emb"),
930
+ model_location=Path("gene_emb"),
931
+ epochs=5,
932
+ gene_info="../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_info_table.csv",
933
+ genes="../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
934
+ corpus_30M="../GeneFormer_repo/Genecorpus-30M/genecorpus_30M_2048.dataset/",
935
+ )
Modular_usage.md → geneformer/modular_classifier_usage.md RENAMED
@@ -1,33 +1,33 @@
1
  # Cell classifier
2
  def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".",
3
- dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/',
4
  dataset_split = None,
5
- filter_cells = .005,
6
- epochs = 1,
7
- cpu_cores = os.cpu_count(),
8
- geneformer_batch_size = 12,
9
- optimizer = 'adamw',
10
- max_lr = 5e-5,
11
- num_gpus = torch.cuda.device_count(),
12
  max_input_size = 2 ** 11,
13
- lr_schedule_fn = "linear",
14
- warmup_steps = 500,
15
- freeze_layers = 0,
16
- emb_extract = False,
17
- max_cells = 1000,
18
- emb_layer = 0,
19
- emb_filter = None,
20
- emb_dir = 'embeddings',
21
  overwrite = True,
22
  label = "cell_type",
23
  data_filter = None,
24
- forward_batch = 200, model_location = None,
25
- skip_training = False,
26
  sample_data = 1,
27
- inference = False,
28
- optimize_hyperparameters = False,
29
  output_dir = None):
30
-
31
  '''
32
  Primary Parameters
33
  -------------------
@@ -36,121 +36,121 @@ def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_s
36
 
37
  model_location: path
38
  Path to location of existing model to use for inference and embedding extraction
39
-
40
  pretrained_model: path
41
  Path to pretrained GeneFormer 30M model before fine-tuning
42
-
43
- inference: bool
44
  Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False
45
-
46
- skip_training: bool
47
  Chooses whether to skip training the model. Defaults to False
48
-
49
  emb_extract: bool
50
  Choose whether to extract embeddings and calculate similarities. Defaults to True
51
-
52
  optimize_hyperparameters: bool
53
  Choose whether to optimize model hyperparamters. Defaults to False
54
  label: string
55
- The label string in the formatted dataset that contains true class labels. Defaults to "label"
56
-
57
  Customization Parameters
58
  -------------------
59
-
60
  dataset_split: str
61
  How the dataset should be partitioned (if at all), and what ID should be used for partitioning
62
-
63
  data_filter: list
64
  (For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split
65
-
66
  label: str
67
  What feature should be read as a classification label
68
-
69
  emb_layer: int
70
  What layer embeddings should be extracted and compared from.
71
-
72
  emb_filter: ['cell1', 'cell2'...]
73
  Allows user to narrow down range of cells that embeddings will be extracted from.
74
-
75
  max_cells: int
76
- How many embeddings from cells should be extracted.
77
-
78
  freeze_layers: int
79
  Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model).
80
-
81
  sample_data: float
82
  What proportion of the HF dataset should be used
83
-
84
  '''
85
-
86
  # Gene Classifier
87
- def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv",
88
  genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
89
  corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
90
- max_input_size = 2 ** 11,
91
  max_lr = 5e-5,
92
- freeze_layers = 4,
93
- num_gpus = 1,
94
- num_proc = os.cpu_count(),
95
- geneformer_batch_size = 9,
96
- epochs = 1,
97
  filter_dataset = 50_000,
98
- emb_extract = True,
99
- emb_layer = 0,
100
- forward_batch = 200,
101
- filter_data = None,
102
- inference = False,
103
- k_validate = True,
104
- model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
105
- skip_training = False,
106
- emb_dir = 'gene_emb',
107
- output_dir = None,
108
- max_cells = 1000,
109
  num_cpus = os.cpu_count()):
110
-
111
  """"
112
  Primary Parameters
113
  -----------
114
-
115
  gene_info: path
116
  Path to gene mappings
117
-
118
  corpus_30M: path
119
  Path to 30M Gene Corpus
120
-
121
  model: path
122
  Path to pretrained GeneFormer model
123
-
124
  genes: path
125
  Path to csv file containing different columns of genes and the column labels
126
-
127
  inference: bool
128
  Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
129
-
130
  k_validate: bool
131
  Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
132
-
133
  skip_training: bool
134
  Whether the model should skip the training portion. Defaults to False
135
-
136
  emb_extract: bool
137
  WHether the model should extract embeddings for a given gene (WIP)
138
-
139
-
140
  Customization Parameters
141
  -----------
142
-
143
  freeze_layers: int
144
  Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
145
-
146
  filter_dataset: int
147
  Number of cells to filter from 30M dataset. Default is 50_000
148
-
149
  emb_layer: int
150
  What layer embeddings are extracted from. Default is 4
151
-
152
  filter_data: str, list
153
  Filters down embeddings to a single category. Default is None
154
-
155
-
156
  """
 
1
  # Cell classifier
2
  def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".",
3
+ dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/',
4
  dataset_split = None,
5
+ filter_cells = .005,
6
+ epochs = 1,
7
+ cpu_cores = os.cpu_count(),
8
+ geneformer_batch_size = 12,
9
+ optimizer = 'adamw',
10
+ max_lr = 5e-5,
11
+ num_gpus = torch.cuda.device_count(),
12
  max_input_size = 2 ** 11,
13
+ lr_schedule_fn = "linear",
14
+ warmup_steps = 500,
15
+ freeze_layers = 0,
16
+ emb_extract = False,
17
+ max_cells = 1000,
18
+ emb_layer = 0,
19
+ emb_filter = None,
20
+ emb_dir = 'embeddings',
21
  overwrite = True,
22
  label = "cell_type",
23
  data_filter = None,
24
+ forward_batch = 200, model_location = None,
25
+ skip_training = False,
26
  sample_data = 1,
27
+ inference = False,
28
+ optimize_hyperparameters = False,
29
  output_dir = None):
30
+
31
  '''
32
  Primary Parameters
33
  -------------------
 
36
 
37
  model_location: path
38
  Path to location of existing model to use for inference and embedding extraction
39
+
40
  pretrained_model: path
41
  Path to pretrained GeneFormer 30M model before fine-tuning
42
+
43
+ inference: bool
44
  Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False
45
+
46
+ skip_training: bool
47
  Chooses whether to skip training the model. Defaults to False
48
+
49
  emb_extract: bool
50
  Choose whether to extract embeddings and calculate similarities. Defaults to True
51
+
52
  optimize_hyperparameters: bool
53
  Choose whether to optimize model hyperparamters. Defaults to False
54
  label: string
55
+ The label string in the formatted dataset that contains true class labels. Defaults to "label"
56
+
57
  Customization Parameters
58
  -------------------
59
+
60
  dataset_split: str
61
  How the dataset should be partitioned (if at all), and what ID should be used for partitioning
62
+
63
  data_filter: list
64
  (For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split
65
+
66
  label: str
67
  What feature should be read as a classification label
68
+
69
  emb_layer: int
70
  What layer embeddings should be extracted and compared from.
71
+
72
  emb_filter: ['cell1', 'cell2'...]
73
  Allows user to narrow down range of cells that embeddings will be extracted from.
74
+
75
  max_cells: int
76
+ How many embeddings from cells should be extracted.
77
+
78
  freeze_layers: int
79
  Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model).
80
+
81
  sample_data: float
82
  What proportion of the HF dataset should be used
83
+
84
  '''
85
+
86
  # Gene Classifier
87
+ def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv",
88
  genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
89
  corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
90
+ max_input_size = 2 ** 11,
91
  max_lr = 5e-5,
92
+ freeze_layers = 4,
93
+ num_gpus = 1,
94
+ num_proc = os.cpu_count(),
95
+ geneformer_batch_size = 9,
96
+ epochs = 1,
97
  filter_dataset = 50_000,
98
+ emb_extract = True,
99
+ emb_layer = 0,
100
+ forward_batch = 200,
101
+ filter_data = None,
102
+ inference = False,
103
+ k_validate = True,
104
+ model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
105
+ skip_training = False,
106
+ emb_dir = 'gene_emb',
107
+ output_dir = None,
108
+ max_cells = 1000,
109
  num_cpus = os.cpu_count()):
110
+
111
  """"
112
  Primary Parameters
113
  -----------
114
+
115
  gene_info: path
116
  Path to gene mappings
117
+
118
  corpus_30M: path
119
  Path to 30M Gene Corpus
120
+
121
  model: path
122
  Path to pretrained GeneFormer model
123
+
124
  genes: path
125
  Path to csv file containing different columns of genes and the column labels
126
+
127
  inference: bool
128
  Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
129
+
130
  k_validate: bool
131
  Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
132
+
133
  skip_training: bool
134
  Whether the model should skip the training portion. Defaults to False
135
+
136
  emb_extract: bool
137
  WHether the model should extract embeddings for a given gene (WIP)
138
+
139
+
140
  Customization Parameters
141
  -----------
142
+
143
  freeze_layers: int
144
  Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
145
+
146
  filter_dataset: int
147
  Number of cells to filter from 30M dataset. Default is 50_000
148
+
149
  emb_layer: int
150
  What layer embeddings are extracted from. Default is 4
151
+
152
  filter_data: str, list
153
  Filters down embeddings to a single category. Default is None
154
+
155
+
156
  """