Spaces:
Sleeping
Sleeping
import pandas as pd | |
from IPython.display import clear_output | |
import torch | |
from transformers import EsmForSequenceClassification, AdamW, AutoTokenizer | |
from torch.utils.data import DataLoader, TensorDataset, random_split | |
from sklearn.preprocessing import LabelEncoder | |
from tqdm import tqdm | |
import numpy as np | |
import seaborn as sns | |
from sklearn.model_selection import train_test_split | |
import matplotlib.pyplot as plt | |
import pickle | |
import torch.nn.functional as F | |
import gradio as gr | |
import io | |
from PIL import Image | |
import Bio | |
from Bio import SeqIO | |
import zipfile | |
import os | |
# Load the model from the file | |
with open('family_labels.pkl', 'rb') as filefam: | |
yfam = pickle.load(filefam) | |
tokenizerfam = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") #facebook/esm2_t33_650M_UR50D | |
device = 'cpu' | |
device | |
modelfam = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=len(yfam.classes_)) | |
modelfam = modelfam.to('cpu') | |
modelfam.load_state_dict(torch.load("family.pth", map_location=torch.device('cpu'))) | |
modelfam.eval() | |
x_testfam = ["""MAEVLRTLAGKPKCHALRPMILFLIMLVLVLFGYGVLSPRSLMPGSLERGFCMAVREPDH | |
LQRVSLPRMVYPQPKVLTPCRKDVLVVTPWLAPIVWEGTFNIDILNEQFRLQNTTIGLTV | |
FAIKKYVAFLKLFLETAEKHFMVGHRVHYYVFTDQPAAVPRVTLGTGRQLSVLEVRAYKR | |
WQDVSMRRMEMISDFCERRFLSEVDYLVCVDVDMEFRDHVGVEILTPLFGTLHPGFYGSS | |
REAFTYERRPQSQAYIPKDEGDFYYLGGFFGGSVQEVQRLTRACHQAMMVDQANGIEAVW | |
HDESHLNKYLLRHKPTKVLSPEYLWDQQLLGWPAVLRKLRFTAVPKNHQAVRNP | |
"""] | |
encoded_inputfam = tokenizerfam(x_testfam, padding=True, truncation=True, max_length=512, return_tensors="pt") | |
input_idsfam = encoded_inputfam["input_ids"] | |
attention_maskfam = encoded_inputfam["attention_mask"] | |
with torch.no_grad(): | |
outputfam = modelfam(input_idsfam, attention_mask=attention_maskfam) | |
logitsfam = outputfam.logits | |
probabilitiesfam = F.softmax(logitsfam, dim=1) | |
_, predicted_labelsfam = torch.max(logitsfam, dim=1) | |
probabilitiesfam[0] | |
decoded_labelsfam = yfam.inverse_transform(predicted_labelsfam.tolist()) | |
decoded_labelsfam | |
#Load donor model from file | |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") | |
with open('donor_labels.pkl', 'rb') as file: | |
label_encoder = pickle.load(file) | |
# encoded_labels = label_encoder.fit(y) | |
# labels = torch.tensor(encoded_labels) | |
model = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=len(label_encoder.classes_)) | |
model = model.to('cpu') | |
model.load_state_dict(torch.load("best_model_35M_t12_5v5.pth", map_location=torch.device('cpu'))) #model_best_35v2M.pth | |
model.eval() | |
x_test = ["""MAEVLRTLAGKPKCHALRPMILFLIMLVLVLFGYGVLSPRSLMPGSLERGFCMAVREPDH | |
LQRVSLPRMVYPQPKVLTPCRKDVLVVTPWLAPIVWEGTFNIDILNEQFRLQNTTIGLTV | |
FAIKKYVAFLKLFLETAEKHFMVGHRVHYYVFTDQPAAVPRVTLGTGRQLSVLEVRAYKR | |
WQDVSMRRMEMISDFCERRFLSEVDYLVCVDVDMEFRDHVGVEILTPLFGTLHPGFYGSS | |
REAFTYERRPQSQAYIPKDEGDFYYLGGFFGGSVQEVQRLTRACHQAMMVDQANGIEAVW | |
HDESHLNKYLLRHKPTKVLSPEYLWDQQLLGWPAVLRKLRFTAVPKNHQAVRNP | |
"""] | |
encoded_input = tokenizer(x_test, padding=True, truncation=True, max_length=512, return_tensors="pt") | |
input_ids = encoded_input["input_ids"] | |
attention_mask = encoded_input["attention_mask"] | |
with torch.no_grad(): | |
output = model(input_ids, attention_mask=attention_mask) | |
logits = output.logits | |
probabilities = F.softmax(logits, dim=1) | |
_, predicted_labels = torch.max(logits, dim=1) | |
probabilities[0] | |
decoded_labels = label_encoder.inverse_transform(predicted_labels.tolist()) | |
decoded_labels | |
glycosyltransferase_db = { | |
"GT31-chsy" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'}, | |
"GT2-CesA2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT43-arath" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'}, | |
"GT8-Met1" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT8.html' }, | |
"GT32-higher" : {'CAZy Name': 'GT32', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT32.html'}, | |
"GT40" : {'CAZy Name': 'GT40', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT40.html'}, | |
"GT16" : {'CAZy Name': 'GT16', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT16.html'}, | |
"GT27" : {'CAZy Name': 'GT27', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT27.html'}, | |
"GT55" : {'CAZy Name': 'GT55', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT55.html'}, | |
"GT8-Glycogenin" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT8.html' }, | |
"GT8-1" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT8.html' }, | |
"GT25" : {'CAZy Name': 'GT25', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT25.html'}, | |
"GT2-DPM_like" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT31-fringe" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'}, | |
"GT2-Bact_puta" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT84" : {'CAZy Name': 'GT84', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT84.html'}, | |
"GT13" : {'CAZy Name': 'GT13', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT13.html'}, | |
"GT43-cele" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'}, | |
"GT2-Bact_LPS1" : {'CAZy Name': 'GT92', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT2-Bact_Oant" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT67" : {'CAZy Name': 'GT67', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT67.html'}, | |
"GT2-HAS" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT82" : {'CAZy Name': 'GT82', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT82.html'}, | |
"GT24" : {'CAZy Name': 'GT24', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT24.html'}, | |
"GT31-plant" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'}, | |
"GT81-Bact" : {'CAZy Name': 'GT81', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT81.html'}, | |
"GT2-Bact_gt25Me": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT2-B3GntL" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '4 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT49" : {'CAZy Name': 'GT49', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT49.html'}, | |
"GT34" : {'CAZy Name': 'GT34', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT34.html'}, | |
"GT45" : {'CAZy Name': 'GT45', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT45.html'}, | |
"GT32-lower" : {'CAZy Name': 'GT32', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT32.html'}, | |
"GT88" : {'CAZy Name': 'GT88', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT88.html'}, | |
"GT21" : {'CAZy Name': 'GT21', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT21.html'}, | |
"GT2-DPG_synt" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT43-b3gat2" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'}, | |
"GT2-Chitin_synt": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT8-Bact" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' }, | |
"GT8-Met2" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' }, | |
"GT2-Bact_Chlor1": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT54" : {'CAZy Name': 'GT54', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT54.html'}, | |
"GT2-Cel_bre3" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT2-Bact_Rham" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT6" : {'CAZy Name': 'GT6 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT6.html' }, | |
"GT2-Bact_puta2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT7-1" : {'CAZy Name': 'GT7 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT7.html' }, | |
"GT2-Csl" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '4 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT2-ExoU" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT2-Csl2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '4 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT64" : {'CAZy Name': 'GT64', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT64.html'}, | |
"GT2-Bact_Chlor2": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT78" : {'CAZy Name': 'GT78', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT78.html'}, | |
"GT12" : {'CAZy Name': 'GT12', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT12.html'}, | |
"GT31-gnt" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'}, | |
"GT2-Bact_CHS" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT62" : {'CAZy Name': 'GT62', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '3 ', 'More Info': 'http://www.cazy.org/GT62.html'}, | |
"GT8-Met_Pla" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' }, | |
"GT15" : {'CAZy Name': 'GT15', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT15.html'}, | |
"GT43-b3gat1" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'}, | |
"GT31-b3glt" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'}, | |
"GT2-CesA1" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT60" : {'CAZy Name': 'GT60', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT60.html'}, | |
"GT14" : {'CAZy Name': 'GT14', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT14.html'}, | |
"GT2-Bact_DPM_sy": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT17" : {'CAZy Name': 'GT17', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT17.html'}, | |
"GT2-Bact_LPS2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '3 ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT77" : {'CAZy Name': 'GT77', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT77.html'}, | |
"GT2-Bact_EpsO" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
"GT43-b3gat3" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'}, | |
"GT8-Fun" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT8.html' }, | |
"GT75" : {'CAZy Name': 'GT75', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT75.html'}, | |
"GT2-Bact_GlfT" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT2.html' }, | |
} | |
def get_family_info(family_name): | |
family_info = glycosyltransferase_db.get(family_name, {}) | |
output = "" | |
for key, value in family_info.items(): | |
if key == "more_info": | |
output += "**{}:**".format(key.title().replace("_", " ")) + "\n" | |
for link in value: | |
output += "[{}]({}) ".format(link, link) | |
else: | |
output += "**{}:** {} ".format(key.title().replace("_", " "), value) | |
return output | |
def fig_to_img(fig): | |
"""Converts a matplotlib figure to a PIL Image and returns it""" | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png', bbox_inches='tight') | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
def preprocess_protein_sequence(protein_fasta): | |
lines = protein_fasta.split('\n') | |
headers = [line for line in lines if line.startswith('>')] | |
if len(headers) > 1: | |
return None, "Multiple fasta sequences detected. Please upload a fasta file with only one sequence." | |
protein_sequence = ''.join(line for line in lines if not line.startswith('>')) | |
# Check for invalid characters | |
valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy") # the 20 standard amino acids | |
if not set(protein_sequence).issubset(valid_characters): | |
return None, "Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids. Does your sequence contain gaps?" | |
return protein_sequence, None | |
def process_family_sequence(protein_fasta): | |
protein_sequence, error_msg = preprocess_protein_sequence(protein_fasta) | |
if error_msg: | |
return None, None, None, error_msg | |
encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt") | |
input_idsfam = encoded_input["input_ids"] | |
attention_maskfam = encoded_input["attention_mask"] | |
with torch.no_grad(): | |
outputfam = modelfam(input_idsfam, attention_mask=attention_maskfam) | |
logitsfam = outputfam.logits | |
probabilitiesfam = F.softmax(logitsfam, dim=1) | |
_, predicted_labelsfam = torch.max(logitsfam, dim=1) | |
decoded_labelsfam = yfam.inverse_transform(predicted_labelsfam.tolist()) | |
family_info = get_family_info(decoded_labelsfam[0]) | |
figfam = plt.figure(figsize=(10, 5)) | |
labelsfam = yfam.classes_ | |
probabilitiesfam = probabilitiesfam.tolist() | |
# Convert the nested list to a flat list of probabilities | |
probabilitiesfam_flat = probabilitiesfam[0] if probabilitiesfam else [] | |
# Sort labels and probabilities by probability | |
labels_probsfam = list(zip(labelsfam, probabilitiesfam_flat)) | |
labels_probsfam.sort(key=lambda x: x[1], reverse=True) | |
# Select the top 5 fams | |
labels_probs_top5fam = labels_probsfam[:5] | |
labels_top5, probabilities_top5 = zip(*labels_probs_top5fam) | |
y_posfam = np.arange(len(labels_top5)) | |
plt.barh(y_posfam, [prob*100 for prob in probabilities_top5], align='center', alpha=0.5) | |
plt.yticks(y_posfam, labels_top5) | |
plt.xlabel('Probability (%)') | |
plt.title('Top 5 Family Class Probabilities') | |
plt.xlim(0, 100) | |
plt.close(figfam) | |
img = fig_to_img(figfam) | |
if len(protein_sequence) < 100: | |
return decoded_labelsfam[0], img, None, f"**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}" | |
return decoded_labelsfam[0], img, None, family_info | |
def process_single_sequence(protein_fasta): #, protein_file | |
protein_sequence, error_msg = preprocess_protein_sequence(protein_fasta) | |
if error_msg: | |
return None, None, None, error_msg | |
encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt") | |
input_ids = encoded_input["input_ids"] | |
attention_mask = encoded_input["attention_mask"] | |
with torch.no_grad(): | |
output = model(input_ids, attention_mask=attention_mask) | |
logits = output.logits | |
dprobabilities = F.softmax(logits, dim=1)[0] | |
_, predicted_labels = torch.max(logits, dim=1) | |
decoded_labels = label_encoder.inverse_transform(predicted_labels.tolist()) | |
family_info = get_family_info(decoded_labels[0]) | |
fig = plt.figure(figsize=(10, 5)) | |
labels = label_encoder.classes_ | |
dprobabilities = dprobabilities.tolist() | |
# Sort labels and probabilities by probability | |
labels_probs = list(zip(labels, dprobabilities)) | |
labels_probs.sort(key=lambda x: x[1], reverse=True) | |
# Select the top 3 donors | |
labels_probs_top3 = labels_probs[:3] | |
labels_top3, probabilities_top3 = zip(*labels_probs_top3) | |
y_pos = np.arange(len(labels_top3)) | |
plt.barh(y_pos, [prob*100 for prob in probabilities_top3], align='center', alpha=0.5) | |
plt.yticks(y_pos, labels_top3) | |
plt.xlabel('Probability (%)') | |
plt.title('Top 3 Donor Class Probabilities') | |
plt.xlim(0, 100) | |
plt.close(fig) | |
img = fig_to_img(fig) | |
if len(protein_sequence) < 100: | |
return decoded_labels[0], img, None, f"**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}" | |
return decoded_labels[0], img, None, None | |
def process_sequence_file(protein_file): # added progress parameter that is displayed in gradio #, progress=gr.Progress() | |
try: | |
records = list(SeqIO.parse(protein_file.name, "fasta")) | |
except Exception as e: | |
return str(e) | |
if not os.path.exists('results'): | |
os.makedirs('results') | |
total = len(records) | |
for idx, record in enumerate(records): | |
protein_sequence = str(record.seq) | |
valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy") | |
if not set(protein_sequence).issubset(valid_characters): | |
with open(f'results/result_{idx+1}.txt', 'w') as file: | |
file.write("Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids. Does your sequence contain gaps?") | |
continue | |
label, img, _, info = process_single_sequence(protein_sequence) | |
img.save(f'results/result_{idx+1}.png') | |
with open(f'results/result_{idx+1}.txt', 'w') as file: | |
file.write(f'Predicted Donor: {label}\n\n{info}') | |
# progress(idx/total) # Update the progress bar | |
# Create a zip file w/ results -- To Do: Figure out how to improve compression for large files | |
with zipfile.ZipFile('predicted_results.zip', 'w', zipfile.ZIP_DEFLATED) as zipf: | |
for root, dirs, files in os.walk('results/'): | |
for file in files: | |
zipf.write(os.path.join(root, file)) | |
return 'predicted_results.zip' #Provide indication of how to interpret downloaded zip file? f"**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. | |
# Function to mask a residue at a particular position | |
def mask_residue(sequence, position): | |
return sequence[:position] + 'X' + sequence[position+1:] | |
def generate_heatmap(protein_fasta): | |
protein_sequence, error_msg = preprocess_protein_sequence(protein_fasta) | |
# Tokenize and predict for original sequence | |
encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt") | |
with torch.no_grad(): | |
original_output = model(encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"]) | |
original_probabilities = F.softmax(original_output.logits, dim=1).cpu().numpy()[0] | |
# Define the size of each group | |
group_size = 10 # allow user to change this | |
# Calculate the number of groups | |
num_groups = len(protein_sequence) // group_size + (len(protein_sequence) % group_size > 0) | |
# Initialize an array to hold the importance scores | |
importance_scores = np.zeros((num_groups, len(original_probabilities))) | |
# Initialize tqdm progress bar | |
# with tqdm(total=num_groups, desc="Processing groups", position=0, leave=True) as pbar: | |
# # Loop through each group of residues in the sequence | |
for i in range(0, len(protein_sequence), group_size): | |
# Mask the residues in the group at positions [i, i + group_size) | |
masked_sequence = protein_sequence[:i] + 'X' * min(group_size, len(protein_sequence) - i) + protein_sequence[i + group_size:] | |
# Tokenize and predict for the masked sequence | |
encoded_input = tokenizer([masked_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt") | |
with torch.no_grad(): | |
masked_output = model(encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"]) | |
masked_probabilities = F.softmax(masked_output.logits, dim=1).cpu().numpy()[0] | |
# Calculate the change in probabilities and store it as the importance score | |
group_index = i // group_size | |
importance_scores[group_index, :] = np.abs(original_probabilities - masked_probabilities) | |
progress = (i // group_size + 1) / num_groups * 100 | |
print(f"Progress: {progress:.2f}%") | |
figmap, ax = plt.subplots(figsize=(20, 20)) | |
sns.heatmap(importance_scores, annot=True, cmap="coolwarm", xticklabels=label_encoder.classes_, yticklabels=[f"{i}-{i+group_size-1}" for i in range(0, len(protein_sequence), group_size)], ax=ax) | |
ax.set_xlabel("Predicted Labels") | |
ax.set_ylabel("Residue Position Groups") | |
img = fig_to_img(figmap) | |
return img | |
def main_function_single(sequence, show_explanation): | |
# Process seq, and return outputs for both fam and don | |
family_label, family_img, _, family_info = process_family_sequence(sequence) | |
donor_label, donor_img, *_ = process_single_sequence(sequence) | |
figmap = None | |
if show_explanation: | |
figmap = generate_heatmap(sequence) | |
return family_label, family_img, family_info, donor_label, donor_img, figmap | |
def main_function_upload(protein_file): #, progress=gr.Progress() | |
return process_sequence_file(protein_file) #, progress | |
prediction_imagefam = gr.outputs.Image(type='pil', label="Family prediction graph") | |
prediction_imagedonor = gr.outputs.Image(type='pil', label="Donor prediction graph") | |
prediction_explain = gr.outputs.Image(type='pil', label="Donor prediction explanation") | |
with gr.Blocks() as app: | |
gr.Markdown("# Glydentify") | |
with gr.Tab("Single Sequence Prediction"): | |
with gr.Row().style(equal_height=True): | |
with gr.Column(): | |
sequence = gr.inputs.Textbox(lines=16, placeholder='Enter Protein Sequence Here...', label="Protein Sequence") | |
explanation_checkbox = gr.inputs.Checkbox(label="Show Explanation", default=False) | |
with gr.Column(): | |
with gr.Accordion("Example:"): | |
gr.Markdown(""" | |
\>Q9LTZ9|GALS2_ARATH Galactan beta-1,4-galactosyltransferase GALS2 | |
MAKERDQNTKDKNLLICFLWNFSAELKLALMALLVLCTLATLLPFLPSSFSISASELRFC | |
ISRIAVNSTSVNFTTVVEKPVLDNAVKLTEKPVLDNGVTKQPLTEEKVLNNGVIKRTFTG | |
YGWAAYNFVLMNAYRGGVNTFAVIGLSSKPLHVYSHPTYRCEWIPLNQSDNRILTDGTKI | |
LTDWGYGRVYTTVVVNCTFPSNTVINPKNTGGTLLLHATTGDTDRNITDSIPVLTETPNT | |
VDFALYESNLRRREKYDYLYCGSSLYGNLSPQRIREWIAYHVRFFGERSHFVLHDAGGIT | |
EEVFEVLKPWIELGRVTVHDIREQERFDGYYHNQFMVVNDCLHRYRFMAKWMFFFDVDEF | |
IYVPAKSSISSVMVSLEEYSQFTIEQMPMSSQLCYDGDGPARTYRKWGFEKLAYRDVKKV | |
PRRDRKYAVQPRNVFATGVHMSQHLQGKTYHRAEGKIRYFHYHGSISQRREPCRHLYNGT | |
RIVHENN | |
""") | |
family_prediction = gr.outputs.Textbox(label="Predicted family") | |
donor_prediction = gr.outputs.Textbox(label="Predicted donor") | |
info_markdown = gr.Markdown() | |
# Predict and Clear buttons | |
with gr.Row().style(equal_height=True): | |
with gr.Column(): | |
predict_button = gr.Button("Predict") | |
predict_button.click(main_function_single, inputs=[sequence, explanation_checkbox], | |
outputs=[family_prediction, prediction_imagefam, info_markdown, | |
donor_prediction, prediction_imagedonor, prediction_explain]) | |
# Family & Donor Section | |
with gr.Row().style(equal_height=True): | |
with gr.Column(): | |
with gr.Accordion("Prediction Bar Graphs:"): | |
prediction_imagefam.render() # = gr.outputs.Image(type='pil', label="Family prediction graph") | |
prediction_imagedonor.render() # = gr.outputs.Image(type='pil', label="Donor prediction graph") | |
# Explain Section | |
with gr.Column(): | |
if explanation_checkbox: # Only render if the checkbox is checked | |
with gr.Accordion("Donor explanation"): | |
prediction_explain.render() # = gr.outputs.Image(type='pil', label="Donor prediction explaination") | |
with gr.Tab("Multiple Sequence Prediction"): | |
with gr.Row().style(equal_height=True): | |
with gr.Column(): | |
protein_file = gr.inputs.File(label="Upload FASTA file") | |
with gr.Column(): | |
result_file = gr.outputs.File(label="Download predictions of uploaded sequences") | |
with gr.Row().style(equal_height=True): | |
with gr.Column(): | |
process_button = gr.Button("Process") | |
process_button.click(main_function_upload, inputs=protein_file, outputs=[result_file]) | |
with gr.Column(): | |
clear = gr.Button("Clear") | |
clear.click(lambda: None) | |
# clear.click() | |
app.launch(show_error=True) | |