ProtNLA / app.py
khairi's picture
aggregate predictions and add term names
1f4905f
import obonet
import pandas as pd
import streamlit as st
from transformers import EncoderDecoderModel, PreTrainedTokenizerFast
from io import StringIO
from Bio import SeqIO
from inference import run_inference, prepare_inputs, post_process_outputs, join_predictions_with_terms
@st.cache_resource
def load_model(model_path):
return EncoderDecoderModel.from_pretrained(model_path).eval().requires_grad_(False)
@st.cache_data
def get_terms_df():
graph = obonet.read_obo('http://purl.obolibrary.org/obo/go/go-basic.obo')
data = {
"term": [],
"name": []
}
for term, obj in graph.nodes(data=True):
data['term'].append(term)
data['name'].append(obj['name'])
return pd.DataFrame(data)
def load_tokenizers(protein_tokenizer_path, term_tokenizer_path):
protein_tokenizer = PreTrainedTokenizerFast.from_pretrained(protein_tokenizer_path)
text_tokenizer = PreTrainedTokenizerFast.from_pretrained(term_tokenizer_path)
return protein_tokenizer, text_tokenizer
textarea_placeholder = "Input your sequences in fasta format"
sample_protein = """>sp|A0A0C5B5G6|MOTSC_HUMAN Mitochondrial-derived peptide MOTS-c OS=Homo sapiens OX=9606 GN=MT-RNR1 PE=1 SV=1
MRWQEMGYIFYPRKLR
"""
tokenizers_path = ('khairi/ProtFormer', 'khairi/ProtNLA-V1')
aspects = {
'<molecular_function>': 'Molecular Function',
'<cellular_component>': 'Cellular Component',
'<biological_process>': 'Biological Process',
}
model = load_model("khairi/ProtNLA-V1")
protein_tokenizer, text_tokenizer = load_tokenizers(tokenizers_path[0], tokenizers_path[1])
model.config.vocab_size = len(text_tokenizer)
terms_df = get_terms_df()
fasta_sequences = st.text_area("Input Sequences (fasta format):", value=sample_protein, placeholder=textarea_placeholder)
aspect = st.selectbox(label="Gene Ontology Aspect", options=aspects.keys(), format_func=lambda x: aspects[x])
predict = st.button(label='Run')
if predict:
fasta_io = StringIO(fasta_sequences)
records = SeqIO.parse(fasta_io, "fasta")
for record in records:
sequence = str(record.seq)
inputs, prompt = prepare_inputs(protein_tokenizer=protein_tokenizer, term_tokenizer=text_tokenizer,
aspect=aspect, protein_sequence=sequence)
outputs = run_inference(model, inputs, prompt, 5)
predictions = post_process_outputs(model, outputs, text_tokenizer)
predictions = join_predictions_with_terms(predictions, terms_df)
st.header('Predictions')
st.subheader(f'Protein Sequence: {sequence}')
st.subheader(f'Sequence Length: {len(record.seq)}')
st.subheader(f'GO Aspect: {aspects[aspect]}')
st.write(predictions)