Upload 15 files
Browse filesDeployable version of GeneFormer gene/cell classification and embedding extraction in a single function. Function parameters explained in the markdown file, example usage at the bottom of each python file. Let me know if anything is needed or if there are unresolved issues, and I can get to fixing them!
- Cell_classifier.py +861 -0
- Gene_classifier.py +746 -0
- Immune_modelpredictions.pickle +0 -0
- Modular_usage.md +156 -0
- gene_embclasses.txt +1 -0
- gene_embdataset.pk +0 -0
Cell_classifier.py
ADDED
@@ -0,0 +1,861 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Package Imports
|
2 |
+
import tqdm
|
3 |
+
import sys
|
4 |
+
import polars as pl
|
5 |
+
import pysam
|
6 |
+
import os
|
7 |
+
from datasets import Dataset
|
8 |
+
from collections import Counter
|
9 |
+
import random
|
10 |
+
import datetime
|
11 |
+
from pathlib import Path
|
12 |
+
import subprocess
|
13 |
+
import seaborn as sns; sns.set()
|
14 |
+
from datasets import load_from_disk
|
15 |
+
import fastcluster
|
16 |
+
from sklearn.metrics import accuracy_score, f1_score
|
17 |
+
from transformers import BertForSequenceClassification
|
18 |
+
from transformers import Trainer
|
19 |
+
from transformers.training_args import TrainingArguments
|
20 |
+
from geneformer import DataCollatorForCellClassification, EmbExtractor
|
21 |
+
import pickle
|
22 |
+
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve
|
23 |
+
from sklearn.metrics import auc as precision_auc
|
24 |
+
from sklearn.preprocessing import label_binarize
|
25 |
+
import pyarrow as pa
|
26 |
+
import concurrent.futures
|
27 |
+
from matplotlib import pyplot as plt
|
28 |
+
import torch
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from scipy.stats import ranksums
|
31 |
+
import ray
|
32 |
+
import ast
|
33 |
+
from ray import tune
|
34 |
+
from ray.tune import ExperimentAnalysis
|
35 |
+
from ray.tune.search.hyperopt import HyperOptSearch
|
36 |
+
import numpy as np
|
37 |
+
|
38 |
+
# Properly sets up NCCV environment
|
39 |
+
GPU_NUMBER = [i for i in range(torch.cuda.device_count())]
|
40 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
41 |
+
os.environ["NCCL_DEBUG"] = "INFO"
|
42 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
+
|
44 |
+
# Function for generating a ROC curve from data
|
45 |
+
def ROC(prediction, truth, type = 'GeneFormer', label = ''):
|
46 |
+
|
47 |
+
fpr, tpr, _ = roc_curve(truth, prediction[:, 1])
|
48 |
+
auc = roc_auc_score(truth, prediction[:, 1])
|
49 |
+
print(f'{type} AUC: {auc}')
|
50 |
+
plt.plot(fpr,tpr, label="AUC="+str(auc))
|
51 |
+
plt.ylabel('True Positive Rate')
|
52 |
+
plt.xlabel('False Positive Rate')
|
53 |
+
plt.title(f'{label} ROC Curve')
|
54 |
+
plt.legend(loc=4)
|
55 |
+
plt.savefig('ROC.png')
|
56 |
+
|
57 |
+
return tpr, fpr, auc
|
58 |
+
|
59 |
+
# Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
|
60 |
+
def similarity(tensor1, tensor2, cosine = False):
|
61 |
+
|
62 |
+
if cosine == False:
|
63 |
+
if tensor1.ndimension() > 1:
|
64 |
+
tensor1 = tensor1.view(1, -1)
|
65 |
+
if tensor2.ndimension() > 1:
|
66 |
+
tensor2 = tensor2.view(1, -1)
|
67 |
+
dot_product = torch.matmul(tensor1, tensor2)
|
68 |
+
norm_tensor1 = torch.norm(tensor1)
|
69 |
+
norm_tensor2 = torch.norm(tensor2)
|
70 |
+
epsilon = 1e-8
|
71 |
+
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
72 |
+
similarity = (similarity.item() + 1)/2
|
73 |
+
else:
|
74 |
+
if tensor1.shape != tensor2.shape:
|
75 |
+
raise ValueError("Input tensors must have the same shape.")
|
76 |
+
|
77 |
+
# Compute cosine similarity using PyTorch's dot product function
|
78 |
+
dot_product = torch.dot(tensor1, tensor2)
|
79 |
+
norm_tensor1 = torch.norm(tensor1)
|
80 |
+
norm_tensor2 = torch.norm(tensor2)
|
81 |
+
|
82 |
+
# Avoid division by zero by adding a small epsilon
|
83 |
+
epsilon = 1e-8
|
84 |
+
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
85 |
+
|
86 |
+
return similarity.item()
|
87 |
+
|
88 |
+
# Plots heatmap between different classes/labels
|
89 |
+
def plot_similarity_heatmap(similarities):
|
90 |
+
classes = list(similarities.keys())
|
91 |
+
classlen = len(classes)
|
92 |
+
arr = np.zeros((classlen, classlen))
|
93 |
+
for i, c in enumerate(classes):
|
94 |
+
for j, cc in enumerate(classes):
|
95 |
+
if cc == c:
|
96 |
+
val = 1.0
|
97 |
+
else:
|
98 |
+
val = similarities[c][cc]
|
99 |
+
arr[i][j] = val
|
100 |
+
|
101 |
+
plt.figure(figsize=(8, 6))
|
102 |
+
plt.imshow(arr, cmap='inferno', vmin=0, vmax=1)
|
103 |
+
plt.colorbar()
|
104 |
+
plt.xticks(np.arange(classlen), classes, rotation = 45, ha = 'right')
|
105 |
+
plt.yticks(np.arange(classlen), classes)
|
106 |
+
plt.title("Similarity Heatmap")
|
107 |
+
plt.savefig("similarity_heatmap.png")
|
108 |
+
|
109 |
+
# Function for tokenizing genes into ranked-value encodings from Geneformer
|
110 |
+
def tokenize_dataset(gene_set, type = None, token_set = 'token_dictionary.pkl', species = 'human'):
|
111 |
+
token_dataset = open(token_set, 'rb')
|
112 |
+
token_dict = pickle.load(token_dataset)
|
113 |
+
wrap = True
|
114 |
+
|
115 |
+
if isinstance(gene_set[0], list) == False:
|
116 |
+
gene_set = [gene_set]
|
117 |
+
wrap = False
|
118 |
+
|
119 |
+
pool = Pool()
|
120 |
+
converted_set = []
|
121 |
+
|
122 |
+
def process_gene(gene):
|
123 |
+
api_url = f"https://rest.ensembl.org/xrefs/symbol/{species}/{gene}?object_type=gene"
|
124 |
+
response = requests.get(api_url, headers={"Content-Type": "application/json"})
|
125 |
+
try:
|
126 |
+
data = response.json()
|
127 |
+
gene = data[0]['id']
|
128 |
+
except:
|
129 |
+
gene = None
|
130 |
+
return gene
|
131 |
+
|
132 |
+
def process_hgnc(gene):
|
133 |
+
for gene in tqdm.tqdm(genes, total = len(genes)):
|
134 |
+
api_url = f"https://rest.ensembl.org/xrefs/symbol/{species}/{hgnc_id}?object_type=gene"
|
135 |
+
response = requests.get(api_url, headers={"Content-Type": "application/json"})
|
136 |
+
try:
|
137 |
+
data = response.json()
|
138 |
+
gene = data[0]['id']
|
139 |
+
except:
|
140 |
+
gene = None
|
141 |
+
return gene
|
142 |
+
|
143 |
+
def process_go(gene):
|
144 |
+
mg = mygene.MyGeneInfo()
|
145 |
+
results = mg.query(gene, scopes="go", species=species, fields="ensembl.gene")
|
146 |
+
|
147 |
+
ensembl_ids = []
|
148 |
+
max_score = 0
|
149 |
+
for hit_num, hit in enumerate(results["hits"]):
|
150 |
+
if hit['_score'] > max_score:
|
151 |
+
max_score = hit['_score']
|
152 |
+
chosen_hit = hit
|
153 |
+
try:
|
154 |
+
try:
|
155 |
+
gene = chosen_hit["ensembl"]["gene"]
|
156 |
+
except:
|
157 |
+
gene = chosen_hit["ensembl"][0]["gene"]
|
158 |
+
except:
|
159 |
+
gene = None
|
160 |
+
return gene
|
161 |
+
|
162 |
+
if type == None or type.upper() == 'ENSEMBL':
|
163 |
+
converted_set = gene_set
|
164 |
+
elif type.upper() == 'GENE':
|
165 |
+
for genes in gene_set:
|
166 |
+
converted_genes = []
|
167 |
+
for result in tqdm.tqdm(pool.imap(process_gene, genes), total = len(genes)):
|
168 |
+
converted_genes.append(result)
|
169 |
+
converted_set.append(converted_genes)
|
170 |
+
elif type.upper() == 'GO':
|
171 |
+
for genes in gene_set:
|
172 |
+
converted_genes = []
|
173 |
+
for result in tqdm.tqdm(pool.imap(process_go, genes), total = len(genes)):
|
174 |
+
converted_genes.append(result)
|
175 |
+
converted_set.append(converted_genes)
|
176 |
+
elif type.upper() == 'HGNC':
|
177 |
+
for genes in gene_set:
|
178 |
+
converted_genes = []
|
179 |
+
for result in tqdm.tqdm(pool.imap(process_hgnc, genes), total = len(genes)):
|
180 |
+
converted_genes.append(result)
|
181 |
+
converted_set.append(converted_genes)
|
182 |
+
|
183 |
+
Chembl = []
|
184 |
+
for set_num, set in enumerate(converted_set):
|
185 |
+
Chembl.append([])
|
186 |
+
for gene in set:
|
187 |
+
if gene == None:
|
188 |
+
Chembl[set_num].append(None)
|
189 |
+
else:
|
190 |
+
try:
|
191 |
+
Chembl[set_num].append(token_dict[gene])
|
192 |
+
except:
|
193 |
+
print(f'{gene} not found in tokenized dataset!')
|
194 |
+
Chembl[set_num].append(None)
|
195 |
+
|
196 |
+
if wrap == False:
|
197 |
+
Chembl = Chembl[0]
|
198 |
+
|
199 |
+
return Chembl
|
200 |
+
|
201 |
+
|
202 |
+
# '/work/ccnr/GeneFormer/GeneFormer_repo/Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/'
|
203 |
+
# '/work/ccnr/GeneFormer/GeneFormer_repo/Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/'
|
204 |
+
'''
|
205 |
+
======================================================
|
206 |
+
|
207 |
+
PRIMARY CELL - CLASSIFIER AND EMBEDDING EXTRACTOR CLASS
|
208 |
+
|
209 |
+
+++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
210 |
+
|
211 |
+
Runs cell-level classification and embedding extraction with Geneformer
|
212 |
+
|
213 |
+
'''
|
214 |
+
|
215 |
+
def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".",
|
216 |
+
dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/', dataset_split = None, filter_cells = .005, epochs = 1, cpu_cores = os.cpu_count(), geneformer_batch_size = 12, optimizer = 'adamw', max_lr = 5e-5, num_gpus = torch.cuda.device_count(), max_input_size = 2 ** 11, lr_schedule_fn = "linear", warmup_steps = 500, freeze_layers = 0, emb_extract = False, max_cells = 1000, emb_layer = 0, emb_filter = None, emb_dir = 'embeddings', overwrite = True, label = "cell_type", data_filter = None,
|
217 |
+
forward_batch = 200, model_location = None, skip_training = False, sample_data = 1, inference = False, optimize_hyperparameters = False, output_dir = None):
|
218 |
+
|
219 |
+
'''
|
220 |
+
Primary Parameters
|
221 |
+
-------------------
|
222 |
+
dataset: path
|
223 |
+
Path to fine-tuning/testing dataset for training
|
224 |
+
|
225 |
+
model_location: path
|
226 |
+
Path to location of existing model to use for inference and embedding extraction
|
227 |
+
|
228 |
+
pretrained_model: path
|
229 |
+
Path to pretrained GeneFormer 30M model before fine-tuning
|
230 |
+
|
231 |
+
inference: bool
|
232 |
+
Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False
|
233 |
+
|
234 |
+
skip_training: bool
|
235 |
+
Chooses whether to skip training the model. Defaults to False
|
236 |
+
|
237 |
+
emb_extract: bool
|
238 |
+
Choose whether to extract embeddings and calculate similarities. Defaults to True
|
239 |
+
|
240 |
+
optimize_hyperparameters: bool
|
241 |
+
Choose whether to optimize model hyperparamters. Defaults to False
|
242 |
+
|
243 |
+
|
244 |
+
Customization Parameters
|
245 |
+
-------------------
|
246 |
+
|
247 |
+
dataset_split: str
|
248 |
+
How the dataset should be partitioned (if at all), and what ID should be used for partitioning
|
249 |
+
|
250 |
+
data_filter: list
|
251 |
+
(For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split
|
252 |
+
|
253 |
+
label: str
|
254 |
+
What feature should be read as a classification label
|
255 |
+
|
256 |
+
emb_layer: int
|
257 |
+
What layer embeddings should be extracted and compared from.
|
258 |
+
|
259 |
+
emb_filter: ['cell1', 'cell2'...]
|
260 |
+
Allows user to narrow down range of cells that embeddings will be extracted from.
|
261 |
+
|
262 |
+
max_cells: int
|
263 |
+
How many embeddings from cells should be extracted.
|
264 |
+
|
265 |
+
freeze_layers: int
|
266 |
+
Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model).
|
267 |
+
|
268 |
+
sample_data: float
|
269 |
+
What proportion of the HF dataset should be used
|
270 |
+
|
271 |
+
'''`
|
272 |
+
|
273 |
+
dataset_list = []
|
274 |
+
evalset_list = []
|
275 |
+
split_list = []
|
276 |
+
target_dict_list = []
|
277 |
+
|
278 |
+
'''
|
279 |
+
For loading and pretraining with custom median expressions and/or custom gene conversions
|
280 |
+
-------------------------------------------------------------
|
281 |
+
|
282 |
+
token set: path
|
283 |
+
Path to token conversion dictionary
|
284 |
+
|
285 |
+
median set: path
|
286 |
+
Path to median gene dictionary (ensembl IDs as the keys)
|
287 |
+
|
288 |
+
|
289 |
+
median_data = pickle.load(open(median_set, 'rb'))
|
290 |
+
median_data['<pad>'] = None
|
291 |
+
median_data['<mask>'] = None
|
292 |
+
|
293 |
+
token_set = pickle.load(open(token_set, 'rb'))
|
294 |
+
median_dict = {key:median_data[key] for key in list(token_set.keys())}
|
295 |
+
'''
|
296 |
+
|
297 |
+
train_dataset = load_from_disk(dataset)
|
298 |
+
num_samples = int(len(train_dataset) * sample_data)
|
299 |
+
random_indices = random.sample(range(len(train_dataset)), num_samples)
|
300 |
+
train_dataset = train_dataset.select(random_indices)
|
301 |
+
|
302 |
+
sample = int(sample_data * len(train_dataset))
|
303 |
+
sample_indices = random.sample(range(len(train_dataset)), sample)
|
304 |
+
train_dataset = train_dataset.select(sample_indices)
|
305 |
+
|
306 |
+
def if_not_rare_celltype(example):
|
307 |
+
return example[label] in cells_to_keep
|
308 |
+
|
309 |
+
# change labels to numerical ids
|
310 |
+
def classes_to_ids(example):
|
311 |
+
example["label"] = target_name_id_dict[example["label"]]
|
312 |
+
return example
|
313 |
+
|
314 |
+
def if_trained_label(example):
|
315 |
+
return example["label"] in trained_labels
|
316 |
+
|
317 |
+
if skip_training != True:
|
318 |
+
def compute_metrics(pred):
|
319 |
+
labels = pred.label_ids
|
320 |
+
preds = pred.predictions.argmax(-1)
|
321 |
+
# calculate accuracy and macro f1 using sklearn's function
|
322 |
+
acc = accuracy_score(labels, preds)
|
323 |
+
macro_f1 = f1_score(labels, preds, average='macro')
|
324 |
+
return {
|
325 |
+
'accuracy': acc,
|
326 |
+
'macro_f1': macro_f1
|
327 |
+
}
|
328 |
+
|
329 |
+
# Defines custom exceptions for collecting labels (default excluded)
|
330 |
+
excep = {"bone_marrow":"immune"}
|
331 |
+
|
332 |
+
if dataset_split != None:
|
333 |
+
if data_filter != None:
|
334 |
+
split_iter = [data_filter]
|
335 |
+
else:
|
336 |
+
split_iter = Counter(train_dataset[dataset_split]).keys()
|
337 |
+
for lab in split_iter:
|
338 |
+
|
339 |
+
# collect list of tissues for fine-tuning (immune and bone marrow are included together)
|
340 |
+
if lab in list(excep.keys()):
|
341 |
+
continue
|
342 |
+
elif lab == list(excep.values()):
|
343 |
+
split_ids = [excep.keys(),excep.values()]
|
344 |
+
split_list += [excep.values()]
|
345 |
+
else:
|
346 |
+
split_ids = [lab]
|
347 |
+
split_list += [lab]
|
348 |
+
|
349 |
+
# filter datasets for given organ
|
350 |
+
def if_label(example):
|
351 |
+
return example[dataset_split] == lab
|
352 |
+
|
353 |
+
trainset_label = train_dataset.filter(if_label, num_proc=cpu_cores)
|
354 |
+
label_counter = Counter(trainset_label[label])
|
355 |
+
total_cells = sum(label_counter.values())
|
356 |
+
|
357 |
+
# Throws out cells with a low proportion in the dataset (drop cell types representing <0.5% of cells per deepsort published method)
|
358 |
+
cells_to_keep = [k for k,v in label_counter.items() if v>(filter_cells*total_cells)]
|
359 |
+
trainset_label_subset = trainset_label.filter(if_not_rare_celltype, num_proc=cpu_cores)
|
360 |
+
|
361 |
+
# shuffle datasets and rename columns
|
362 |
+
trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
|
363 |
+
trainset_label_shuffled = trainset_label_shuffled.rename_column(label,"label")
|
364 |
+
trainset_label_shuffled = trainset_label_shuffled.remove_columns(dataset_split)
|
365 |
+
|
366 |
+
# create dictionary of cell types : label ids
|
367 |
+
target_names = list(Counter(trainset_label_shuffled["label"]).keys())
|
368 |
+
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
|
369 |
+
target_dict_list += [target_name_id_dict]
|
370 |
+
|
371 |
+
labeled_trainset = trainset_label_shuffled.map(classes_to_ids, num_proc=cpu_cores)
|
372 |
+
|
373 |
+
# create 80/20 train/eval splits
|
374 |
+
labeled_train_split = trainset_label_shuffled.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
|
375 |
+
labeled_eval_split = trainset_label_shuffled.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
|
376 |
+
|
377 |
+
# filter dataset for cell types in corresponding training set
|
378 |
+
trained_labels = list(Counter(labeled_train_split["label"]).keys())
|
379 |
+
|
380 |
+
labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=cpu_cores)
|
381 |
+
|
382 |
+
dataset_list += [labeled_train_split]
|
383 |
+
evalset_list += [labeled_eval_split_subset]
|
384 |
+
|
385 |
+
trainset_dict = dict(zip(split_list,dataset_list))
|
386 |
+
traintargetdict_dict = dict(zip(split_list,target_dict_list))
|
387 |
+
evalset_dict = dict(zip(split_list,evalset_list))
|
388 |
+
|
389 |
+
for lab in split_list:
|
390 |
+
label_trainset = trainset_dict[lab]
|
391 |
+
label_evalset = evalset_dict[lab]
|
392 |
+
label_dict = traintargetdict_dict[lab]
|
393 |
+
|
394 |
+
# set logging steps
|
395 |
+
logging_steps = round(len(label_trainset)/geneformer_batch_size/10)
|
396 |
+
if logging_steps == 0:
|
397 |
+
logging_steps = 1
|
398 |
+
|
399 |
+
# reload pretrained model
|
400 |
+
model = BertForSequenceClassification.from_pretrained("/work/ccnr/GeneFormer/GeneFormer_repo",
|
401 |
+
num_labels=len(label_dict.keys()),
|
402 |
+
output_attentions = False,
|
403 |
+
output_hidden_states = False).to(device)
|
404 |
+
|
405 |
+
# define output directory path
|
406 |
+
current_date = datetime.datetime.now()
|
407 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
408 |
+
|
409 |
+
if output_dir == None:
|
410 |
+
output_dir = f"{datestamp}_geneformer_CellClassifier_{lab}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
|
411 |
+
|
412 |
+
# ensure not overwriting previously saved model
|
413 |
+
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
|
414 |
+
|
415 |
+
if os.path.isfile(saved_model_test) == True and overwrite == False:
|
416 |
+
raise Exception("Model already saved to this directory.")
|
417 |
+
|
418 |
+
# make output directory
|
419 |
+
subprocess.call(f'mkdir -p {output_dir}', shell=True)
|
420 |
+
|
421 |
+
# set training arguments
|
422 |
+
training_args = {
|
423 |
+
"learning_rate": max_lr,
|
424 |
+
"do_train": True,
|
425 |
+
"do_eval": True,
|
426 |
+
"evaluation_strategy": "epoch",
|
427 |
+
"save_strategy": "epoch",
|
428 |
+
"logging_steps": logging_steps,
|
429 |
+
"group_by_length": True,
|
430 |
+
"length_column_name": "length",
|
431 |
+
"disable_tqdm": False,
|
432 |
+
"lr_scheduler_type": lr_schedule_fn,
|
433 |
+
"warmup_steps": warmup_steps,
|
434 |
+
"weight_decay": 0.001,
|
435 |
+
"per_device_train_batch_size": geneformer_batch_size,
|
436 |
+
"per_device_eval_batch_size": geneformer_batch_size,
|
437 |
+
"num_train_epochs": epochs,
|
438 |
+
"load_best_model_at_end": True,
|
439 |
+
"output_dir": output_dir,
|
440 |
+
}
|
441 |
+
|
442 |
+
|
443 |
+
training_args_init = TrainingArguments(**training_args)
|
444 |
+
true_labels = label_evalset['label']
|
445 |
+
|
446 |
+
|
447 |
+
if optimize_hyperparameters == False:
|
448 |
+
# create the trainer
|
449 |
+
trainer = Trainer(
|
450 |
+
model=model,
|
451 |
+
args=training_args_init,
|
452 |
+
data_collator=DataCollatorForCellClassification(),
|
453 |
+
train_dataset=label_trainset,
|
454 |
+
eval_dataset=label_evalset,
|
455 |
+
compute_metrics=compute_metrics
|
456 |
+
)
|
457 |
+
|
458 |
+
# train the cell type classifier
|
459 |
+
trainer.train()
|
460 |
+
predictions = trainer.predict(label_evalset)
|
461 |
+
print(f'accuracy: {accuracy_score(predictions.argmax(), label_evalset["labels"])}')
|
462 |
+
|
463 |
+
tpr, fpr, auc = ROC(predictions.predictions, true_labels)
|
464 |
+
|
465 |
+
metrics = compute_metrics(predictions)
|
466 |
+
with open(f"{output_dir}predictions.pickle", "wb") as fp:
|
467 |
+
pickle.dump(predictions, fp)
|
468 |
+
|
469 |
+
trainer.save_metrics("eval",predictions.metrics)
|
470 |
+
|
471 |
+
with open(f'{output_dir}/targets.txt', 'w') as f:
|
472 |
+
if len(target_dict_list) == 1:
|
473 |
+
f.write(str(target_dict_list[0]))
|
474 |
+
else:
|
475 |
+
f.write(str(target_dict_list))
|
476 |
+
|
477 |
+
try:
|
478 |
+
|
479 |
+
precision, recall, _ = precision_recall_curve(true_labels, predictions.predictions[:, 1])
|
480 |
+
pr_auc = precision_auc(recall, precision)
|
481 |
+
|
482 |
+
print(f'AUC: {pr_auc}')
|
483 |
+
return recall, precision, pr_auc
|
484 |
+
except:
|
485 |
+
pass
|
486 |
+
|
487 |
+
trainer.save_model(output_dir)
|
488 |
+
else:
|
489 |
+
|
490 |
+
def model_init():
|
491 |
+
model = BertForSequenceClassification.from_pretrained(pretrained_model,
|
492 |
+
num_labels=len(label_dict.keys()),
|
493 |
+
output_attentions = False,
|
494 |
+
output_hidden_states = False)
|
495 |
+
if freeze_layers is not None:
|
496 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
497 |
+
for module in modules_to_freeze:
|
498 |
+
for param in module.parameters():
|
499 |
+
param.requires_grad = False
|
500 |
+
model = model.to(device)
|
501 |
+
return model
|
502 |
+
|
503 |
+
trainer = Trainer(
|
504 |
+
model_init=model_init,
|
505 |
+
args=training_args_init,
|
506 |
+
data_collator=DataCollatorForCellClassification(),
|
507 |
+
train_dataset=label_trainset,
|
508 |
+
eval_dataset=label_evalset,
|
509 |
+
compute_metrics=compute_metrics
|
510 |
+
)
|
511 |
+
# specify raytune hyperparameter search space
|
512 |
+
ray_config = {
|
513 |
+
"num_train_epochs": tune.choice([epochs]),
|
514 |
+
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
515 |
+
"weight_decay": tune.uniform(0.0, 0.3),
|
516 |
+
"lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
|
517 |
+
"warmup_steps": tune.uniform(100, 2000),
|
518 |
+
"seed": tune.uniform(0,100),
|
519 |
+
"per_device_train_batch_size": tune.choice([geneformer_batch_size])
|
520 |
+
}
|
521 |
+
|
522 |
+
hyperopt_search = HyperOptSearch(
|
523 |
+
metric="eval_accuracy", mode="max")
|
524 |
+
|
525 |
+
if torch.device == 'cuda':
|
526 |
+
resources_per_trial={"cpu":8,"gpu":1},
|
527 |
+
else:
|
528 |
+
resources_per_trial={"cpu":8}
|
529 |
+
|
530 |
+
# optimize hyperparameters
|
531 |
+
best_trial = trainer.hyperparameter_search(
|
532 |
+
direction="maximize",
|
533 |
+
backend="ray",
|
534 |
+
resources_per_trial = resources_per_trial,
|
535 |
+
hp_space=lambda _: ray_config,
|
536 |
+
search_alg=hyperopt_search,
|
537 |
+
n_trials=10, # number of trials
|
538 |
+
progress_reporter=tune.CLIReporter(max_report_frequency=600,
|
539 |
+
sort_by_metric=True,
|
540 |
+
max_progress_rows=100,
|
541 |
+
mode="max",
|
542 |
+
metric="eval_accuracy",
|
543 |
+
metric_columns=["loss", "eval_loss", "eval_accuracy"]))
|
544 |
+
best_hyperparameters = best_trial.hyperparameters
|
545 |
+
|
546 |
+
print("Best Hyperparameters:")
|
547 |
+
print(best_hyperparameters)
|
548 |
+
|
549 |
+
|
550 |
+
|
551 |
+
else:
|
552 |
+
trainset_label = train_dataset
|
553 |
+
label_counter = Counter(trainset_label[label])
|
554 |
+
total_cells = sum(label_counter.values())
|
555 |
+
|
556 |
+
# Throws out cells with a low proportion in the dataset
|
557 |
+
cells_to_keep = [k for k,v in label_counter.items() if v>(filter_cells*total_cells)]
|
558 |
+
trainset_label_subset = trainset_label.filter(if_not_rare_celltype, num_proc=cpu_cores)
|
559 |
+
|
560 |
+
# shuffle datasets and rename columns
|
561 |
+
trainset_label_shuffled = trainset_label_subset.shuffle(seed=42)
|
562 |
+
trainset_label_shuffled = trainset_label_shuffled.rename_column(label,"label")
|
563 |
+
|
564 |
+
# create dictionary of cell types : label ids
|
565 |
+
target_names = list(Counter(trainset_label_shuffled["label"]).keys())
|
566 |
+
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
|
567 |
+
target_dict_list = target_name_id_dict
|
568 |
+
|
569 |
+
labeled_trainset = trainset_label_shuffled.map(classes_to_ids, num_proc=cpu_cores)
|
570 |
+
|
571 |
+
# create 80/20 train/eval splits
|
572 |
+
labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])
|
573 |
+
labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])
|
574 |
+
|
575 |
+
# filter dataset for cell types in corresponding training set
|
576 |
+
trained_labels = list(Counter(labeled_train_split["label"]).keys())
|
577 |
+
labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=cpu_cores)
|
578 |
+
|
579 |
+
# set logging steps
|
580 |
+
logging_steps = round(len(trainset_label)/geneformer_batch_size/10)
|
581 |
+
|
582 |
+
# reload pretrained model
|
583 |
+
model = BertForSequenceClassification.from_pretrained(pretrained_model,
|
584 |
+
num_labels=len(target_dict_list.keys()),
|
585 |
+
output_attentions = False,
|
586 |
+
output_hidden_states = False).to(device)
|
587 |
+
# define output directory path
|
588 |
+
current_date = datetime.datetime.now()
|
589 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
590 |
+
|
591 |
+
if output_dir == None:
|
592 |
+
output_dir = f"{datestamp}_geneformer_CellClassifier_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/"
|
593 |
+
|
594 |
+
# ensure not overwriting previously saved model
|
595 |
+
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
|
596 |
+
if os.path.isfile(saved_model_test) == True and overwrite == False:
|
597 |
+
raise Exception("Model already saved to this directory.")
|
598 |
+
|
599 |
+
# make output directory
|
600 |
+
subprocess.call(f'mkdir -p {output_dir}', shell=True)
|
601 |
+
|
602 |
+
# set training arguments
|
603 |
+
training_args = {
|
604 |
+
"learning_rate": max_lr,
|
605 |
+
"do_train": True,
|
606 |
+
"do_eval": True,
|
607 |
+
"evaluation_strategy": "epoch",
|
608 |
+
"save_strategy": "epoch",
|
609 |
+
"logging_steps": logging_steps,
|
610 |
+
"group_by_length": True,
|
611 |
+
"length_column_name": "length",
|
612 |
+
"disable_tqdm": False,
|
613 |
+
"lr_scheduler_type": lr_schedule_fn,
|
614 |
+
"warmup_steps": warmup_steps,
|
615 |
+
"weight_decay": 0.001,
|
616 |
+
"per_device_train_batch_size": geneformer_batch_size,
|
617 |
+
"per_device_eval_batch_size": geneformer_batch_size,
|
618 |
+
"num_train_epochs": epochs,
|
619 |
+
"load_best_model_at_end": True,
|
620 |
+
"output_dir": output_dir,}
|
621 |
+
|
622 |
+
training_args_init = TrainingArguments(**training_args)
|
623 |
+
true_labels = labeled_eval_split_subset['label']
|
624 |
+
|
625 |
+
if optimize_hyperparameters == False:
|
626 |
+
|
627 |
+
# create the trainer
|
628 |
+
trainer = Trainer(
|
629 |
+
model=model,
|
630 |
+
args=training_args_init,
|
631 |
+
data_collator=DataCollatorForCellClassification(),
|
632 |
+
train_dataset=labeled_train_split,
|
633 |
+
eval_dataset=labeled_eval_split_subset,
|
634 |
+
compute_metrics=compute_metrics
|
635 |
+
)
|
636 |
+
|
637 |
+
# train the cell type classifier
|
638 |
+
trainer.train()
|
639 |
+
predictions = trainer.predict(labeled_eval_split_subset)
|
640 |
+
predictions_tensor = torch.Tensor(predictions.predictions)
|
641 |
+
predicted_labels = torch.argmax(predictions_tensor, dim=1)
|
642 |
+
print(f'accuracy: {accuracy_score(predicted_labels, labeled_eval_split_subset["label"])}')
|
643 |
+
metrics = compute_metrics(predictions)
|
644 |
+
|
645 |
+
with open(f"{output_dir}predictions.pickle", "wb") as fp:
|
646 |
+
pickle.dump(predictions.predictions.argmax(-1), fp)
|
647 |
+
|
648 |
+
trainer.save_metrics("eval",predictions.metrics)
|
649 |
+
trainer.save_model(output_dir)
|
650 |
+
|
651 |
+
# Saves label conversion dictionary to output directory
|
652 |
+
with open(f'{output_dir}/targets.txt', 'w') as f:
|
653 |
+
f.write(str(target_dict_list))
|
654 |
+
|
655 |
+
try:
|
656 |
+
|
657 |
+
precision, recall, _ = precision_recall_curve(true_labels, predictions.predictions[:, 1])
|
658 |
+
pr_auc = precision_auc(recall, precision)
|
659 |
+
|
660 |
+
print(f'AUC: {pr_auc}')
|
661 |
+
return recall, precision, pr_auc
|
662 |
+
except:
|
663 |
+
pass
|
664 |
+
|
665 |
+
else:
|
666 |
+
# Optimizes hyperparameters
|
667 |
+
|
668 |
+
num_classes = len(list(set(labeled_train_split['label'])))
|
669 |
+
def model_init():
|
670 |
+
model = BertForSequenceClassification.from_pretrained(pretrained_model,
|
671 |
+
num_labels=num_classes,
|
672 |
+
output_attentions = False,
|
673 |
+
output_hidden_states = False)
|
674 |
+
|
675 |
+
if freeze_layers is not None:
|
676 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
677 |
+
for module in modules_to_freeze:
|
678 |
+
for param in module.parameters():
|
679 |
+
param.requires_grad = False
|
680 |
+
model = model.to(device)
|
681 |
+
return model
|
682 |
+
|
683 |
+
|
684 |
+
# create the trainer
|
685 |
+
trainer = Trainer(
|
686 |
+
model_init=model_init,
|
687 |
+
args=training_args_init,
|
688 |
+
data_collator=DataCollatorForCellClassification(),
|
689 |
+
train_dataset=labeled_train_split,
|
690 |
+
eval_dataset=labeled_eval_split_subset,
|
691 |
+
compute_metrics=compute_metrics
|
692 |
+
)
|
693 |
+
|
694 |
+
# specify raytune hyperparameter search space
|
695 |
+
ray_config = {
|
696 |
+
"num_train_epochs": tune.choice([epochs]),
|
697 |
+
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
698 |
+
"weight_decay": tune.uniform(0.0, 0.3),
|
699 |
+
"lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
|
700 |
+
"warmup_steps": tune.uniform(100, 2000),
|
701 |
+
"seed": tune.uniform(0,100),
|
702 |
+
"per_device_train_batch_size": tune.choice([geneformer_batch_size])
|
703 |
+
}
|
704 |
+
|
705 |
+
hyperopt_search = HyperOptSearch(
|
706 |
+
metric="eval_accuracy", mode="max")
|
707 |
+
|
708 |
+
if torch.device == 'cuda':
|
709 |
+
resources_per_trial={"cpu":8,"gpu":1},
|
710 |
+
else:
|
711 |
+
resources_per_trial={"cpu":8}
|
712 |
+
|
713 |
+
# optimize hyperparameters
|
714 |
+
best_trial = trainer.hyperparameter_search(
|
715 |
+
direction="maximize",
|
716 |
+
backend="ray",
|
717 |
+
resources_per_trial = resources_per_trial,
|
718 |
+
hp_space=lambda _: ray_config,
|
719 |
+
search_alg=hyperopt_search,
|
720 |
+
n_trials=10, # number of trials
|
721 |
+
progress_reporter=tune.CLIReporter(max_report_frequency=600,
|
722 |
+
sort_by_metric=True,
|
723 |
+
max_progress_rows=100,
|
724 |
+
mode="max",
|
725 |
+
metric="eval_accuracy",
|
726 |
+
metric_columns=["loss", "eval_loss", "eval_accuracy"]))
|
727 |
+
best_hyperparameters = best_trial.hyperparameters
|
728 |
+
|
729 |
+
print("Best Hyperparameters:")
|
730 |
+
print(best_hyperparameters)
|
731 |
+
|
732 |
+
|
733 |
+
# Performs Inference with model
|
734 |
+
if inference == True:
|
735 |
+
if dataset_split != None and data_filter != None:
|
736 |
+
def if_label(example):
|
737 |
+
return example[dataset_split] == data_filter
|
738 |
+
|
739 |
+
train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
|
740 |
+
|
741 |
+
trainset_label_shuffled = train_dataset
|
742 |
+
total_cells = len(trainset_label_shuffled)
|
743 |
+
|
744 |
+
# loads dictionary of all cell labels model was trained on
|
745 |
+
with open(Path(model_location) / 'targets.txt', 'r') as f:
|
746 |
+
data = ast.literal_eval(f.read())
|
747 |
+
if dataset_split != None and data_filter == None:
|
748 |
+
indexer = dataset_split.index(data_filter)
|
749 |
+
data = data[indexer]
|
750 |
+
|
751 |
+
target_dict_list = {key:value for key, value in enumerate(data)}
|
752 |
+
|
753 |
+
# set logging steps
|
754 |
+
logging_steps = round(len(trainset_label_shuffled)/geneformer_batch_size/20)
|
755 |
+
|
756 |
+
# reload pretrained model
|
757 |
+
input_ids = trainset_label_shuffled["input_ids"]
|
758 |
+
inputs = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
|
759 |
+
attention = torch.zeros(len(input_ids), max_input_size, dtype=torch.int64)
|
760 |
+
|
761 |
+
for i, sentence in enumerate(input_ids):
|
762 |
+
sentence_length = len(sentence)
|
763 |
+
if sentence_length <= max_input_size:
|
764 |
+
inputs[i, :sentence_length] = torch.tensor(sentence)
|
765 |
+
attention[i, :sentence_length] = torch.ones(sentence_length)
|
766 |
+
else:
|
767 |
+
inputs[i, :] = torch.tensor(sentence[:max_input_size])
|
768 |
+
attention[i, :] = torch.ones(max_input_size)
|
769 |
+
|
770 |
+
model = BertForSequenceClassification.from_pretrained(model_location, num_labels=len(target_dict_list)).to(device)
|
771 |
+
model_outputs = model(inputs.to(device), attention_mask = attention)["logits"]
|
772 |
+
predictions = F.softmax(model_outputs, dim = -1).argmax(-1)
|
773 |
+
|
774 |
+
predictions = [target_dict_list[int(pred)] for pred in predictions]
|
775 |
+
|
776 |
+
return predictions
|
777 |
+
|
778 |
+
# Extracts embeddings from labelled data
|
779 |
+
if emb_extract == True:
|
780 |
+
if emb_filter == None:
|
781 |
+
with open(f'{model_location}/targets.txt', 'r') as f:
|
782 |
+
data = ast.literal_eval(f.read())
|
783 |
+
if dataset_split != None and data_filter == None:
|
784 |
+
indexer = dataset_split.index(data_filter)
|
785 |
+
data = data[indexer]
|
786 |
+
|
787 |
+
target_dict_list = {key:value for key, value in enumerate(data)}
|
788 |
+
total_filter = None
|
789 |
+
else:
|
790 |
+
total_filter = emb_filter
|
791 |
+
|
792 |
+
train_dataset = load_from_disk(dataset)
|
793 |
+
if dataset_split != None:
|
794 |
+
def if_label(example):
|
795 |
+
return example[dataset_split] == data_filter
|
796 |
+
|
797 |
+
train_dataset = train_dataset.filter(if_label, num_proc=cpu_cores)
|
798 |
+
|
799 |
+
label_counter = Counter(train_dataset[label])
|
800 |
+
total_cells = sum(label_counter.values())
|
801 |
+
cells_to_keep = [k for k,v in label_counter.items() if v>(filter_cells*total_cells)]
|
802 |
+
|
803 |
+
def if_not_rare(example):
|
804 |
+
return example[label] in cells_to_keep
|
805 |
+
|
806 |
+
train_dataset = train_dataset.filter(if_not_rare, num_proc=cpu_cores)
|
807 |
+
|
808 |
+
true_labels = train_dataset[label]
|
809 |
+
num_classes = len(list(set(true_labels)))
|
810 |
+
|
811 |
+
embex = EmbExtractor(model_type="CellClassifier", num_classes=num_classes,
|
812 |
+
filter_data=total_filter, max_ncells=max_cells, emb_layer=emb_layer,
|
813 |
+
emb_label=[dataset_split,label], labels_to_plot=[label], forward_batch_size=forward_batch, nproc=cpu_cores)
|
814 |
+
|
815 |
+
# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset
|
816 |
+
subprocess.call(f'mkdir -p {emb_dir}', shell = True)
|
817 |
+
|
818 |
+
embs = embex.extract_embs(model_directory = model_location, input_data_file = dataset, output_directory = emb_dir, output_prefix = f"{label}_embbeddings")
|
819 |
+
true_labels = embex.filtered_input_data[label]
|
820 |
+
|
821 |
+
emb_dict = {label:[] for label in list(set(true_labels))}
|
822 |
+
for num, emb in embs.iterrows():
|
823 |
+
key = emb[label]
|
824 |
+
selection = emb.iloc[:255]
|
825 |
+
emb = torch.Tensor(selection)
|
826 |
+
emb_dict[key].append(emb)
|
827 |
+
|
828 |
+
for key in list(emb_dict.keys()):
|
829 |
+
stack = torch.stack(emb_dict[key], dim = 0)
|
830 |
+
emb_dict[key] = torch.mean(stack, dim=0)
|
831 |
+
similarities = {key:{} for key in list(emb_dict.keys())}
|
832 |
+
|
833 |
+
for key in list(emb_dict.keys()):
|
834 |
+
remaining_keys = [k for k in list(emb_dict.keys()) if k != key]
|
835 |
+
for k in remaining_keys:
|
836 |
+
embedding = emb_dict[k]
|
837 |
+
sim = similarity(emb_dict[key], embedding, cosine = True)
|
838 |
+
|
839 |
+
similarities[key][k] = sim
|
840 |
+
|
841 |
+
plot_similarity_heatmap(similarities)
|
842 |
+
|
843 |
+
embex.plot_embs(embs=embs,
|
844 |
+
plot_style="umap",
|
845 |
+
output_directory=emb_dir,
|
846 |
+
output_prefix="emb_plot")
|
847 |
+
|
848 |
+
|
849 |
+
embex.plot_embs(embs=embs,
|
850 |
+
plot_style="heatmap",
|
851 |
+
output_directory=emb_dir,
|
852 |
+
output_prefix="emb_plot")
|
853 |
+
|
854 |
+
|
855 |
+
return similarities
|
856 |
+
|
857 |
+
if __name__ == '__main__':
|
858 |
+
predictions = finetune_cells(skip_training = False, dataset_split = None, label = "disease", sample_data = .5, data_filter = 'hcm', epochs = 10, output_dir = 'hcm_model', model_location = 'hcm_model',
|
859 |
+
emb_extract = True, geneformer_batch_size = 12, inference = False, dataset = "/work/ccnr/GeneFormer/GeneFormer_repo/Genecorpus-30M/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/")
|
860 |
+
|
861 |
+
|
Gene_classifier.py
ADDED
@@ -0,0 +1,746 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
GPU_NUMBER = [0] # CHANGE WITH MULTIGPU
|
4 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
5 |
+
os.environ["NCCL_DEBUG"] = "INFO"
|
6 |
+
|
7 |
+
# imports
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
+
import datetime
|
10 |
+
import subprocess
|
11 |
+
from pathlib import Path
|
12 |
+
import math
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import numpy as np
|
15 |
+
import pickle
|
16 |
+
import pandas as pd
|
17 |
+
from datasets import load_from_disk, Dataset
|
18 |
+
from sklearn import preprocessing
|
19 |
+
from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve
|
20 |
+
from sklearn.model_selection import StratifiedKFold
|
21 |
+
import torch
|
22 |
+
from transformers import BertForTokenClassification
|
23 |
+
from transformers import Trainer
|
24 |
+
from transformers.training_args import TrainingArguments
|
25 |
+
from tqdm.notebook import tqdm
|
26 |
+
from sklearn.metrics import roc_curve, roc_auc_score
|
27 |
+
from geneformer import DataCollatorForGeneClassification, EmbExtractor
|
28 |
+
from geneformer.pretrainer import token_dictionary
|
29 |
+
import ast
|
30 |
+
import torch.nn.functional as F
|
31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
from geneformer import TranscriptomeTokenizer
|
33 |
+
|
34 |
+
def vote(logit_pair):
|
35 |
+
a, b = logit_pair
|
36 |
+
if a > b:
|
37 |
+
return 0
|
38 |
+
elif b > a:
|
39 |
+
return 1
|
40 |
+
elif a == b:
|
41 |
+
return "tie"
|
42 |
+
|
43 |
+
def py_softmax(vector):
|
44 |
+
e = np.exp(vector)
|
45 |
+
return e / e.sum()
|
46 |
+
|
47 |
+
# Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
|
48 |
+
def similarity(tensor1, tensor2, cosine = True):
|
49 |
+
if cosine == False:
|
50 |
+
if tensor1.ndimension() > 1:
|
51 |
+
tensor1 = tensor1.view(1, -1)
|
52 |
+
if tensor2.ndimension() > 1:
|
53 |
+
tensor2 = tensor2.view(1, -1)
|
54 |
+
dot_product = torch.matmul(tensor1, tensor2)
|
55 |
+
norm_tensor1 = torch.norm(tensor1)
|
56 |
+
norm_tensor2 = torch.norm(tensor2)
|
57 |
+
epsilon = 1e-8
|
58 |
+
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
59 |
+
similarity = (similarity.item() + 1)/2
|
60 |
+
else:
|
61 |
+
if tensor1.shape != tensor2.shape:
|
62 |
+
raise ValueError("Input tensors must have the same shape.")
|
63 |
+
|
64 |
+
# Compute cosine similarity using PyTorch's dot product function
|
65 |
+
dot_product = torch.dot(tensor1, tensor2)
|
66 |
+
norm_tensor1 = torch.norm(tensor1)
|
67 |
+
norm_tensor2 = torch.norm(tensor2)
|
68 |
+
|
69 |
+
# Avoid division by zero by adding a small epsilon
|
70 |
+
epsilon = 1e-8
|
71 |
+
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
|
72 |
+
|
73 |
+
return similarity.item()
|
74 |
+
|
75 |
+
# Plots heatmap between different classes/labels
|
76 |
+
def plot_similarity_heatmap(similarities):
|
77 |
+
classes = list(similarities.keys())
|
78 |
+
classlen = len(classes)
|
79 |
+
arr = np.zeros((classlen, classlen))
|
80 |
+
for i, c in enumerate(classes):
|
81 |
+
for j, cc in enumerate(classes):
|
82 |
+
if cc == c:
|
83 |
+
val = 1.0
|
84 |
+
else:
|
85 |
+
val = similarities[c][cc]
|
86 |
+
arr[i][j] = val
|
87 |
+
|
88 |
+
plt.figure(figsize=(8, 6))
|
89 |
+
plt.imshow(arr, cmap='inferno', vmin=0, vmax=1)
|
90 |
+
plt.colorbar()
|
91 |
+
plt.xticks(np.arange(classlen), classes, rotation = 45, ha = 'right')
|
92 |
+
plt.yticks(np.arange(classlen), classes)
|
93 |
+
plt.title("Similarity Heatmap")
|
94 |
+
plt.savefig("similarity_heatmap.png")
|
95 |
+
|
96 |
+
# get cross-validated mean and sd metrics
|
97 |
+
def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
|
98 |
+
wts = [count/sum(all_tpr_wt) for count in all_tpr_wt]
|
99 |
+
|
100 |
+
all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)]
|
101 |
+
mean_tpr = np.sum(all_weighted_tpr, axis=0)
|
102 |
+
mean_tpr[-1] = 1.0
|
103 |
+
all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]
|
104 |
+
roc_auc = np.sum(all_weighted_roc_auc)
|
105 |
+
roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))
|
106 |
+
return mean_tpr, roc_auc, roc_auc_sd
|
107 |
+
|
108 |
+
def validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc, num_labels, pre_model):
|
109 |
+
# initiate eval metrics to return
|
110 |
+
num_classes = len(set(labels))
|
111 |
+
mean_fpr = np.linspace(0, 1, 100)
|
112 |
+
|
113 |
+
# create 80/20 train/eval splits
|
114 |
+
targets_train, targets_eval, labels_train, labels_eval = train_test_split(targets, labels ,test_size=0.25, shuffle=True)
|
115 |
+
label_dict_train = dict(zip(targets_train, labels_train))
|
116 |
+
label_dict_eval = dict(zip(targets_eval, labels_eval))
|
117 |
+
|
118 |
+
# function to filter by whether contains train or eval labels
|
119 |
+
def if_contains_train_label(example):
|
120 |
+
a = label_dict_train.keys()
|
121 |
+
b = example['input_ids']
|
122 |
+
return not set(a).isdisjoint(b)
|
123 |
+
|
124 |
+
def if_contains_eval_label(example):
|
125 |
+
a = label_dict_eval.keys()
|
126 |
+
b = example['input_ids']
|
127 |
+
return not set(a).isdisjoint(b)
|
128 |
+
|
129 |
+
# filter dataset for examples containing classes for this split
|
130 |
+
print(f"Filtering training data")
|
131 |
+
trainset = data.filter(if_contains_train_label, num_proc=num_proc)
|
132 |
+
print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n")
|
133 |
+
print(f"Filtering evalation data")
|
134 |
+
evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
|
135 |
+
print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
|
136 |
+
|
137 |
+
# minimize to smaller training sample
|
138 |
+
training_size = min(subsample_size, len(trainset))
|
139 |
+
trainset_min = trainset.select([i for i in range(training_size)])
|
140 |
+
eval_size = min(training_size, len(evalset))
|
141 |
+
half_training_size = round(eval_size/2)
|
142 |
+
evalset_train_min = evalset.select([i for i in range(half_training_size)])
|
143 |
+
evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
|
144 |
+
|
145 |
+
# label conversion functions
|
146 |
+
def generate_train_labels(example):
|
147 |
+
example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]]
|
148 |
+
return example
|
149 |
+
|
150 |
+
def generate_eval_labels(example):
|
151 |
+
example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]]
|
152 |
+
return example
|
153 |
+
|
154 |
+
# label datasets
|
155 |
+
print(f"Labeling training data")
|
156 |
+
trainset_labeled = trainset_min.map(generate_train_labels)
|
157 |
+
print(f"Labeling evaluation data")
|
158 |
+
evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
|
159 |
+
print(f"Labeling evaluation OOS data")
|
160 |
+
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
|
161 |
+
|
162 |
+
# load model
|
163 |
+
model = BertForTokenClassification.from_pretrained(
|
164 |
+
pre_model,
|
165 |
+
num_labels=num_labels,
|
166 |
+
output_attentions = False,
|
167 |
+
output_hidden_states = False,
|
168 |
+
)
|
169 |
+
if freeze_layers is not None:
|
170 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
171 |
+
for module in modules_to_freeze:
|
172 |
+
for param in module.parameters():
|
173 |
+
param.requires_grad = False
|
174 |
+
|
175 |
+
model = model.to(device)
|
176 |
+
|
177 |
+
# add output directory to training args and initiate
|
178 |
+
training_args["output_dir"] = output_dir
|
179 |
+
training_args_init = TrainingArguments(**training_args)
|
180 |
+
|
181 |
+
# create the trainer
|
182 |
+
trainer = Trainer(
|
183 |
+
model=model,
|
184 |
+
args=training_args_init,
|
185 |
+
data_collator=DataCollatorForGeneClassification(),
|
186 |
+
train_dataset=trainset_labeled,
|
187 |
+
eval_dataset=evalset_train_labeled,
|
188 |
+
)
|
189 |
+
|
190 |
+
# train the gene classifier
|
191 |
+
trainer.train()
|
192 |
+
trainer.save_model(output_dir)
|
193 |
+
|
194 |
+
fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 200, mean_fpr)
|
195 |
+
auc_score = auc(fpr, tpr)
|
196 |
+
|
197 |
+
return fpr, tpr, auc_score
|
198 |
+
|
199 |
+
# cross-validate gene classifier
|
200 |
+
def cross_validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc, num_labels, pre_model):
|
201 |
+
# check if output directory already written to
|
202 |
+
# ensure not overwriting previously saved model
|
203 |
+
model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin")
|
204 |
+
#if os.path.isfile(model_dir_test) == True:
|
205 |
+
# raise Exception("Model already saved to this directory.")
|
206 |
+
|
207 |
+
device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
208 |
+
# initiate eval metrics to return
|
209 |
+
num_classes = len(set(labels))
|
210 |
+
mean_fpr = np.linspace(0, 1, 100)
|
211 |
+
all_tpr = []
|
212 |
+
all_roc_auc = []
|
213 |
+
all_tpr_wt = []
|
214 |
+
label_dicts = []
|
215 |
+
confusion = np.zeros((num_classes,num_classes))
|
216 |
+
|
217 |
+
# set up cross-validation splits
|
218 |
+
skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)
|
219 |
+
# train and evaluate
|
220 |
+
iteration_num = 0
|
221 |
+
for train_index, eval_index in tqdm(skf.split(targets, labels)):
|
222 |
+
if len(labels) > 500:
|
223 |
+
print("early stopping activated due to large # of training examples")
|
224 |
+
if iteration_num == 3:
|
225 |
+
break
|
226 |
+
|
227 |
+
print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n")
|
228 |
+
|
229 |
+
# generate cross-validation splits
|
230 |
+
targets_train, targets_eval = targets[train_index], targets[eval_index]
|
231 |
+
labels_train, labels_eval = labels[train_index], labels[eval_index]
|
232 |
+
label_dict_train = dict(zip(targets_train, labels_train))
|
233 |
+
label_dict_eval = dict(zip(targets_eval, labels_eval))
|
234 |
+
label_dicts += (iteration_num, targets_train, targets_eval, labels_train, labels_eval)
|
235 |
+
|
236 |
+
# function to filter by whether contains train or eval labels
|
237 |
+
def if_contains_train_label(example):
|
238 |
+
a = label_dict_train.keys()
|
239 |
+
b = example['input_ids']
|
240 |
+
|
241 |
+
return not set(a).isdisjoint(b)
|
242 |
+
|
243 |
+
def if_contains_eval_label(example):
|
244 |
+
a = label_dict_eval.keys()
|
245 |
+
b = example['input_ids']
|
246 |
+
|
247 |
+
return not set(a).isdisjoint(b)
|
248 |
+
|
249 |
+
# filter dataset for examples containing classes for this split
|
250 |
+
print(f"Filtering training data")
|
251 |
+
trainset = data.filter(if_contains_train_label, num_proc=num_proc)
|
252 |
+
print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n")
|
253 |
+
print(f"Filtering evalation data")
|
254 |
+
evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
|
255 |
+
print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
|
256 |
+
|
257 |
+
# minimize to smaller training sample
|
258 |
+
training_size = min(subsample_size, len(trainset))
|
259 |
+
trainset_min = trainset.select([i for i in range(training_size)])
|
260 |
+
eval_size = min(training_size, len(evalset))
|
261 |
+
half_training_size = round(eval_size/2)
|
262 |
+
evalset_train_min = evalset.select([i for i in range(half_training_size)])
|
263 |
+
evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
|
264 |
+
|
265 |
+
# label conversion functions
|
266 |
+
def generate_train_labels(example):
|
267 |
+
example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]]
|
268 |
+
return example
|
269 |
+
|
270 |
+
def generate_eval_labels(example):
|
271 |
+
example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]]
|
272 |
+
return example
|
273 |
+
|
274 |
+
# label datasets
|
275 |
+
print(f"Labeling training data")
|
276 |
+
trainset_labeled = trainset_min.map(generate_train_labels)
|
277 |
+
print(f"Labeling evaluation data")
|
278 |
+
evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
|
279 |
+
print(f"Labeling evaluation OOS data")
|
280 |
+
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
|
281 |
+
|
282 |
+
# create output directories
|
283 |
+
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
284 |
+
ksplit_model_dir = os.path.join(ksplit_output_dir, "models/")
|
285 |
+
|
286 |
+
# ensure not overwriting previously saved model
|
287 |
+
model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin")
|
288 |
+
#if os.path.isfile(model_output_file) == True:
|
289 |
+
# raise Exception("Model already saved to this directory.")
|
290 |
+
|
291 |
+
# make training and model output directories
|
292 |
+
subprocess.call(f'mkdir -p {ksplit_output_dir}', shell=True)
|
293 |
+
subprocess.call(f'mkdir -p {ksplit_model_dir}', shell=True)
|
294 |
+
|
295 |
+
# load model
|
296 |
+
model = BertForTokenClassification.from_pretrained(
|
297 |
+
pre_model,
|
298 |
+
num_labels=num_labels,
|
299 |
+
output_attentions = False,
|
300 |
+
output_hidden_states = False,
|
301 |
+
)
|
302 |
+
if freeze_layers is not None:
|
303 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
304 |
+
for module in modules_to_freeze:
|
305 |
+
for param in module.parameters():
|
306 |
+
param.requires_grad = False
|
307 |
+
|
308 |
+
model = model.to(device)
|
309 |
+
|
310 |
+
# add output directory to training args and initiate
|
311 |
+
training_args["output_dir"] = ksplit_output_dir
|
312 |
+
training_args_init = TrainingArguments(**training_args)
|
313 |
+
|
314 |
+
# create the trainer
|
315 |
+
trainer = Trainer(
|
316 |
+
model=model,
|
317 |
+
args=training_args_init,
|
318 |
+
data_collator=DataCollatorForGeneClassification(),
|
319 |
+
train_dataset=trainset_labeled,
|
320 |
+
eval_dataset=evalset_train_labeled
|
321 |
+
)
|
322 |
+
|
323 |
+
# train the gene classifier
|
324 |
+
trainer.train()
|
325 |
+
|
326 |
+
# save model
|
327 |
+
trainer.save_model(ksplit_model_dir)
|
328 |
+
|
329 |
+
# evaluate model
|
330 |
+
fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 200, mean_fpr)
|
331 |
+
|
332 |
+
# append to tpr and roc lists
|
333 |
+
confusion = confusion + conf_mat
|
334 |
+
all_tpr.append(interp_tpr)
|
335 |
+
all_roc_auc.append(auc(fpr, tpr))
|
336 |
+
# append number of eval examples by which to weight tpr in averaged graphs
|
337 |
+
all_tpr_wt.append(len(tpr))
|
338 |
+
|
339 |
+
iteration_num = iteration_num + 1
|
340 |
+
|
341 |
+
# get overall metrics for cross-validation
|
342 |
+
mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt)
|
343 |
+
return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts
|
344 |
+
|
345 |
+
# Computes metrics
|
346 |
+
def compute_metrics(pred):
|
347 |
+
labels = pred.label_ids
|
348 |
+
preds = pred.predictions.argmax(-1)
|
349 |
+
# calculate accuracy and macro f1 using sklearn's function
|
350 |
+
acc = accuracy_score(labels, preds)
|
351 |
+
macro_f1 = f1_score(labels, preds, average='macro')
|
352 |
+
|
353 |
+
return {
|
354 |
+
'accuracy': acc,
|
355 |
+
'macro_f1': macro_f1
|
356 |
+
}
|
357 |
+
|
358 |
+
# plot ROC curve
|
359 |
+
def plot_ROC(bundled_data, title):
|
360 |
+
plt.figure()
|
361 |
+
lw = 2
|
362 |
+
for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
|
363 |
+
plt.plot(mean_fpr, mean_tpr, color=color,
|
364 |
+
lw=lw, label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(sample, roc_auc, roc_auc_sd))
|
365 |
+
|
366 |
+
plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
|
367 |
+
plt.xlim([0.0, 1.0])
|
368 |
+
plt.ylim([0.0, 1.05])
|
369 |
+
plt.xlabel('False Positive Rate')
|
370 |
+
plt.ylabel('True Positive Rate')
|
371 |
+
plt.title(title)
|
372 |
+
plt.legend(loc="lower right")
|
373 |
+
plt.savefig("ROC.png")
|
374 |
+
|
375 |
+
return mean_fpr, mean_tpr, roc_auc
|
376 |
+
|
377 |
+
# plot confusion matrix
|
378 |
+
def plot_confusion_matrix(classes_list, conf_mat, title):
|
379 |
+
display_labels = []
|
380 |
+
i = 0
|
381 |
+
for label in classes_list:
|
382 |
+
display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:,i]))]
|
383 |
+
i = i + 1
|
384 |
+
display = ConfusionMatrixDisplay(confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"),
|
385 |
+
display_labels=display_labels)
|
386 |
+
display.plot(cmap="Blues",values_format=".2g")
|
387 |
+
plt.title(title)
|
388 |
+
plt.savefig("CM.png")
|
389 |
+
|
390 |
+
# Function to find the largest number smaller
|
391 |
+
# than or equal to N that is divisible by k
|
392 |
+
def find_largest_div(N, K):
|
393 |
+
rem = N % K
|
394 |
+
if(rem == 0):
|
395 |
+
return N
|
396 |
+
else:
|
397 |
+
return N - rem
|
398 |
+
|
399 |
+
def preprocess_classifier_batch(cell_batch, max_len):
|
400 |
+
if max_len == None:
|
401 |
+
max_len = max([len(i) for i in cell_batch["input_ids"]])
|
402 |
+
def pad_label_example(example):
|
403 |
+
example["labels"] = np.pad(example["labels"],
|
404 |
+
(0, max_len-len(example["input_ids"])),
|
405 |
+
mode='constant', constant_values=-100)
|
406 |
+
example["input_ids"] = np.pad(example["input_ids"],
|
407 |
+
(0, max_len-len(example["input_ids"])),
|
408 |
+
mode='constant', constant_values=token_dictionary.get("<pad>"))
|
409 |
+
example["attention_mask"] = (example["input_ids"] != token_dictionary.get("<pad>")).astype(int)
|
410 |
+
return example
|
411 |
+
padded_batch = cell_batch.map(pad_label_example)
|
412 |
+
return padded_batch
|
413 |
+
|
414 |
+
# forward batch size is batch size for model inference (e.g. 200)
|
415 |
+
def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
|
416 |
+
predict_logits = []
|
417 |
+
predict_labels = []
|
418 |
+
model.to('cpu')
|
419 |
+
model.eval()
|
420 |
+
|
421 |
+
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
|
422 |
+
evalset_len = len(evalset)
|
423 |
+
max_divisible = find_largest_div(evalset_len, forward_batch_size)
|
424 |
+
if len(evalset) - max_divisible == 1:
|
425 |
+
evalset_len = max_divisible
|
426 |
+
|
427 |
+
max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
|
428 |
+
|
429 |
+
for i in range(0, evalset_len, forward_batch_size):
|
430 |
+
max_range = min(i+forward_batch_size, evalset_len)
|
431 |
+
batch_evalset = evalset.select([i for i in range(i, max_range)])
|
432 |
+
padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
|
433 |
+
padded_batch.set_format(type="torch")
|
434 |
+
|
435 |
+
input_data_batch = padded_batch["input_ids"]
|
436 |
+
attn_msk_batch = padded_batch["attention_mask"]
|
437 |
+
label_batch = padded_batch["labels"]
|
438 |
+
with torch.no_grad():
|
439 |
+
input_ids = input_data_batch
|
440 |
+
attn_mask = attn_msk_batch
|
441 |
+
labels = label_batch
|
442 |
+
outputs = model(
|
443 |
+
|
444 |
+
input_ids = input_ids,
|
445 |
+
attention_mask = attn_mask,
|
446 |
+
labels = labels
|
447 |
+
)
|
448 |
+
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
|
449 |
+
predict_labels += [torch.squeeze(label_batch.to("cpu"))]
|
450 |
+
|
451 |
+
logits_by_cell = torch.cat(predict_logits)
|
452 |
+
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
|
453 |
+
labels_by_cell = torch.cat(predict_labels)
|
454 |
+
all_labels = torch.flatten(labels_by_cell)
|
455 |
+
logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100]
|
456 |
+
y_pred = [vote(item[0]) for item in logit_label_paired]
|
457 |
+
y_true = [item[1] for item in logit_label_paired]
|
458 |
+
logits_list = [item[0] for item in logit_label_paired]
|
459 |
+
# probability of class 1
|
460 |
+
y_score = [py_softmax(item)[1] for item in logits_list]
|
461 |
+
conf_mat = confusion_matrix(y_true, y_pred)
|
462 |
+
fpr, tpr, _ = roc_curve(y_true, y_score)
|
463 |
+
# plot roc_curve for this split
|
464 |
+
plt.plot(fpr, tpr)
|
465 |
+
plt.xlim([0.0, 1.0])
|
466 |
+
plt.ylim([0.0, 1.05])
|
467 |
+
plt.xlabel('False Positive Rate')
|
468 |
+
plt.ylabel('True Positive Rate')
|
469 |
+
plt.title('ROC')
|
470 |
+
plt.show()
|
471 |
+
# interpolate to graph
|
472 |
+
interp_tpr = np.interp(mean_fpr, fpr, tpr)
|
473 |
+
interp_tpr[0] = 0.0
|
474 |
+
return fpr, tpr, interp_tpr, conf_mat
|
475 |
+
|
476 |
+
def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv", genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
|
477 |
+
corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
|
478 |
+
max_input_size = 2 ** 11, max_lr = 5e-5, freeze_layers = 4, num_gpus = 1, num_proc = os.cpu_count(), geneformer_batch_size = 9, epochs = 1, filter_dataset = 50_000,
|
479 |
+
emb_extract = True, emb_layer = 0, forward_batch = 200, filter_data = None, inference = False, k_validate = True, model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/", skip_training = False, emb_dir = 'gene_emb', output_dir = None, max_cells = 1000, num_cpus = os.cpu_count()):
|
480 |
+
|
481 |
+
|
482 |
+
""""
|
483 |
+
Primary Parameters
|
484 |
+
-----------
|
485 |
+
|
486 |
+
gene_info: path
|
487 |
+
Path to gene mappings
|
488 |
+
|
489 |
+
corpus_30M: path
|
490 |
+
Path to 30M Gene Corpus
|
491 |
+
|
492 |
+
model: path
|
493 |
+
Path to pretrained GeneFormer model
|
494 |
+
|
495 |
+
genes: path
|
496 |
+
Path to csv file containing different columns of genes and the column labels
|
497 |
+
|
498 |
+
inference: bool
|
499 |
+
Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
|
500 |
+
|
501 |
+
k_validate: bool
|
502 |
+
Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
|
503 |
+
|
504 |
+
skip_training: bool
|
505 |
+
Whether the model should skip the training portion. Defaults to False
|
506 |
+
|
507 |
+
emb_extract: bool
|
508 |
+
WHether the model should extract embeddings for a given gene (WIP)
|
509 |
+
|
510 |
+
|
511 |
+
Customization Parameters
|
512 |
+
-----------
|
513 |
+
|
514 |
+
freeze_layers: int
|
515 |
+
Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
|
516 |
+
|
517 |
+
filter_dataset: int
|
518 |
+
Number of cells to filter from 30M dataset. Default is 50_000
|
519 |
+
|
520 |
+
emb_layer: int
|
521 |
+
What layer embeddings are extracted from. Default is 4
|
522 |
+
|
523 |
+
filter_data: str, list
|
524 |
+
Filters down embeddings to a single category. Default is None
|
525 |
+
|
526 |
+
|
527 |
+
"""
|
528 |
+
|
529 |
+
# table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)
|
530 |
+
gene_info = pd.read_csv(gene_info, index_col=0)
|
531 |
+
labels = gene_info.columns
|
532 |
+
|
533 |
+
# create dictionaries for corresponding attributes
|
534 |
+
gene_id_type_dict = dict(zip(gene_info["ensembl_id"],gene_info["gene_type"]))
|
535 |
+
gene_name_id_dict = dict(zip(gene_info["gene_name"],gene_info["ensembl_id"]))
|
536 |
+
gene_id_name_dict = {v: k for k,v in gene_name_id_dict.items()}
|
537 |
+
|
538 |
+
# function for preparing targets and labels
|
539 |
+
def prep_inputs(label_store, id_type):
|
540 |
+
target_list = []
|
541 |
+
if id_type == "gene_name":
|
542 |
+
for key in list(label_store.keys()):
|
543 |
+
targets = [gene_name_id_dict[gene] for gene in label_store[key] if gene_name_id_dict.get(gene) in token_dictionary]
|
544 |
+
targets_id = [token_dictionary[gene] for gene in targets]
|
545 |
+
target_list.append(targets_id)
|
546 |
+
elif id_type == "ensembl_id":
|
547 |
+
for key in list(label_store.keys()):
|
548 |
+
targets = [gene for gene in label_store[key] if gene in token_dictionary]
|
549 |
+
targets_id = [token_dictionary[gene] for gene in targets]
|
550 |
+
target_list.append(targets_id)
|
551 |
+
|
552 |
+
targets, labels = [], []
|
553 |
+
for targ in target_list:
|
554 |
+
targets = targets + targ
|
555 |
+
targets = np.array(targets)
|
556 |
+
for num, targ in enumerate(target_list):
|
557 |
+
label = [num]*len(targ)
|
558 |
+
labels = labels + label
|
559 |
+
labels = np.array(labels)
|
560 |
+
unique_labels = num + 1
|
561 |
+
|
562 |
+
nsplits = min(5, min([len(targ) for targ in target_list])-1)
|
563 |
+
assert nsplits > 2
|
564 |
+
|
565 |
+
return targets, labels, nsplits, unique_labels
|
566 |
+
|
567 |
+
if skip_training == False:
|
568 |
+
# preparing targets and labels for dosage sensitive vs insensitive TFs
|
569 |
+
gene_classes = pd.read_csv(genes, header=0)
|
570 |
+
if filter_data == None:
|
571 |
+
labels = gene_classes.columns
|
572 |
+
else:
|
573 |
+
if isinstance(filter_data, list):
|
574 |
+
labels = filter_data
|
575 |
+
else:
|
576 |
+
labels = [filter_data]
|
577 |
+
label_store = {}
|
578 |
+
|
579 |
+
# Dictionary for decoding labels
|
580 |
+
decode = {i:labels[i] for i in range(len(labels))}
|
581 |
+
|
582 |
+
for label in labels:
|
583 |
+
label_store[label] = gene_classes[label].dropna()
|
584 |
+
|
585 |
+
targets, labels, nsplits, unique_labels = prep_inputs(label_store, "ensembl_id")
|
586 |
+
|
587 |
+
|
588 |
+
|
589 |
+
# load training dataset
|
590 |
+
train_dataset=load_from_disk(corpus_30M)
|
591 |
+
shuffled_train_dataset = train_dataset.shuffle(seed=42)
|
592 |
+
subsampled_train_dataset = shuffled_train_dataset.select([i for i in range(filter_dataset)])
|
593 |
+
lr_schedule_fn = "linear"
|
594 |
+
warmup_steps = 500
|
595 |
+
optimizer = "adamw"
|
596 |
+
subsample_size = 10_000
|
597 |
+
|
598 |
+
training_args = {
|
599 |
+
"learning_rate": max_lr,
|
600 |
+
"do_train": True,
|
601 |
+
"evaluation_strategy": "no",
|
602 |
+
"save_strategy": "epoch",
|
603 |
+
"logging_steps": 10,
|
604 |
+
"group_by_length": True,
|
605 |
+
"length_column_name": "length",
|
606 |
+
"disable_tqdm": False,
|
607 |
+
"lr_scheduler_type": lr_schedule_fn,
|
608 |
+
"warmup_steps": warmup_steps,
|
609 |
+
"weight_decay": 0.001,
|
610 |
+
"per_device_train_batch_size": geneformer_batch_size,
|
611 |
+
"per_device_eval_batch_size": geneformer_batch_size,
|
612 |
+
"num_train_epochs": epochs,
|
613 |
+
}
|
614 |
+
|
615 |
+
# define output directory path
|
616 |
+
current_date = datetime.datetime.now()
|
617 |
+
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
618 |
+
|
619 |
+
if output_dir == None:
|
620 |
+
training_output_dir = Path(f"{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/")
|
621 |
+
else:
|
622 |
+
training_output_dir = Path(output_dir)
|
623 |
+
|
624 |
+
# make output directory
|
625 |
+
subprocess.call(f'mkdir -p {training_output_dir}', shell=True)
|
626 |
+
|
627 |
+
# Places number of classes + in directory
|
628 |
+
num_classes = len(set(labels))
|
629 |
+
info_list = [num_classes, decode]
|
630 |
+
|
631 |
+
with open(training_output_dir / 'classes.txt', 'w') as f:
|
632 |
+
f.write(str(info_list))
|
633 |
+
|
634 |
+
subsampled_train_dataset.save_to_disk(output_dir / 'dataset')
|
635 |
+
|
636 |
+
if k_validate == True:
|
637 |
+
ksplit_model ="ksplit0/models"
|
638 |
+
ksplit_model_test = os.path.join(training_output_dir, ksplit_model)
|
639 |
+
#if os.path.isfile(ksplit_model_test) == True:
|
640 |
+
# raise Exception("Model already saved to this directory.")
|
641 |
+
# cross-validate gene classifier
|
642 |
+
all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts = cross_validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1, unique_labels, model)
|
643 |
+
|
644 |
+
bundled_data = []
|
645 |
+
bundled_data += [(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red")]
|
646 |
+
graph_title = " ".join([i + ' vs' if count < len(label_store) - 1 else i for count, i in enumerate(label_store)])
|
647 |
+
fpr, tpr, auc = plot_ROC(bundled_data, 'Dosage Sensitive vs Insensitive TFs')
|
648 |
+
print(auc)
|
649 |
+
# plot confusion matrix
|
650 |
+
plot_confusion_matrix(label_store, confusion, "Geneformer")
|
651 |
+
else:
|
652 |
+
fpr, tpr, auc = validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1, unique_labels, model)
|
653 |
+
print(auc)
|
654 |
+
|
655 |
+
if inference == True:
|
656 |
+
# preparing targets and labels for dosage sensitive vs insensitive TFs
|
657 |
+
gene_classes = pd.read_csv(genes, header=0)
|
658 |
+
targets = []
|
659 |
+
for column in gene_classes.columns:
|
660 |
+
targets += list(gene_classes[column])
|
661 |
+
tokens = []
|
662 |
+
for target in targets:
|
663 |
+
try:
|
664 |
+
tokens.append(token_dictionary[target])
|
665 |
+
except:
|
666 |
+
tokens.append(0)
|
667 |
+
|
668 |
+
targets = torch.LongTensor([tokens])
|
669 |
+
|
670 |
+
|
671 |
+
with open(f'{model_location}classes.txt', 'r') as f:
|
672 |
+
info_list = ast.literal_eval(f.read())
|
673 |
+
num_classes = info_list[0]
|
674 |
+
labels = info_list[1]
|
675 |
+
|
676 |
+
model = BertForTokenClassification.from_pretrained(
|
677 |
+
model_location,
|
678 |
+
num_labels=num_classes,
|
679 |
+
output_attentions = False,
|
680 |
+
output_hidden_states = False,
|
681 |
+
local_files_only = True
|
682 |
+
)
|
683 |
+
if freeze_layers is not None:
|
684 |
+
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
685 |
+
for module in modules_to_freeze:
|
686 |
+
for param in module.parameters():
|
687 |
+
param.requires_grad = False
|
688 |
+
|
689 |
+
model = model.to(device)
|
690 |
+
|
691 |
+
# evaluate model
|
692 |
+
predictions = F.softmax(model(targets.to(device))["logits"], dim = -1).argmax(-1)[0]
|
693 |
+
predictions = [labels[int(pred)] for pred in predictions]
|
694 |
+
|
695 |
+
return predictions
|
696 |
+
|
697 |
+
# Extracts aggregate gene embeddings for each label
|
698 |
+
if emb_extract == True:
|
699 |
+
with open(f'{model_location}/classes.txt', 'r') as f:
|
700 |
+
data = ast.literal_eval(f.read())
|
701 |
+
num_classes = data[0]
|
702 |
+
decode = data[1]
|
703 |
+
|
704 |
+
gene_classes = pd.read_csv(genes, header=0)
|
705 |
+
labels = gene_classes.columns
|
706 |
+
tokenize = TranscriptomeTokenizer()
|
707 |
+
|
708 |
+
label_dict = {}
|
709 |
+
for label in labels:
|
710 |
+
genes = gene_classes[label]
|
711 |
+
tokenized_genes = []
|
712 |
+
for gene in genes:
|
713 |
+
try:
|
714 |
+
tokenized_genes.append(tokenize.gene_token_dict[gene])
|
715 |
+
except:
|
716 |
+
continue
|
717 |
+
label_dict[label] = tokenized_genes
|
718 |
+
|
719 |
+
embex = EmbExtractor(model_type="GeneClassifier", num_classes=num_classes, emb_mode = "gene",
|
720 |
+
filter_data=None, max_ncells=max_cells, emb_layer=emb_layer,
|
721 |
+
emb_label=label_dict, labels_to_plot=list(labels), forward_batch_size=forward_batch, nproc=num_cpus)
|
722 |
+
|
723 |
+
|
724 |
+
subprocess.call(f'mkdir -p {emb_dir}', shell = True)
|
725 |
+
|
726 |
+
embs = embex.extract_embs(model_directory = model_location, input_data_file = model_location / 'dataset', output_directory = emb_dir, output_prefix = f"{label}_embbeddings")
|
727 |
+
|
728 |
+
emb_dict = {label:[] for label in list(set(labels))}
|
729 |
+
similarities = {key:{} for key in list(emb_dict.keys())}
|
730 |
+
|
731 |
+
for column in embs.columns:
|
732 |
+
remaining_cols = [k for k in embs.columns if k != column]
|
733 |
+
for k in remaining_cols:
|
734 |
+
embedding = torch.Tensor(embs[k])
|
735 |
+
sim = similarity(torch.Tensor(embs[column]), embedding, cosine = True)
|
736 |
+
similarities[column][k] = sim
|
737 |
+
|
738 |
+
plot_similarity_heatmap(similarities)
|
739 |
+
print(similarities)
|
740 |
+
|
741 |
+
return similarities
|
742 |
+
|
743 |
+
if __name__ == '__main__':
|
744 |
+
classify_genes(k_validate = False, inference = False, skip_training = False, emb_extract = True, output_dir = Path('gene_emb'), model_location = Path('gene_emb'), epochs = 5, gene_info = "../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_info_table.csv", genes = "../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", corpus_30M = "../GeneFormer_repo/Genecorpus-30M/genecorpus_30M_2048.dataset/")
|
745 |
+
|
746 |
+
|
Immune_modelpredictions.pickle
ADDED
Binary file (99.1 kB). View file
|
|
Modular_usage.md
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cell classifier
|
2 |
+
def finetune_cells(token_set = Path('geneformer/token_dictionary.pkl'), median_set = Path('geneformer/gene_median_dictionary.pkl'), pretrained_model = ".",
|
3 |
+
dataset = 'Genecorpus-30M/example_input_files/cell_classification/cell_type_annotation/cell_type_train_data.dataset/',
|
4 |
+
dataset_split = None,
|
5 |
+
filter_cells = .005,
|
6 |
+
epochs = 1,
|
7 |
+
cpu_cores = os.cpu_count(),
|
8 |
+
geneformer_batch_size = 12,
|
9 |
+
optimizer = 'adamw',
|
10 |
+
max_lr = 5e-5,
|
11 |
+
num_gpus = torch.cuda.device_count(),
|
12 |
+
max_input_size = 2 ** 11,
|
13 |
+
lr_schedule_fn = "linear",
|
14 |
+
warmup_steps = 500,
|
15 |
+
freeze_layers = 0,
|
16 |
+
emb_extract = False,
|
17 |
+
max_cells = 1000,
|
18 |
+
emb_layer = 0,
|
19 |
+
emb_filter = None,
|
20 |
+
emb_dir = 'embeddings',
|
21 |
+
overwrite = True,
|
22 |
+
label = "cell_type",
|
23 |
+
data_filter = None,
|
24 |
+
forward_batch = 200, model_location = None,
|
25 |
+
skip_training = False,
|
26 |
+
sample_data = 1,
|
27 |
+
inference = False,
|
28 |
+
optimize_hyperparameters = False,
|
29 |
+
output_dir = None):
|
30 |
+
|
31 |
+
'''
|
32 |
+
Primary Parameters
|
33 |
+
-------------------
|
34 |
+
dataset: path
|
35 |
+
Path to fine-tuning/testing dataset for training
|
36 |
+
|
37 |
+
model_location: path
|
38 |
+
Path to location of existing model to use for inference and embedding extraction
|
39 |
+
|
40 |
+
pretrained_model: path
|
41 |
+
Path to pretrained GeneFormer 30M model before fine-tuning
|
42 |
+
|
43 |
+
inference: bool
|
44 |
+
Chooses whether to perform inference (which causes the function to return the list of similarities). Defaults to False
|
45 |
+
|
46 |
+
skip_training: bool
|
47 |
+
Chooses whether to skip training the model. Defaults to False
|
48 |
+
|
49 |
+
emb_extract: bool
|
50 |
+
Choose whether to extract embeddings and calculate similarities. Defaults to True
|
51 |
+
|
52 |
+
optimize_hyperparameters: bool
|
53 |
+
Choose whether to optimize model hyperparamters. Defaults to False
|
54 |
+
label: string
|
55 |
+
The label string in the formatted dataset that contains true class labels. Defaults to "label"
|
56 |
+
|
57 |
+
Customization Parameters
|
58 |
+
-------------------
|
59 |
+
|
60 |
+
dataset_split: str
|
61 |
+
How the dataset should be partitioned (if at all), and what ID should be used for partitioning
|
62 |
+
|
63 |
+
data_filter: list
|
64 |
+
(For embeddings and inference) Runs analysis subsets of the dataset by the ID defined by dataset_split
|
65 |
+
|
66 |
+
label: str
|
67 |
+
What feature should be read as a classification label
|
68 |
+
|
69 |
+
emb_layer: int
|
70 |
+
What layer embeddings should be extracted and compared from.
|
71 |
+
|
72 |
+
emb_filter: ['cell1', 'cell2'...]
|
73 |
+
Allows user to narrow down range of cells that embeddings will be extracted from.
|
74 |
+
|
75 |
+
max_cells: int
|
76 |
+
How many embeddings from cells should be extracted.
|
77 |
+
|
78 |
+
freeze_layers: int
|
79 |
+
Number of layers should be permanently frozen during fine-tuning (starting from the first layer, 4 brings it up to the pretrained model).
|
80 |
+
|
81 |
+
sample_data: float
|
82 |
+
What proportion of the HF dataset should be used
|
83 |
+
|
84 |
+
'''
|
85 |
+
|
86 |
+
# Gene Classifier
|
87 |
+
def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv",
|
88 |
+
genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
|
89 |
+
corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
|
90 |
+
max_input_size = 2 ** 11,
|
91 |
+
max_lr = 5e-5,
|
92 |
+
freeze_layers = 4,
|
93 |
+
num_gpus = 1,
|
94 |
+
num_proc = os.cpu_count(),
|
95 |
+
geneformer_batch_size = 9,
|
96 |
+
epochs = 1,
|
97 |
+
filter_dataset = 50_000,
|
98 |
+
emb_extract = True,
|
99 |
+
emb_layer = 0,
|
100 |
+
forward_batch = 200,
|
101 |
+
filter_data = None,
|
102 |
+
inference = False,
|
103 |
+
k_validate = True,
|
104 |
+
model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/",
|
105 |
+
skip_training = False,
|
106 |
+
emb_dir = 'gene_emb',
|
107 |
+
output_dir = None,
|
108 |
+
max_cells = 1000,
|
109 |
+
num_cpus = os.cpu_count()):
|
110 |
+
|
111 |
+
""""
|
112 |
+
Primary Parameters
|
113 |
+
-----------
|
114 |
+
|
115 |
+
gene_info: path
|
116 |
+
Path to gene mappings
|
117 |
+
|
118 |
+
corpus_30M: path
|
119 |
+
Path to 30M Gene Corpus
|
120 |
+
|
121 |
+
model: path
|
122 |
+
Path to pretrained GeneFormer model
|
123 |
+
|
124 |
+
genes: path
|
125 |
+
Path to csv file containing different columns of genes and the column labels
|
126 |
+
|
127 |
+
inference: bool
|
128 |
+
Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
|
129 |
+
|
130 |
+
k_validate: bool
|
131 |
+
Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
|
132 |
+
|
133 |
+
skip_training: bool
|
134 |
+
Whether the model should skip the training portion. Defaults to False
|
135 |
+
|
136 |
+
emb_extract: bool
|
137 |
+
WHether the model should extract embeddings for a given gene (WIP)
|
138 |
+
|
139 |
+
|
140 |
+
Customization Parameters
|
141 |
+
-----------
|
142 |
+
|
143 |
+
freeze_layers: int
|
144 |
+
Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
|
145 |
+
|
146 |
+
filter_dataset: int
|
147 |
+
Number of cells to filter from 30M dataset. Default is 50_000
|
148 |
+
|
149 |
+
emb_layer: int
|
150 |
+
What layer embeddings are extracted from. Default is 4
|
151 |
+
|
152 |
+
filter_data: str, list
|
153 |
+
Filters down embeddings to a single category. Default is None
|
154 |
+
|
155 |
+
|
156 |
+
"""
|
gene_embclasses.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
[2, {0: 0, 1: 0}]
|
gene_embdataset.pk
ADDED
Binary file (1.76 kB). View file
|
|