Glydentify / app.py
arikat's picture
checkbox
e0dca03
raw
history blame
29.8 kB
import pandas as pd
from IPython.display import clear_output
import torch
from transformers import EsmForSequenceClassification, AdamW, AutoTokenizer
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import numpy as np
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pickle
import torch.nn.functional as F
import gradio as gr
import io
from PIL import Image
import Bio
from Bio import SeqIO
import zipfile
import os
# Load the model from the file
with open('family_labels.pkl', 'rb') as filefam:
yfam = pickle.load(filefam)
tokenizerfam = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") #facebook/esm2_t33_650M_UR50D
device = 'cpu'
device
modelfam = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=len(yfam.classes_))
modelfam = modelfam.to('cpu')
modelfam.load_state_dict(torch.load("family.pth", map_location=torch.device('cpu')))
modelfam.eval()
x_testfam = ["""MAEVLRTLAGKPKCHALRPMILFLIMLVLVLFGYGVLSPRSLMPGSLERGFCMAVREPDH
LQRVSLPRMVYPQPKVLTPCRKDVLVVTPWLAPIVWEGTFNIDILNEQFRLQNTTIGLTV
FAIKKYVAFLKLFLETAEKHFMVGHRVHYYVFTDQPAAVPRVTLGTGRQLSVLEVRAYKR
WQDVSMRRMEMISDFCERRFLSEVDYLVCVDVDMEFRDHVGVEILTPLFGTLHPGFYGSS
REAFTYERRPQSQAYIPKDEGDFYYLGGFFGGSVQEVQRLTRACHQAMMVDQANGIEAVW
HDESHLNKYLLRHKPTKVLSPEYLWDQQLLGWPAVLRKLRFTAVPKNHQAVRNP
"""]
encoded_inputfam = tokenizerfam(x_testfam, padding=True, truncation=True, max_length=512, return_tensors="pt")
input_idsfam = encoded_inputfam["input_ids"]
attention_maskfam = encoded_inputfam["attention_mask"]
with torch.no_grad():
outputfam = modelfam(input_idsfam, attention_mask=attention_maskfam)
logitsfam = outputfam.logits
probabilitiesfam = F.softmax(logitsfam, dim=1)
_, predicted_labelsfam = torch.max(logitsfam, dim=1)
probabilitiesfam[0]
decoded_labelsfam = yfam.inverse_transform(predicted_labelsfam.tolist())
decoded_labelsfam
#Load donor model from file
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
with open('donor_labels.pkl', 'rb') as file:
label_encoder = pickle.load(file)
# encoded_labels = label_encoder.fit(y)
# labels = torch.tensor(encoded_labels)
model = EsmForSequenceClassification.from_pretrained("facebook/esm2_t12_35M_UR50D", num_labels=len(label_encoder.classes_))
model = model.to('cpu')
model.load_state_dict(torch.load("best_model_35M_t12_5v5.pth", map_location=torch.device('cpu'))) #model_best_35v2M.pth
model.eval()
x_test = ["""MAEVLRTLAGKPKCHALRPMILFLIMLVLVLFGYGVLSPRSLMPGSLERGFCMAVREPDH
LQRVSLPRMVYPQPKVLTPCRKDVLVVTPWLAPIVWEGTFNIDILNEQFRLQNTTIGLTV
FAIKKYVAFLKLFLETAEKHFMVGHRVHYYVFTDQPAAVPRVTLGTGRQLSVLEVRAYKR
WQDVSMRRMEMISDFCERRFLSEVDYLVCVDVDMEFRDHVGVEILTPLFGTLHPGFYGSS
REAFTYERRPQSQAYIPKDEGDFYYLGGFFGGSVQEVQRLTRACHQAMMVDQANGIEAVW
HDESHLNKYLLRHKPTKVLSPEYLWDQQLLGWPAVLRKLRFTAVPKNHQAVRNP
"""]
encoded_input = tokenizer(x_test, padding=True, truncation=True, max_length=512, return_tensors="pt")
input_ids = encoded_input["input_ids"]
attention_mask = encoded_input["attention_mask"]
with torch.no_grad():
output = model(input_ids, attention_mask=attention_mask)
logits = output.logits
probabilities = F.softmax(logits, dim=1)
_, predicted_labels = torch.max(logits, dim=1)
probabilities[0]
decoded_labels = label_encoder.inverse_transform(predicted_labels.tolist())
decoded_labels
glycosyltransferase_db = {
"GT31-chsy" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'},
"GT2-CesA2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT43-arath" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
"GT8-Met1" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT8.html' },
"GT32-higher" : {'CAZy Name': 'GT32', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT32.html'},
"GT40" : {'CAZy Name': 'GT40', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT40.html'},
"GT16" : {'CAZy Name': 'GT16', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT16.html'},
"GT27" : {'CAZy Name': 'GT27', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT27.html'},
"GT55" : {'CAZy Name': 'GT55', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT55.html'},
"GT8-Glycogenin" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT8.html' },
"GT8-1" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT8.html' },
"GT25" : {'CAZy Name': 'GT25', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT25.html'},
"GT2-DPM_like" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT31-fringe" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'},
"GT2-Bact_puta" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT84" : {'CAZy Name': 'GT84', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT84.html'},
"GT13" : {'CAZy Name': 'GT13', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT13.html'},
"GT43-cele" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
"GT2-Bact_LPS1" : {'CAZy Name': 'GT92', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT2-Bact_Oant" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT67" : {'CAZy Name': 'GT67', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT67.html'},
"GT2-HAS" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT82" : {'CAZy Name': 'GT82', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT82.html'},
"GT24" : {'CAZy Name': 'GT24', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT24.html'},
"GT31-plant" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'},
"GT81-Bact" : {'CAZy Name': 'GT81', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT81.html'},
"GT2-Bact_gt25Me": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT2-B3GntL" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '4 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT49" : {'CAZy Name': 'GT49', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT49.html'},
"GT34" : {'CAZy Name': 'GT34', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT34.html'},
"GT45" : {'CAZy Name': 'GT45', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT45.html'},
"GT32-lower" : {'CAZy Name': 'GT32', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT32.html'},
"GT88" : {'CAZy Name': 'GT88', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT88.html'},
"GT21" : {'CAZy Name': 'GT21', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT21.html'},
"GT2-DPG_synt" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT43-b3gat2" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
"GT2-Chitin_synt": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT8-Bact" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' },
"GT8-Met2" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' },
"GT2-Bact_Chlor1": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT54" : {'CAZy Name': 'GT54', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '6 ', 'More Info': 'http://www.cazy.org/GT54.html'},
"GT2-Cel_bre3" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT2-Bact_Rham" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT6" : {'CAZy Name': 'GT6 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT6.html' },
"GT2-Bact_puta2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT7-1" : {'CAZy Name': 'GT7 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT7.html' },
"GT2-Csl" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '4 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT2-ExoU" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT2-Csl2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '4 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT64" : {'CAZy Name': 'GT64', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT64.html'},
"GT2-Bact_Chlor2": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT78" : {'CAZy Name': 'GT78', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT78.html'},
"GT12" : {'CAZy Name': 'GT12', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT12.html'},
"GT31-gnt" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'},
"GT2-Bact_CHS" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT62" : {'CAZy Name': 'GT62', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '3 ', 'More Info': 'http://www.cazy.org/GT62.html'},
"GT8-Met_Pla" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT8.html' },
"GT15" : {'CAZy Name': 'GT15', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT15.html'},
"GT43-b3gat1" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
"GT31-b3glt" : {'CAZy Name': 'GT31', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '8 ', 'More Info': 'http://www.cazy.org/GT31.html'},
"GT2-CesA1" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '1 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT60" : {'CAZy Name': 'GT60', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '5 ', 'More Info': 'http://www.cazy.org/GT60.html'},
"GT14" : {'CAZy Name': 'GT14', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT14.html'},
"GT2-Bact_DPM_sy": {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '2 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT17" : {'CAZy Name': 'GT17', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '7 ', 'More Info': 'http://www.cazy.org/GT17.html'},
"GT2-Bact_LPS2" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': '3 ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT77" : {'CAZy Name': 'GT77', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT77.html'},
"GT2-Bact_EpsO" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': ' ', 'More Info': 'http://www.cazy.org/GT2.html' },
"GT43-b3gat3" : {'CAZy Name': 'GT43', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT43.html'},
"GT8-Fun" : {'CAZy Name': 'GT8 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Retaining', 'Clade': '9 ', 'More Info': 'http://www.cazy.org/GT8.html' },
"GT75" : {'CAZy Name': 'GT75', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT75.html'},
"GT2-Bact_GlfT" : {'CAZy Name': 'GT2 ', 'Alternative Name': '', 'Fold': 'A', 'Mechanism': 'Inverting', 'Clade': 'N/A', 'More Info': 'http://www.cazy.org/GT2.html' },
}
def get_family_info(family_name):
family_info = glycosyltransferase_db.get(family_name, {})
output = ""
for key, value in family_info.items():
if key == "more_info":
output += "**{}:**".format(key.title().replace("_", " ")) + "\n"
for link in value:
output += "[{}]({}) ".format(link, link)
else:
output += "**{}:** {} ".format(key.title().replace("_", " "), value)
return output
def fig_to_img(fig):
"""Converts a matplotlib figure to a PIL Image and returns it"""
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
img = Image.open(buf)
return img
def preprocess_protein_sequence(protein_fasta):
lines = protein_fasta.split('\n')
headers = [line for line in lines if line.startswith('>')]
if len(headers) > 1:
return None, "Multiple fasta sequences detected. Please upload a fasta file with only one sequence."
protein_sequence = ''.join(line for line in lines if not line.startswith('>'))
# Check for invalid characters
valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy") # the 20 standard amino acids
if not set(protein_sequence).issubset(valid_characters):
return None, "Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids. Does your sequence contain gaps?"
return protein_sequence, None
def process_family_sequence(protein_fasta):
protein_sequence, error_msg = preprocess_protein_sequence(protein_fasta)
if error_msg:
return None, None, None, error_msg
encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
input_idsfam = encoded_input["input_ids"]
attention_maskfam = encoded_input["attention_mask"]
with torch.no_grad():
outputfam = modelfam(input_idsfam, attention_mask=attention_maskfam)
logitsfam = outputfam.logits
probabilitiesfam = F.softmax(logitsfam, dim=1)
_, predicted_labelsfam = torch.max(logitsfam, dim=1)
decoded_labelsfam = yfam.inverse_transform(predicted_labelsfam.tolist())
family_info = get_family_info(decoded_labelsfam[0])
figfam = plt.figure(figsize=(10, 5))
labelsfam = yfam.classes_
probabilitiesfam = probabilitiesfam.tolist()
# Convert the nested list to a flat list of probabilities
probabilitiesfam_flat = probabilitiesfam[0] if probabilitiesfam else []
# Sort labels and probabilities by probability
labels_probsfam = list(zip(labelsfam, probabilitiesfam_flat))
labels_probsfam.sort(key=lambda x: x[1], reverse=True)
# Select the top 5 fams
labels_probs_top5fam = labels_probsfam[:5]
labels_top5, probabilities_top5 = zip(*labels_probs_top5fam)
y_posfam = np.arange(len(labels_top5))
plt.barh(y_posfam, [prob*100 for prob in probabilities_top5], align='center', alpha=0.5)
plt.yticks(y_posfam, labels_top5)
plt.xlabel('Probability (%)')
plt.title('Top 5 Family Class Probabilities')
plt.xlim(0, 100)
plt.close(figfam)
img = fig_to_img(figfam)
if len(protein_sequence) < 100:
return decoded_labelsfam[0], img, None, f"**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}"
return decoded_labelsfam[0], img, None, family_info
def process_single_sequence(protein_fasta): #, protein_file
protein_sequence, error_msg = preprocess_protein_sequence(protein_fasta)
if error_msg:
return None, None, None, error_msg
encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
input_ids = encoded_input["input_ids"]
attention_mask = encoded_input["attention_mask"]
with torch.no_grad():
output = model(input_ids, attention_mask=attention_mask)
logits = output.logits
dprobabilities = F.softmax(logits, dim=1)[0]
_, predicted_labels = torch.max(logits, dim=1)
decoded_labels = label_encoder.inverse_transform(predicted_labels.tolist())
family_info = get_family_info(decoded_labels[0])
fig = plt.figure(figsize=(10, 5))
labels = label_encoder.classes_
dprobabilities = dprobabilities.tolist()
# Sort labels and probabilities by probability
labels_probs = list(zip(labels, dprobabilities))
labels_probs.sort(key=lambda x: x[1], reverse=True)
# Select the top 3 donors
labels_probs_top3 = labels_probs[:3]
labels_top3, probabilities_top3 = zip(*labels_probs_top3)
y_pos = np.arange(len(labels_top3))
plt.barh(y_pos, [prob*100 for prob in probabilities_top3], align='center', alpha=0.5)
plt.yticks(y_pos, labels_top3)
plt.xlabel('Probability (%)')
plt.title('Top 3 Donor Class Probabilities')
plt.xlim(0, 100)
plt.close(fig)
img = fig_to_img(fig)
if len(protein_sequence) < 100:
return decoded_labels[0], img, None, f"**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions. \n\n {family_info}"
return decoded_labels[0], img, None, None
def process_sequence_file(protein_file): # added progress parameter that is displayed in gradio #, progress=gr.Progress()
try:
records = list(SeqIO.parse(protein_file.name, "fasta"))
except Exception as e:
return str(e)
if not os.path.exists('results'):
os.makedirs('results')
total = len(records)
for idx, record in enumerate(records):
protein_sequence = str(record.seq)
valid_characters = set("ACDEFGHIKLMNPQRSTVWYacdefghiklmnpqrstvwy")
if not set(protein_sequence).issubset(valid_characters):
with open(f'results/result_{idx+1}.txt', 'w') as file:
file.write("Invalid protein sequence. It contains characters that are not one of the 20 standard amino acids. Does your sequence contain gaps?")
continue
label, img, _, info = process_single_sequence(protein_sequence)
img.save(f'results/result_{idx+1}.png')
with open(f'results/result_{idx+1}.txt', 'w') as file:
file.write(f'Predicted Donor: {label}\n\n{info}')
# progress(idx/total) # Update the progress bar
# Create a zip file w/ results -- To Do: Figure out how to improve compression for large files
with zipfile.ZipFile('predicted_results.zip', 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk('results/'):
for file in files:
zipf.write(os.path.join(root, file))
return 'predicted_results.zip' #Provide indication of how to interpret downloaded zip file? f"**Warning:** The sequence is relatively short. Fragmentary and partial sequences may result in incorrect predictions.
# Function to mask a residue at a particular position
def mask_residue(sequence, position):
return sequence[:position] + 'X' + sequence[position+1:]
def generate_heatmap(protein_fasta):
protein_sequence, error_msg = preprocess_protein_sequence(protein_fasta)
# Tokenize and predict for original sequence
encoded_input = tokenizer([protein_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
with torch.no_grad():
original_output = model(encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"])
original_probabilities = F.softmax(original_output.logits, dim=1).cpu().numpy()[0]
# Define the size of each group
group_size = 10 # allow user to change this
# Calculate the number of groups
num_groups = len(protein_sequence) // group_size + (len(protein_sequence) % group_size > 0)
# Initialize an array to hold the importance scores
importance_scores = np.zeros((num_groups, len(original_probabilities)))
# Initialize tqdm progress bar
# with tqdm(total=num_groups, desc="Processing groups", position=0, leave=True) as pbar:
# # Loop through each group of residues in the sequence
for i in range(0, len(protein_sequence), group_size):
# Mask the residues in the group at positions [i, i + group_size)
masked_sequence = protein_sequence[:i] + 'X' * min(group_size, len(protein_sequence) - i) + protein_sequence[i + group_size:]
# Tokenize and predict for the masked sequence
encoded_input = tokenizer([masked_sequence], padding=True, truncation=True, max_length=512, return_tensors="pt")
with torch.no_grad():
masked_output = model(encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"])
masked_probabilities = F.softmax(masked_output.logits, dim=1).cpu().numpy()[0]
# Calculate the change in probabilities and store it as the importance score
group_index = i // group_size
importance_scores[group_index, :] = np.abs(original_probabilities - masked_probabilities)
progress = (i // group_size + 1) / num_groups * 100
print(f"Progress: {progress:.2f}%")
figmap, ax = plt.subplots(figsize=(20, 20))
sns.heatmap(importance_scores, annot=True, cmap="coolwarm", xticklabels=label_encoder.classes_, yticklabels=[f"{i}-{i+group_size-1}" for i in range(0, len(protein_sequence), group_size)], ax=ax)
ax.set_xlabel("Predicted Labels")
ax.set_ylabel("Residue Position Groups")
img = fig_to_img(figmap)
return img
def main_function_single(sequence, show_explanation):
# Process seq, and return outputs for both fam and don
family_label, family_img, _, family_info = process_family_sequence(sequence)
donor_label, donor_img, *_ = process_single_sequence(sequence)
figmap = None
if show_explanation:
figmap = generate_heatmap(sequence)
return family_label, family_img, family_info, donor_label, donor_img, figmap
def main_function_upload(protein_file): #, progress=gr.Progress()
return process_sequence_file(protein_file) #, progress
prediction_imagefam = gr.outputs.Image(type='pil', label="Family prediction graph")
prediction_imagedonor = gr.outputs.Image(type='pil', label="Donor prediction graph")
prediction_explain = gr.outputs.Image(type='pil', label="Donor prediction explanation")
with gr.Blocks() as app:
gr.Markdown("# Glydentify")
with gr.Tab("Single Sequence Prediction"):
with gr.Row().style(equal_height=True):
with gr.Column():
sequence = gr.inputs.Textbox(lines=16, placeholder='Enter Protein Sequence Here...', label="Protein Sequence")
explanation_checkbox = gr.inputs.Checkbox(label="Show Explanation", default=False)
with gr.Column():
with gr.Accordion("Example:"):
gr.Markdown("""
\>Q9LTZ9|GALS2_ARATH Galactan beta-1,4-galactosyltransferase GALS2
MAKERDQNTKDKNLLICFLWNFSAELKLALMALLVLCTLATLLPFLPSSFSISASELRFC
ISRIAVNSTSVNFTTVVEKPVLDNAVKLTEKPVLDNGVTKQPLTEEKVLNNGVIKRTFTG
YGWAAYNFVLMNAYRGGVNTFAVIGLSSKPLHVYSHPTYRCEWIPLNQSDNRILTDGTKI
LTDWGYGRVYTTVVVNCTFPSNTVINPKNTGGTLLLHATTGDTDRNITDSIPVLTETPNT
VDFALYESNLRRREKYDYLYCGSSLYGNLSPQRIREWIAYHVRFFGERSHFVLHDAGGIT
EEVFEVLKPWIELGRVTVHDIREQERFDGYYHNQFMVVNDCLHRYRFMAKWMFFFDVDEF
IYVPAKSSISSVMVSLEEYSQFTIEQMPMSSQLCYDGDGPARTYRKWGFEKLAYRDVKKV
PRRDRKYAVQPRNVFATGVHMSQHLQGKTYHRAEGKIRYFHYHGSISQRREPCRHLYNGT
RIVHENN
""")
family_prediction = gr.outputs.Textbox(label="Predicted family")
donor_prediction = gr.outputs.Textbox(label="Predicted donor")
info_markdown = gr.Markdown()
# Predict and Clear buttons
with gr.Row().style(equal_height=True):
with gr.Column():
predict_button = gr.Button("Predict")
predict_button.click(main_function_single, inputs=[sequence, explanation_checkbox],
outputs=[family_prediction, prediction_imagefam, info_markdown,
donor_prediction, prediction_imagedonor, prediction_explain])
# Family & Donor Section
with gr.Row().style(equal_height=True):
with gr.Column():
with gr.Accordion("Prediction Bar Graphs:"):
prediction_imagefam.render() # = gr.outputs.Image(type='pil', label="Family prediction graph")
prediction_imagedonor.render() # = gr.outputs.Image(type='pil', label="Donor prediction graph")
# Explain Section
with gr.Column():
if explanation_checkbox: # Only render if the checkbox is checked
with gr.Accordion("Donor explanation"):
prediction_explain.render() # = gr.outputs.Image(type='pil', label="Donor prediction explaination")
with gr.Tab("Multiple Sequence Prediction"):
with gr.Row().style(equal_height=True):
with gr.Column():
protein_file = gr.inputs.File(label="Upload FASTA file")
with gr.Column():
result_file = gr.outputs.File(label="Download predictions of uploaded sequences")
with gr.Row().style(equal_height=True):
with gr.Column():
process_button = gr.Button("Process")
process_button.click(main_function_upload, inputs=protein_file, outputs=[result_file])
with gr.Column():
clear = gr.Button("Clear")
clear.click(lambda: None)
# clear.click()
app.launch(show_error=True)