|
import os |
|
import sys |
|
GPU_NUMBER = [0] |
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER]) |
|
os.environ["NCCL_DEBUG"] = "INFO" |
|
|
|
|
|
from sklearn.model_selection import train_test_split |
|
import datetime |
|
import subprocess |
|
from pathlib import Path |
|
import math |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pickle |
|
import pandas as pd |
|
from datasets import load_from_disk, Dataset |
|
from sklearn import preprocessing |
|
from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve |
|
from sklearn.model_selection import StratifiedKFold |
|
import torch |
|
from transformers import BertForTokenClassification |
|
from transformers import Trainer |
|
from transformers.training_args import TrainingArguments |
|
from tqdm.notebook import tqdm |
|
from sklearn.metrics import roc_curve, roc_auc_score |
|
from geneformer import DataCollatorForGeneClassification, EmbExtractor |
|
from geneformer.pretrainer import token_dictionary |
|
import ast |
|
import torch.nn.functional as F |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
from geneformer import TranscriptomeTokenizer |
|
|
|
def vote(logit_pair): |
|
a, b = logit_pair |
|
if a > b: |
|
return 0 |
|
elif b > a: |
|
return 1 |
|
elif a == b: |
|
return "tie" |
|
|
|
def py_softmax(vector): |
|
e = np.exp(vector) |
|
return e / e.sum() |
|
|
|
|
|
def similarity(tensor1, tensor2, cosine = True): |
|
if cosine == False: |
|
if tensor1.ndimension() > 1: |
|
tensor1 = tensor1.view(1, -1) |
|
if tensor2.ndimension() > 1: |
|
tensor2 = tensor2.view(1, -1) |
|
dot_product = torch.matmul(tensor1, tensor2) |
|
norm_tensor1 = torch.norm(tensor1) |
|
norm_tensor2 = torch.norm(tensor2) |
|
epsilon = 1e-8 |
|
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon) |
|
similarity = (similarity.item() + 1)/2 |
|
else: |
|
if tensor1.shape != tensor2.shape: |
|
raise ValueError("Input tensors must have the same shape.") |
|
|
|
|
|
dot_product = torch.dot(tensor1, tensor2) |
|
norm_tensor1 = torch.norm(tensor1) |
|
norm_tensor2 = torch.norm(tensor2) |
|
|
|
|
|
epsilon = 1e-8 |
|
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon) |
|
|
|
return similarity.item() |
|
|
|
|
|
def plot_similarity_heatmap(similarities): |
|
classes = list(similarities.keys()) |
|
classlen = len(classes) |
|
arr = np.zeros((classlen, classlen)) |
|
for i, c in enumerate(classes): |
|
for j, cc in enumerate(classes): |
|
if cc == c: |
|
val = 1.0 |
|
else: |
|
val = similarities[c][cc] |
|
arr[i][j] = val |
|
|
|
plt.figure(figsize=(8, 6)) |
|
plt.imshow(arr, cmap='inferno', vmin=0, vmax=1) |
|
plt.colorbar() |
|
plt.xticks(np.arange(classlen), classes, rotation = 45, ha = 'right') |
|
plt.yticks(np.arange(classlen), classes) |
|
plt.title("Similarity Heatmap") |
|
plt.savefig("similarity_heatmap.png") |
|
|
|
|
|
def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt): |
|
wts = [count/sum(all_tpr_wt) for count in all_tpr_wt] |
|
|
|
all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)] |
|
mean_tpr = np.sum(all_weighted_tpr, axis=0) |
|
mean_tpr[-1] = 1.0 |
|
all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)] |
|
roc_auc = np.sum(all_weighted_roc_auc) |
|
roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts)) |
|
return mean_tpr, roc_auc, roc_auc_sd |
|
|
|
def validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc, num_labels, pre_model): |
|
|
|
num_classes = len(set(labels)) |
|
mean_fpr = np.linspace(0, 1, 100) |
|
|
|
|
|
targets_train, targets_eval, labels_train, labels_eval = train_test_split(targets, labels ,test_size=0.25, shuffle=True) |
|
label_dict_train = dict(zip(targets_train, labels_train)) |
|
label_dict_eval = dict(zip(targets_eval, labels_eval)) |
|
|
|
|
|
def if_contains_train_label(example): |
|
a = label_dict_train.keys() |
|
b = example['input_ids'] |
|
return not set(a).isdisjoint(b) |
|
|
|
def if_contains_eval_label(example): |
|
a = label_dict_eval.keys() |
|
b = example['input_ids'] |
|
return not set(a).isdisjoint(b) |
|
|
|
|
|
print(f"Filtering training data") |
|
trainset = data.filter(if_contains_train_label, num_proc=num_proc) |
|
print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n") |
|
print(f"Filtering evalation data") |
|
evalset = data.filter(if_contains_eval_label, num_proc=num_proc) |
|
print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n") |
|
|
|
|
|
training_size = min(subsample_size, len(trainset)) |
|
trainset_min = trainset.select([i for i in range(training_size)]) |
|
eval_size = min(training_size, len(evalset)) |
|
half_training_size = round(eval_size/2) |
|
evalset_train_min = evalset.select([i for i in range(half_training_size)]) |
|
evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)]) |
|
|
|
|
|
def generate_train_labels(example): |
|
example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]] |
|
return example |
|
|
|
def generate_eval_labels(example): |
|
example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]] |
|
return example |
|
|
|
|
|
print(f"Labeling training data") |
|
trainset_labeled = trainset_min.map(generate_train_labels) |
|
print(f"Labeling evaluation data") |
|
evalset_train_labeled = evalset_train_min.map(generate_eval_labels) |
|
print(f"Labeling evaluation OOS data") |
|
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels) |
|
|
|
|
|
model = BertForTokenClassification.from_pretrained( |
|
pre_model, |
|
num_labels=num_labels, |
|
output_attentions = False, |
|
output_hidden_states = False, |
|
) |
|
if freeze_layers is not None: |
|
modules_to_freeze = model.bert.encoder.layer[:freeze_layers] |
|
for module in modules_to_freeze: |
|
for param in module.parameters(): |
|
param.requires_grad = False |
|
|
|
model = model.to(device) |
|
|
|
|
|
training_args["output_dir"] = output_dir |
|
training_args_init = TrainingArguments(**training_args) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args_init, |
|
data_collator=DataCollatorForGeneClassification(), |
|
train_dataset=trainset_labeled, |
|
eval_dataset=evalset_train_labeled, |
|
) |
|
|
|
|
|
trainer.train() |
|
trainer.save_model(output_dir) |
|
|
|
fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 200, mean_fpr) |
|
auc_score = auc(fpr, tpr) |
|
|
|
return fpr, tpr, auc_score |
|
|
|
|
|
def cross_validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc, num_labels, pre_model): |
|
|
|
|
|
model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin") |
|
|
|
|
|
|
|
device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
num_classes = len(set(labels)) |
|
mean_fpr = np.linspace(0, 1, 100) |
|
all_tpr = [] |
|
all_roc_auc = [] |
|
all_tpr_wt = [] |
|
label_dicts = [] |
|
confusion = np.zeros((num_classes,num_classes)) |
|
|
|
|
|
skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True) |
|
|
|
iteration_num = 0 |
|
for train_index, eval_index in tqdm(skf.split(targets, labels)): |
|
if len(labels) > 500: |
|
print("early stopping activated due to large # of training examples") |
|
if iteration_num == 3: |
|
break |
|
|
|
print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n") |
|
|
|
|
|
targets_train, targets_eval = targets[train_index], targets[eval_index] |
|
labels_train, labels_eval = labels[train_index], labels[eval_index] |
|
label_dict_train = dict(zip(targets_train, labels_train)) |
|
label_dict_eval = dict(zip(targets_eval, labels_eval)) |
|
label_dicts += (iteration_num, targets_train, targets_eval, labels_train, labels_eval) |
|
|
|
|
|
def if_contains_train_label(example): |
|
a = label_dict_train.keys() |
|
b = example['input_ids'] |
|
|
|
return not set(a).isdisjoint(b) |
|
|
|
def if_contains_eval_label(example): |
|
a = label_dict_eval.keys() |
|
b = example['input_ids'] |
|
|
|
return not set(a).isdisjoint(b) |
|
|
|
|
|
print(f"Filtering training data") |
|
trainset = data.filter(if_contains_train_label, num_proc=num_proc) |
|
print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n") |
|
print(f"Filtering evalation data") |
|
evalset = data.filter(if_contains_eval_label, num_proc=num_proc) |
|
print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n") |
|
|
|
|
|
training_size = min(subsample_size, len(trainset)) |
|
trainset_min = trainset.select([i for i in range(training_size)]) |
|
eval_size = min(training_size, len(evalset)) |
|
half_training_size = round(eval_size/2) |
|
evalset_train_min = evalset.select([i for i in range(half_training_size)]) |
|
evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)]) |
|
|
|
|
|
def generate_train_labels(example): |
|
example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]] |
|
return example |
|
|
|
def generate_eval_labels(example): |
|
example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]] |
|
return example |
|
|
|
|
|
print(f"Labeling training data") |
|
trainset_labeled = trainset_min.map(generate_train_labels) |
|
print(f"Labeling evaluation data") |
|
evalset_train_labeled = evalset_train_min.map(generate_eval_labels) |
|
print(f"Labeling evaluation OOS data") |
|
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels) |
|
|
|
|
|
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}") |
|
ksplit_model_dir = os.path.join(ksplit_output_dir, "models/") |
|
|
|
|
|
model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin") |
|
|
|
|
|
|
|
|
|
subprocess.call(f'mkdir -p {ksplit_output_dir}', shell=True) |
|
subprocess.call(f'mkdir -p {ksplit_model_dir}', shell=True) |
|
|
|
|
|
model = BertForTokenClassification.from_pretrained( |
|
pre_model, |
|
num_labels=num_labels, |
|
output_attentions = False, |
|
output_hidden_states = False, |
|
) |
|
if freeze_layers is not None: |
|
modules_to_freeze = model.bert.encoder.layer[:freeze_layers] |
|
for module in modules_to_freeze: |
|
for param in module.parameters(): |
|
param.requires_grad = False |
|
|
|
model = model.to(device) |
|
|
|
|
|
training_args["output_dir"] = ksplit_output_dir |
|
training_args_init = TrainingArguments(**training_args) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args_init, |
|
data_collator=DataCollatorForGeneClassification(), |
|
train_dataset=trainset_labeled, |
|
eval_dataset=evalset_train_labeled |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
trainer.save_model(ksplit_model_dir) |
|
|
|
|
|
fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 200, mean_fpr) |
|
|
|
|
|
confusion = confusion + conf_mat |
|
all_tpr.append(interp_tpr) |
|
all_roc_auc.append(auc(fpr, tpr)) |
|
|
|
all_tpr_wt.append(len(tpr)) |
|
|
|
iteration_num = iteration_num + 1 |
|
|
|
|
|
mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt) |
|
return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts |
|
|
|
|
|
def compute_metrics(pred): |
|
labels = pred.label_ids |
|
preds = pred.predictions.argmax(-1) |
|
|
|
acc = accuracy_score(labels, preds) |
|
macro_f1 = f1_score(labels, preds, average='macro') |
|
|
|
return { |
|
'accuracy': acc, |
|
'macro_f1': macro_f1 |
|
} |
|
|
|
|
|
def plot_ROC(bundled_data, title): |
|
plt.figure() |
|
lw = 2 |
|
for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data: |
|
plt.plot(mean_fpr, mean_tpr, color=color, |
|
lw=lw, label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(sample, roc_auc, roc_auc_sd)) |
|
|
|
plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--') |
|
plt.xlim([0.0, 1.0]) |
|
plt.ylim([0.0, 1.05]) |
|
plt.xlabel('False Positive Rate') |
|
plt.ylabel('True Positive Rate') |
|
plt.title(title) |
|
plt.legend(loc="lower right") |
|
plt.savefig("ROC.png") |
|
|
|
return mean_fpr, mean_tpr, roc_auc |
|
|
|
|
|
def plot_confusion_matrix(classes_list, conf_mat, title): |
|
display_labels = [] |
|
i = 0 |
|
for label in classes_list: |
|
display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:,i]))] |
|
i = i + 1 |
|
display = ConfusionMatrixDisplay(confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"), |
|
display_labels=display_labels) |
|
display.plot(cmap="Blues",values_format=".2g") |
|
plt.title(title) |
|
plt.savefig("CM.png") |
|
|
|
|
|
|
|
def find_largest_div(N, K): |
|
rem = N % K |
|
if(rem == 0): |
|
return N |
|
else: |
|
return N - rem |
|
|
|
def preprocess_classifier_batch(cell_batch, max_len): |
|
if max_len == None: |
|
max_len = max([len(i) for i in cell_batch["input_ids"]]) |
|
def pad_label_example(example): |
|
example["labels"] = np.pad(example["labels"], |
|
(0, max_len-len(example["input_ids"])), |
|
mode='constant', constant_values=-100) |
|
example["input_ids"] = np.pad(example["input_ids"], |
|
(0, max_len-len(example["input_ids"])), |
|
mode='constant', constant_values=token_dictionary.get("<pad>")) |
|
example["attention_mask"] = (example["input_ids"] != token_dictionary.get("<pad>")).astype(int) |
|
return example |
|
padded_batch = cell_batch.map(pad_label_example) |
|
return padded_batch |
|
|
|
|
|
def classifier_predict(model, evalset, forward_batch_size, mean_fpr): |
|
predict_logits = [] |
|
predict_labels = [] |
|
model.to('cpu') |
|
model.eval() |
|
|
|
|
|
evalset_len = len(evalset) |
|
max_divisible = find_largest_div(evalset_len, forward_batch_size) |
|
if len(evalset) - max_divisible == 1: |
|
evalset_len = max_divisible |
|
|
|
max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"]) |
|
|
|
for i in range(0, evalset_len, forward_batch_size): |
|
max_range = min(i+forward_batch_size, evalset_len) |
|
batch_evalset = evalset.select([i for i in range(i, max_range)]) |
|
padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len) |
|
padded_batch.set_format(type="torch") |
|
|
|
input_data_batch = padded_batch["input_ids"] |
|
attn_msk_batch = padded_batch["attention_mask"] |
|
label_batch = padded_batch["labels"] |
|
with torch.no_grad(): |
|
input_ids = input_data_batch |
|
attn_mask = attn_msk_batch |
|
labels = label_batch |
|
outputs = model( |
|
|
|
input_ids = input_ids, |
|
attention_mask = attn_mask, |
|
labels = labels |
|
) |
|
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))] |
|
predict_labels += [torch.squeeze(label_batch.to("cpu"))] |
|
|
|
logits_by_cell = torch.cat(predict_logits) |
|
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2]) |
|
labels_by_cell = torch.cat(predict_labels) |
|
all_labels = torch.flatten(labels_by_cell) |
|
logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100] |
|
y_pred = [vote(item[0]) for item in logit_label_paired] |
|
y_true = [item[1] for item in logit_label_paired] |
|
logits_list = [item[0] for item in logit_label_paired] |
|
|
|
y_score = [py_softmax(item)[1] for item in logits_list] |
|
conf_mat = confusion_matrix(y_true, y_pred) |
|
fpr, tpr, _ = roc_curve(y_true, y_score) |
|
|
|
plt.plot(fpr, tpr) |
|
plt.xlim([0.0, 1.0]) |
|
plt.ylim([0.0, 1.05]) |
|
plt.xlabel('False Positive Rate') |
|
plt.ylabel('True Positive Rate') |
|
plt.title('ROC') |
|
plt.show() |
|
|
|
interp_tpr = np.interp(mean_fpr, fpr, tpr) |
|
interp_tpr[0] = 0.0 |
|
return fpr, tpr, interp_tpr, conf_mat |
|
|
|
def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv", genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", |
|
corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.', |
|
max_input_size = 2 ** 11, max_lr = 5e-5, freeze_layers = 4, num_gpus = 1, num_proc = os.cpu_count(), geneformer_batch_size = 9, epochs = 1, filter_dataset = 50_000, |
|
emb_extract = True, emb_layer = 0, forward_batch = 200, filter_data = None, inference = False, k_validate = True, model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/", skip_training = False, emb_dir = 'gene_emb', output_dir = None, max_cells = 1000, num_cpus = os.cpu_count()): |
|
|
|
|
|
"""" |
|
Primary Parameters |
|
----------- |
|
|
|
gene_info: path |
|
Path to gene mappings |
|
|
|
corpus_30M: path |
|
Path to 30M Gene Corpus |
|
|
|
model: path |
|
Path to pretrained GeneFormer model |
|
|
|
genes: path |
|
Path to csv file containing different columns of genes and the column labels |
|
|
|
inference: bool |
|
Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False |
|
|
|
k_validate: bool |
|
Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True |
|
|
|
skip_training: bool |
|
Whether the model should skip the training portion. Defaults to False |
|
|
|
emb_extract: bool |
|
WHether the model should extract embeddings for a given gene (WIP) |
|
|
|
|
|
Customization Parameters |
|
----------- |
|
|
|
freeze_layers: int |
|
Freezes x number of layers from the model. Default is 4 (2 non-frozen layers) |
|
|
|
filter_dataset: int |
|
Number of cells to filter from 30M dataset. Default is 50_000 |
|
|
|
emb_layer: int |
|
What layer embeddings are extracted from. Default is 4 |
|
|
|
filter_data: str, list |
|
Filters down embeddings to a single category. Default is None |
|
|
|
|
|
""" |
|
|
|
|
|
gene_info = pd.read_csv(gene_info, index_col=0) |
|
labels = gene_info.columns |
|
|
|
|
|
gene_id_type_dict = dict(zip(gene_info["ensembl_id"],gene_info["gene_type"])) |
|
gene_name_id_dict = dict(zip(gene_info["gene_name"],gene_info["ensembl_id"])) |
|
gene_id_name_dict = {v: k for k,v in gene_name_id_dict.items()} |
|
|
|
|
|
def prep_inputs(label_store, id_type): |
|
target_list = [] |
|
if id_type == "gene_name": |
|
for key in list(label_store.keys()): |
|
targets = [gene_name_id_dict[gene] for gene in label_store[key] if gene_name_id_dict.get(gene) in token_dictionary] |
|
targets_id = [token_dictionary[gene] for gene in targets] |
|
target_list.append(targets_id) |
|
elif id_type == "ensembl_id": |
|
for key in list(label_store.keys()): |
|
targets = [gene for gene in label_store[key] if gene in token_dictionary] |
|
targets_id = [token_dictionary[gene] for gene in targets] |
|
target_list.append(targets_id) |
|
|
|
targets, labels = [], [] |
|
for targ in target_list: |
|
targets = targets + targ |
|
targets = np.array(targets) |
|
for num, targ in enumerate(target_list): |
|
label = [num]*len(targ) |
|
labels = labels + label |
|
labels = np.array(labels) |
|
unique_labels = num + 1 |
|
|
|
nsplits = min(5, min([len(targ) for targ in target_list])-1) |
|
assert nsplits > 2 |
|
|
|
return targets, labels, nsplits, unique_labels |
|
|
|
if skip_training == False: |
|
|
|
gene_classes = pd.read_csv(genes, header=0) |
|
if filter_data == None: |
|
labels = gene_classes.columns |
|
else: |
|
if isinstance(filter_data, list): |
|
labels = filter_data |
|
else: |
|
labels = [filter_data] |
|
label_store = {} |
|
|
|
|
|
decode = {i:labels[i] for i in range(len(labels))} |
|
|
|
for label in labels: |
|
label_store[label] = gene_classes[label].dropna() |
|
|
|
targets, labels, nsplits, unique_labels = prep_inputs(label_store, "ensembl_id") |
|
|
|
|
|
|
|
|
|
train_dataset=load_from_disk(corpus_30M) |
|
shuffled_train_dataset = train_dataset.shuffle(seed=42) |
|
subsampled_train_dataset = shuffled_train_dataset.select([i for i in range(filter_dataset)]) |
|
lr_schedule_fn = "linear" |
|
warmup_steps = 500 |
|
optimizer = "adamw" |
|
subsample_size = 10_000 |
|
|
|
training_args = { |
|
"learning_rate": max_lr, |
|
"do_train": True, |
|
"evaluation_strategy": "no", |
|
"save_strategy": "epoch", |
|
"logging_steps": 10, |
|
"group_by_length": True, |
|
"length_column_name": "length", |
|
"disable_tqdm": False, |
|
"lr_scheduler_type": lr_schedule_fn, |
|
"warmup_steps": warmup_steps, |
|
"weight_decay": 0.001, |
|
"per_device_train_batch_size": geneformer_batch_size, |
|
"per_device_eval_batch_size": geneformer_batch_size, |
|
"num_train_epochs": epochs, |
|
} |
|
|
|
|
|
current_date = datetime.datetime.now() |
|
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}" |
|
|
|
if output_dir == None: |
|
training_output_dir = Path(f"{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/") |
|
else: |
|
training_output_dir = Path(output_dir) |
|
|
|
|
|
subprocess.call(f'mkdir -p {training_output_dir}', shell=True) |
|
|
|
|
|
num_classes = len(set(labels)) |
|
info_list = [num_classes, decode] |
|
|
|
with open(training_output_dir / 'classes.txt', 'w') as f: |
|
f.write(str(info_list)) |
|
|
|
subsampled_train_dataset.save_to_disk(output_dir / 'dataset') |
|
|
|
if k_validate == True: |
|
ksplit_model ="ksplit0/models" |
|
ksplit_model_test = os.path.join(training_output_dir, ksplit_model) |
|
|
|
|
|
|
|
all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts = cross_validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1, unique_labels, model) |
|
|
|
bundled_data = [] |
|
bundled_data += [(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red")] |
|
graph_title = " ".join([i + ' vs' if count < len(label_store) - 1 else i for count, i in enumerate(label_store)]) |
|
fpr, tpr, auc = plot_ROC(bundled_data, 'Dosage Sensitive vs Insensitive TFs') |
|
print(auc) |
|
|
|
plot_confusion_matrix(label_store, confusion, "Geneformer") |
|
else: |
|
fpr, tpr, auc = validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1, unique_labels, model) |
|
print(auc) |
|
|
|
if inference == True: |
|
|
|
gene_classes = pd.read_csv(genes, header=0) |
|
targets = [] |
|
for column in gene_classes.columns: |
|
targets += list(gene_classes[column]) |
|
tokens = [] |
|
for target in targets: |
|
try: |
|
tokens.append(token_dictionary[target]) |
|
except: |
|
tokens.append(0) |
|
|
|
targets = torch.LongTensor([tokens]) |
|
|
|
|
|
with open(f'{model_location}classes.txt', 'r') as f: |
|
info_list = ast.literal_eval(f.read()) |
|
num_classes = info_list[0] |
|
labels = info_list[1] |
|
|
|
model = BertForTokenClassification.from_pretrained( |
|
model_location, |
|
num_labels=num_classes, |
|
output_attentions = False, |
|
output_hidden_states = False, |
|
local_files_only = True |
|
) |
|
if freeze_layers is not None: |
|
modules_to_freeze = model.bert.encoder.layer[:freeze_layers] |
|
for module in modules_to_freeze: |
|
for param in module.parameters(): |
|
param.requires_grad = False |
|
|
|
model = model.to(device) |
|
|
|
|
|
predictions = F.softmax(model(targets.to(device))["logits"], dim = -1).argmax(-1)[0] |
|
predictions = [labels[int(pred)] for pred in predictions] |
|
|
|
return predictions |
|
|
|
|
|
if emb_extract == True: |
|
with open(f'{model_location}/classes.txt', 'r') as f: |
|
data = ast.literal_eval(f.read()) |
|
num_classes = data[0] |
|
decode = data[1] |
|
|
|
gene_classes = pd.read_csv(genes, header=0) |
|
labels = gene_classes.columns |
|
tokenize = TranscriptomeTokenizer() |
|
|
|
label_dict = {} |
|
for label in labels: |
|
genes = gene_classes[label] |
|
tokenized_genes = [] |
|
for gene in genes: |
|
try: |
|
tokenized_genes.append(tokenize.gene_token_dict[gene]) |
|
except: |
|
continue |
|
label_dict[label] = tokenized_genes |
|
|
|
embex = EmbExtractor(model_type="GeneClassifier", num_classes=num_classes, emb_mode = "gene", |
|
filter_data=None, max_ncells=max_cells, emb_layer=emb_layer, |
|
emb_label=label_dict, labels_to_plot=list(labels), forward_batch_size=forward_batch, nproc=num_cpus) |
|
|
|
|
|
subprocess.call(f'mkdir -p {emb_dir}', shell = True) |
|
|
|
embs = embex.extract_embs(model_directory = model_location, input_data_file = model_location / 'dataset', output_directory = emb_dir, output_prefix = f"{label}_embbeddings") |
|
|
|
emb_dict = {label:[] for label in list(set(labels))} |
|
similarities = {key:{} for key in list(emb_dict.keys())} |
|
|
|
for column in embs.columns: |
|
remaining_cols = [k for k in embs.columns if k != column] |
|
for k in remaining_cols: |
|
embedding = torch.Tensor(embs[k]) |
|
sim = similarity(torch.Tensor(embs[column]), embedding, cosine = True) |
|
similarities[column][k] = sim |
|
|
|
plot_similarity_heatmap(similarities) |
|
print(similarities) |
|
|
|
return similarities |
|
|
|
if __name__ == '__main__': |
|
classify_genes(k_validate = False, inference = False, skip_training = False, emb_extract = True, output_dir = Path('gene_emb'), model_location = Path('gene_emb'), epochs = 5, gene_info = "../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_info_table.csv", genes = "../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", corpus_30M = "../GeneFormer_repo/Genecorpus-30M/genecorpus_30M_2048.dataset/") |
|
|
|
|