|
import obonet |
|
import pandas as pd |
|
import streamlit as st |
|
from transformers import EncoderDecoderModel, PreTrainedTokenizerFast |
|
|
|
from io import StringIO |
|
from Bio import SeqIO |
|
|
|
from inference import run_inference, prepare_inputs, post_process_outputs, join_predictions_with_terms |
|
|
|
@st.cache_resource |
|
def load_model(model_path): |
|
return EncoderDecoderModel.from_pretrained(model_path).eval().requires_grad_(False) |
|
|
|
@st.cache_data |
|
def get_terms_df(): |
|
graph = obonet.read_obo('http://purl.obolibrary.org/obo/go/go-basic.obo') |
|
data = { |
|
"term": [], |
|
"name": [] |
|
} |
|
|
|
for term, obj in graph.nodes(data=True): |
|
data['term'].append(term) |
|
data['name'].append(obj['name']) |
|
|
|
return pd.DataFrame(data) |
|
|
|
def load_tokenizers(protein_tokenizer_path, term_tokenizer_path): |
|
protein_tokenizer = PreTrainedTokenizerFast.from_pretrained(protein_tokenizer_path) |
|
text_tokenizer = PreTrainedTokenizerFast.from_pretrained(term_tokenizer_path) |
|
|
|
return protein_tokenizer, text_tokenizer |
|
|
|
|
|
textarea_placeholder = "Input your sequences in fasta format" |
|
sample_protein = """>sp|A0A0C5B5G6|MOTSC_HUMAN Mitochondrial-derived peptide MOTS-c OS=Homo sapiens OX=9606 GN=MT-RNR1 PE=1 SV=1 |
|
MRWQEMGYIFYPRKLR |
|
""" |
|
|
|
tokenizers_path = ('khairi/ProtFormer', 'khairi/ProtNLA-V1') |
|
aspects = { |
|
'<molecular_function>': 'Molecular Function', |
|
'<cellular_component>': 'Cellular Component', |
|
'<biological_process>': 'Biological Process', |
|
} |
|
|
|
model = load_model("khairi/ProtNLA-V1") |
|
protein_tokenizer, text_tokenizer = load_tokenizers(tokenizers_path[0], tokenizers_path[1]) |
|
|
|
model.config.vocab_size = len(text_tokenizer) |
|
terms_df = get_terms_df() |
|
|
|
fasta_sequences = st.text_area("Input Sequences (fasta format):", value=sample_protein, placeholder=textarea_placeholder) |
|
aspect = st.selectbox(label="Gene Ontology Aspect", options=aspects.keys(), format_func=lambda x: aspects[x]) |
|
|
|
predict = st.button(label='Run') |
|
|
|
if predict: |
|
fasta_io = StringIO(fasta_sequences) |
|
records = SeqIO.parse(fasta_io, "fasta") |
|
for record in records: |
|
sequence = str(record.seq) |
|
inputs, prompt = prepare_inputs(protein_tokenizer=protein_tokenizer, term_tokenizer=text_tokenizer, |
|
aspect=aspect, protein_sequence=sequence) |
|
|
|
outputs = run_inference(model, inputs, prompt, 5) |
|
predictions = post_process_outputs(model, outputs, text_tokenizer) |
|
predictions = join_predictions_with_terms(predictions, terms_df) |
|
|
|
st.header('Predictions') |
|
st.subheader(f'Protein Sequence: {sequence}') |
|
st.subheader(f'Sequence Length: {len(record.seq)}') |
|
st.subheader(f'GO Aspect: {aspects[aspect]}') |
|
st.write(predictions) |
|
|