import obonet import pandas as pd import streamlit as st from transformers import EncoderDecoderModel, PreTrainedTokenizerFast from io import StringIO from Bio import SeqIO from inference import run_inference, prepare_inputs, post_process_outputs, join_predictions_with_terms @st.cache_resource def load_model(model_path): return EncoderDecoderModel.from_pretrained(model_path).eval().requires_grad_(False) @st.cache_data def get_terms_df(): graph = obonet.read_obo('http://purl.obolibrary.org/obo/go/go-basic.obo') data = { "term": [], "name": [] } for term, obj in graph.nodes(data=True): data['term'].append(term) data['name'].append(obj['name']) return pd.DataFrame(data) def load_tokenizers(protein_tokenizer_path, term_tokenizer_path): protein_tokenizer = PreTrainedTokenizerFast.from_pretrained(protein_tokenizer_path) text_tokenizer = PreTrainedTokenizerFast.from_pretrained(term_tokenizer_path) return protein_tokenizer, text_tokenizer textarea_placeholder = "Input your sequences in fasta format" sample_protein = """>sp|A0A0C5B5G6|MOTSC_HUMAN Mitochondrial-derived peptide MOTS-c OS=Homo sapiens OX=9606 GN=MT-RNR1 PE=1 SV=1 MRWQEMGYIFYPRKLR """ tokenizers_path = ('khairi/ProtFormer', 'khairi/ProtNLA-V1') aspects = { '': 'Molecular Function', '': 'Cellular Component', '': 'Biological Process', } model = load_model("khairi/ProtNLA-V1") protein_tokenizer, text_tokenizer = load_tokenizers(tokenizers_path[0], tokenizers_path[1]) model.config.vocab_size = len(text_tokenizer) terms_df = get_terms_df() fasta_sequences = st.text_area("Input Sequences (fasta format):", value=sample_protein, placeholder=textarea_placeholder) aspect = st.selectbox(label="Gene Ontology Aspect", options=aspects.keys(), format_func=lambda x: aspects[x]) predict = st.button(label='Run') if predict: fasta_io = StringIO(fasta_sequences) records = SeqIO.parse(fasta_io, "fasta") for record in records: sequence = str(record.seq) inputs, prompt = prepare_inputs(protein_tokenizer=protein_tokenizer, term_tokenizer=text_tokenizer, aspect=aspect, protein_sequence=sequence) outputs = run_inference(model, inputs, prompt, 5) predictions = post_process_outputs(model, outputs, text_tokenizer) predictions = join_predictions_with_terms(predictions, terms_df) st.header('Predictions') st.subheader(f'Protein Sequence: {sequence}') st.subheader(f'Sequence Length: {len(record.seq)}') st.subheader(f'GO Aspect: {aspects[aspect]}') st.write(predictions)