Geneformer / Gene_classifier.py
tigerdeF's picture
Upload 15 files
e562c0c
raw
history blame
31 kB
import os
import sys
GPU_NUMBER = [0] # CHANGE WITH MULTIGPU
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
# imports
from sklearn.model_selection import train_test_split
import datetime
import subprocess
from pathlib import Path
import math
import matplotlib.pyplot as plt
import numpy as np
import pickle
import pandas as pd
from datasets import load_from_disk, Dataset
from sklearn import preprocessing
from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve
from sklearn.model_selection import StratifiedKFold
import torch
from transformers import BertForTokenClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from tqdm.notebook import tqdm
from sklearn.metrics import roc_curve, roc_auc_score
from geneformer import DataCollatorForGeneClassification, EmbExtractor
from geneformer.pretrainer import token_dictionary
import ast
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from geneformer import TranscriptomeTokenizer
def vote(logit_pair):
a, b = logit_pair
if a > b:
return 0
elif b > a:
return 1
elif a == b:
return "tie"
def py_softmax(vector):
e = np.exp(vector)
return e / e.sum()
# Identifies cosine similarity between two embeddings. 0 is perfectly dissimilar and 1 is perfectly similar
def similarity(tensor1, tensor2, cosine = True):
if cosine == False:
if tensor1.ndimension() > 1:
tensor1 = tensor1.view(1, -1)
if tensor2.ndimension() > 1:
tensor2 = tensor2.view(1, -1)
dot_product = torch.matmul(tensor1, tensor2)
norm_tensor1 = torch.norm(tensor1)
norm_tensor2 = torch.norm(tensor2)
epsilon = 1e-8
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
similarity = (similarity.item() + 1)/2
else:
if tensor1.shape != tensor2.shape:
raise ValueError("Input tensors must have the same shape.")
# Compute cosine similarity using PyTorch's dot product function
dot_product = torch.dot(tensor1, tensor2)
norm_tensor1 = torch.norm(tensor1)
norm_tensor2 = torch.norm(tensor2)
# Avoid division by zero by adding a small epsilon
epsilon = 1e-8
similarity = dot_product / (norm_tensor1 * norm_tensor2 + epsilon)
return similarity.item()
# Plots heatmap between different classes/labels
def plot_similarity_heatmap(similarities):
classes = list(similarities.keys())
classlen = len(classes)
arr = np.zeros((classlen, classlen))
for i, c in enumerate(classes):
for j, cc in enumerate(classes):
if cc == c:
val = 1.0
else:
val = similarities[c][cc]
arr[i][j] = val
plt.figure(figsize=(8, 6))
plt.imshow(arr, cmap='inferno', vmin=0, vmax=1)
plt.colorbar()
plt.xticks(np.arange(classlen), classes, rotation = 45, ha = 'right')
plt.yticks(np.arange(classlen), classes)
plt.title("Similarity Heatmap")
plt.savefig("similarity_heatmap.png")
# get cross-validated mean and sd metrics
def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
wts = [count/sum(all_tpr_wt) for count in all_tpr_wt]
all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)]
mean_tpr = np.sum(all_weighted_tpr, axis=0)
mean_tpr[-1] = 1.0
all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]
roc_auc = np.sum(all_weighted_roc_auc)
roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))
return mean_tpr, roc_auc, roc_auc_sd
def validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc, num_labels, pre_model):
# initiate eval metrics to return
num_classes = len(set(labels))
mean_fpr = np.linspace(0, 1, 100)
# create 80/20 train/eval splits
targets_train, targets_eval, labels_train, labels_eval = train_test_split(targets, labels ,test_size=0.25, shuffle=True)
label_dict_train = dict(zip(targets_train, labels_train))
label_dict_eval = dict(zip(targets_eval, labels_eval))
# function to filter by whether contains train or eval labels
def if_contains_train_label(example):
a = label_dict_train.keys()
b = example['input_ids']
return not set(a).isdisjoint(b)
def if_contains_eval_label(example):
a = label_dict_eval.keys()
b = example['input_ids']
return not set(a).isdisjoint(b)
# filter dataset for examples containing classes for this split
print(f"Filtering training data")
trainset = data.filter(if_contains_train_label, num_proc=num_proc)
print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n")
print(f"Filtering evalation data")
evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
# minimize to smaller training sample
training_size = min(subsample_size, len(trainset))
trainset_min = trainset.select([i for i in range(training_size)])
eval_size = min(training_size, len(evalset))
half_training_size = round(eval_size/2)
evalset_train_min = evalset.select([i for i in range(half_training_size)])
evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
# label conversion functions
def generate_train_labels(example):
example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]]
return example
def generate_eval_labels(example):
example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]]
return example
# label datasets
print(f"Labeling training data")
trainset_labeled = trainset_min.map(generate_train_labels)
print(f"Labeling evaluation data")
evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
print(f"Labeling evaluation OOS data")
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
# load model
model = BertForTokenClassification.from_pretrained(
pre_model,
num_labels=num_labels,
output_attentions = False,
output_hidden_states = False,
)
if freeze_layers is not None:
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
for module in modules_to_freeze:
for param in module.parameters():
param.requires_grad = False
model = model.to(device)
# add output directory to training args and initiate
training_args["output_dir"] = output_dir
training_args_init = TrainingArguments(**training_args)
# create the trainer
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=DataCollatorForGeneClassification(),
train_dataset=trainset_labeled,
eval_dataset=evalset_train_labeled,
)
# train the gene classifier
trainer.train()
trainer.save_model(output_dir)
fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 200, mean_fpr)
auc_score = auc(fpr, tpr)
return fpr, tpr, auc_score
# cross-validate gene classifier
def cross_validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc, num_labels, pre_model):
# check if output directory already written to
# ensure not overwriting previously saved model
model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin")
#if os.path.isfile(model_dir_test) == True:
# raise Exception("Model already saved to this directory.")
device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# initiate eval metrics to return
num_classes = len(set(labels))
mean_fpr = np.linspace(0, 1, 100)
all_tpr = []
all_roc_auc = []
all_tpr_wt = []
label_dicts = []
confusion = np.zeros((num_classes,num_classes))
# set up cross-validation splits
skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)
# train and evaluate
iteration_num = 0
for train_index, eval_index in tqdm(skf.split(targets, labels)):
if len(labels) > 500:
print("early stopping activated due to large # of training examples")
if iteration_num == 3:
break
print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n")
# generate cross-validation splits
targets_train, targets_eval = targets[train_index], targets[eval_index]
labels_train, labels_eval = labels[train_index], labels[eval_index]
label_dict_train = dict(zip(targets_train, labels_train))
label_dict_eval = dict(zip(targets_eval, labels_eval))
label_dicts += (iteration_num, targets_train, targets_eval, labels_train, labels_eval)
# function to filter by whether contains train or eval labels
def if_contains_train_label(example):
a = label_dict_train.keys()
b = example['input_ids']
return not set(a).isdisjoint(b)
def if_contains_eval_label(example):
a = label_dict_eval.keys()
b = example['input_ids']
return not set(a).isdisjoint(b)
# filter dataset for examples containing classes for this split
print(f"Filtering training data")
trainset = data.filter(if_contains_train_label, num_proc=num_proc)
print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n")
print(f"Filtering evalation data")
evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")
# minimize to smaller training sample
training_size = min(subsample_size, len(trainset))
trainset_min = trainset.select([i for i in range(training_size)])
eval_size = min(training_size, len(evalset))
half_training_size = round(eval_size/2)
evalset_train_min = evalset.select([i for i in range(half_training_size)])
evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
# label conversion functions
def generate_train_labels(example):
example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]]
return example
def generate_eval_labels(example):
example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]]
return example
# label datasets
print(f"Labeling training data")
trainset_labeled = trainset_min.map(generate_train_labels)
print(f"Labeling evaluation data")
evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
print(f"Labeling evaluation OOS data")
evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
# create output directories
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
ksplit_model_dir = os.path.join(ksplit_output_dir, "models/")
# ensure not overwriting previously saved model
model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin")
#if os.path.isfile(model_output_file) == True:
# raise Exception("Model already saved to this directory.")
# make training and model output directories
subprocess.call(f'mkdir -p {ksplit_output_dir}', shell=True)
subprocess.call(f'mkdir -p {ksplit_model_dir}', shell=True)
# load model
model = BertForTokenClassification.from_pretrained(
pre_model,
num_labels=num_labels,
output_attentions = False,
output_hidden_states = False,
)
if freeze_layers is not None:
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
for module in modules_to_freeze:
for param in module.parameters():
param.requires_grad = False
model = model.to(device)
# add output directory to training args and initiate
training_args["output_dir"] = ksplit_output_dir
training_args_init = TrainingArguments(**training_args)
# create the trainer
trainer = Trainer(
model=model,
args=training_args_init,
data_collator=DataCollatorForGeneClassification(),
train_dataset=trainset_labeled,
eval_dataset=evalset_train_labeled
)
# train the gene classifier
trainer.train()
# save model
trainer.save_model(ksplit_model_dir)
# evaluate model
fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 200, mean_fpr)
# append to tpr and roc lists
confusion = confusion + conf_mat
all_tpr.append(interp_tpr)
all_roc_auc.append(auc(fpr, tpr))
# append number of eval examples by which to weight tpr in averaged graphs
all_tpr_wt.append(len(tpr))
iteration_num = iteration_num + 1
# get overall metrics for cross-validation
mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt)
return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts
# Computes metrics
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
# calculate accuracy and macro f1 using sklearn's function
acc = accuracy_score(labels, preds)
macro_f1 = f1_score(labels, preds, average='macro')
return {
'accuracy': acc,
'macro_f1': macro_f1
}
# plot ROC curve
def plot_ROC(bundled_data, title):
plt.figure()
lw = 2
for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
plt.plot(mean_fpr, mean_tpr, color=color,
lw=lw, label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(sample, roc_auc, roc_auc_sd))
plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(title)
plt.legend(loc="lower right")
plt.savefig("ROC.png")
return mean_fpr, mean_tpr, roc_auc
# plot confusion matrix
def plot_confusion_matrix(classes_list, conf_mat, title):
display_labels = []
i = 0
for label in classes_list:
display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:,i]))]
i = i + 1
display = ConfusionMatrixDisplay(confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"),
display_labels=display_labels)
display.plot(cmap="Blues",values_format=".2g")
plt.title(title)
plt.savefig("CM.png")
# Function to find the largest number smaller
# than or equal to N that is divisible by k
def find_largest_div(N, K):
rem = N % K
if(rem == 0):
return N
else:
return N - rem
def preprocess_classifier_batch(cell_batch, max_len):
if max_len == None:
max_len = max([len(i) for i in cell_batch["input_ids"]])
def pad_label_example(example):
example["labels"] = np.pad(example["labels"],
(0, max_len-len(example["input_ids"])),
mode='constant', constant_values=-100)
example["input_ids"] = np.pad(example["input_ids"],
(0, max_len-len(example["input_ids"])),
mode='constant', constant_values=token_dictionary.get("<pad>"))
example["attention_mask"] = (example["input_ids"] != token_dictionary.get("<pad>")).astype(int)
return example
padded_batch = cell_batch.map(pad_label_example)
return padded_batch
# forward batch size is batch size for model inference (e.g. 200)
def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
predict_logits = []
predict_labels = []
model.to('cpu')
model.eval()
# ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
evalset_len = len(evalset)
max_divisible = find_largest_div(evalset_len, forward_batch_size)
if len(evalset) - max_divisible == 1:
evalset_len = max_divisible
max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
for i in range(0, evalset_len, forward_batch_size):
max_range = min(i+forward_batch_size, evalset_len)
batch_evalset = evalset.select([i for i in range(i, max_range)])
padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
padded_batch.set_format(type="torch")
input_data_batch = padded_batch["input_ids"]
attn_msk_batch = padded_batch["attention_mask"]
label_batch = padded_batch["labels"]
with torch.no_grad():
input_ids = input_data_batch
attn_mask = attn_msk_batch
labels = label_batch
outputs = model(
input_ids = input_ids,
attention_mask = attn_mask,
labels = labels
)
predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
predict_labels += [torch.squeeze(label_batch.to("cpu"))]
logits_by_cell = torch.cat(predict_logits)
all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
labels_by_cell = torch.cat(predict_labels)
all_labels = torch.flatten(labels_by_cell)
logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100]
y_pred = [vote(item[0]) for item in logit_label_paired]
y_true = [item[1] for item in logit_label_paired]
logits_list = [item[0] for item in logit_label_paired]
# probability of class 1
y_score = [py_softmax(item)[1] for item in logits_list]
conf_mat = confusion_matrix(y_true, y_pred)
fpr, tpr, _ = roc_curve(y_true, y_score)
# plot roc_curve for this split
plt.plot(fpr, tpr)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC')
plt.show()
# interpolate to graph
interp_tpr = np.interp(mean_fpr, fpr, tpr)
interp_tpr[0] = 0.0
return fpr, tpr, interp_tpr, conf_mat
def classify_genes(gene_info = "Genecorpus-30M/example_input_files/gene_info_table.csv", genes = "Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv",
corpus_30M = "Genecorpus-30M/genecorpus_30M_2048.dataset/", model = '.',
max_input_size = 2 ** 11, max_lr = 5e-5, freeze_layers = 4, num_gpus = 1, num_proc = os.cpu_count(), geneformer_batch_size = 9, epochs = 1, filter_dataset = 50_000,
emb_extract = True, emb_layer = 0, forward_batch = 200, filter_data = None, inference = False, k_validate = True, model_location = "230917_geneformer_GeneClassifier_dosageTF_L2048_B12_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/", skip_training = False, emb_dir = 'gene_emb', output_dir = None, max_cells = 1000, num_cpus = os.cpu_count()):
""""
Primary Parameters
-----------
gene_info: path
Path to gene mappings
corpus_30M: path
Path to 30M Gene Corpus
model: path
Path to pretrained GeneFormer model
genes: path
Path to csv file containing different columns of genes and the column labels
inference: bool
Whether the model should be used to run inference. If False, model will train with labeled data instead. Defaults to False
k_validate: bool
Whether the model should run k-fold validation or simply perform regular training/evaluate. Defaults to True
skip_training: bool
Whether the model should skip the training portion. Defaults to False
emb_extract: bool
WHether the model should extract embeddings for a given gene (WIP)
Customization Parameters
-----------
freeze_layers: int
Freezes x number of layers from the model. Default is 4 (2 non-frozen layers)
filter_dataset: int
Number of cells to filter from 30M dataset. Default is 50_000
emb_layer: int
What layer embeddings are extracted from. Default is 4
filter_data: str, list
Filters down embeddings to a single category. Default is None
"""
# table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)
gene_info = pd.read_csv(gene_info, index_col=0)
labels = gene_info.columns
# create dictionaries for corresponding attributes
gene_id_type_dict = dict(zip(gene_info["ensembl_id"],gene_info["gene_type"]))
gene_name_id_dict = dict(zip(gene_info["gene_name"],gene_info["ensembl_id"]))
gene_id_name_dict = {v: k for k,v in gene_name_id_dict.items()}
# function for preparing targets and labels
def prep_inputs(label_store, id_type):
target_list = []
if id_type == "gene_name":
for key in list(label_store.keys()):
targets = [gene_name_id_dict[gene] for gene in label_store[key] if gene_name_id_dict.get(gene) in token_dictionary]
targets_id = [token_dictionary[gene] for gene in targets]
target_list.append(targets_id)
elif id_type == "ensembl_id":
for key in list(label_store.keys()):
targets = [gene for gene in label_store[key] if gene in token_dictionary]
targets_id = [token_dictionary[gene] for gene in targets]
target_list.append(targets_id)
targets, labels = [], []
for targ in target_list:
targets = targets + targ
targets = np.array(targets)
for num, targ in enumerate(target_list):
label = [num]*len(targ)
labels = labels + label
labels = np.array(labels)
unique_labels = num + 1
nsplits = min(5, min([len(targ) for targ in target_list])-1)
assert nsplits > 2
return targets, labels, nsplits, unique_labels
if skip_training == False:
# preparing targets and labels for dosage sensitive vs insensitive TFs
gene_classes = pd.read_csv(genes, header=0)
if filter_data == None:
labels = gene_classes.columns
else:
if isinstance(filter_data, list):
labels = filter_data
else:
labels = [filter_data]
label_store = {}
# Dictionary for decoding labels
decode = {i:labels[i] for i in range(len(labels))}
for label in labels:
label_store[label] = gene_classes[label].dropna()
targets, labels, nsplits, unique_labels = prep_inputs(label_store, "ensembl_id")
# load training dataset
train_dataset=load_from_disk(corpus_30M)
shuffled_train_dataset = train_dataset.shuffle(seed=42)
subsampled_train_dataset = shuffled_train_dataset.select([i for i in range(filter_dataset)])
lr_schedule_fn = "linear"
warmup_steps = 500
optimizer = "adamw"
subsample_size = 10_000
training_args = {
"learning_rate": max_lr,
"do_train": True,
"evaluation_strategy": "no",
"save_strategy": "epoch",
"logging_steps": 10,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": False,
"lr_scheduler_type": lr_schedule_fn,
"warmup_steps": warmup_steps,
"weight_decay": 0.001,
"per_device_train_batch_size": geneformer_batch_size,
"per_device_eval_batch_size": geneformer_batch_size,
"num_train_epochs": epochs,
}
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
if output_dir == None:
training_output_dir = Path(f"{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/")
else:
training_output_dir = Path(output_dir)
# make output directory
subprocess.call(f'mkdir -p {training_output_dir}', shell=True)
# Places number of classes + in directory
num_classes = len(set(labels))
info_list = [num_classes, decode]
with open(training_output_dir / 'classes.txt', 'w') as f:
f.write(str(info_list))
subsampled_train_dataset.save_to_disk(output_dir / 'dataset')
if k_validate == True:
ksplit_model ="ksplit0/models"
ksplit_model_test = os.path.join(training_output_dir, ksplit_model)
#if os.path.isfile(ksplit_model_test) == True:
# raise Exception("Model already saved to this directory.")
# cross-validate gene classifier
all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts = cross_validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1, unique_labels, model)
bundled_data = []
bundled_data += [(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red")]
graph_title = " ".join([i + ' vs' if count < len(label_store) - 1 else i for count, i in enumerate(label_store)])
fpr, tpr, auc = plot_ROC(bundled_data, 'Dosage Sensitive vs Insensitive TFs')
print(auc)
# plot confusion matrix
plot_confusion_matrix(label_store, confusion, "Geneformer")
else:
fpr, tpr, auc = validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1, unique_labels, model)
print(auc)
if inference == True:
# preparing targets and labels for dosage sensitive vs insensitive TFs
gene_classes = pd.read_csv(genes, header=0)
targets = []
for column in gene_classes.columns:
targets += list(gene_classes[column])
tokens = []
for target in targets:
try:
tokens.append(token_dictionary[target])
except:
tokens.append(0)
targets = torch.LongTensor([tokens])
with open(f'{model_location}classes.txt', 'r') as f:
info_list = ast.literal_eval(f.read())
num_classes = info_list[0]
labels = info_list[1]
model = BertForTokenClassification.from_pretrained(
model_location,
num_labels=num_classes,
output_attentions = False,
output_hidden_states = False,
local_files_only = True
)
if freeze_layers is not None:
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
for module in modules_to_freeze:
for param in module.parameters():
param.requires_grad = False
model = model.to(device)
# evaluate model
predictions = F.softmax(model(targets.to(device))["logits"], dim = -1).argmax(-1)[0]
predictions = [labels[int(pred)] for pred in predictions]
return predictions
# Extracts aggregate gene embeddings for each label
if emb_extract == True:
with open(f'{model_location}/classes.txt', 'r') as f:
data = ast.literal_eval(f.read())
num_classes = data[0]
decode = data[1]
gene_classes = pd.read_csv(genes, header=0)
labels = gene_classes.columns
tokenize = TranscriptomeTokenizer()
label_dict = {}
for label in labels:
genes = gene_classes[label]
tokenized_genes = []
for gene in genes:
try:
tokenized_genes.append(tokenize.gene_token_dict[gene])
except:
continue
label_dict[label] = tokenized_genes
embex = EmbExtractor(model_type="GeneClassifier", num_classes=num_classes, emb_mode = "gene",
filter_data=None, max_ncells=max_cells, emb_layer=emb_layer,
emb_label=label_dict, labels_to_plot=list(labels), forward_batch_size=forward_batch, nproc=num_cpus)
subprocess.call(f'mkdir -p {emb_dir}', shell = True)
embs = embex.extract_embs(model_directory = model_location, input_data_file = model_location / 'dataset', output_directory = emb_dir, output_prefix = f"{label}_embbeddings")
emb_dict = {label:[] for label in list(set(labels))}
similarities = {key:{} for key in list(emb_dict.keys())}
for column in embs.columns:
remaining_cols = [k for k in embs.columns if k != column]
for k in remaining_cols:
embedding = torch.Tensor(embs[k])
sim = similarity(torch.Tensor(embs[column]), embedding, cosine = True)
similarities[column][k] = sim
plot_similarity_heatmap(similarities)
print(similarities)
return similarities
if __name__ == '__main__':
classify_genes(k_validate = False, inference = False, skip_training = False, emb_extract = True, output_dir = Path('gene_emb'), model_location = Path('gene_emb'), epochs = 5, gene_info = "../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_info_table.csv", genes = "../GeneFormer_repo/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", corpus_30M = "../GeneFormer_repo/Genecorpus-30M/genecorpus_30M_2048.dataset/")