AptaBLE / gui.py
Atom Bioworks
Update gui.py
05b3c61 verified
from api_prediction import AptaBLE_Pipeline
import gradio as gr
import pandas as pd
import torch
import tempfile
from tabulate import tabulate
import itertools
import os
import random
# Visualization
os.environ['GRADIO_SERVER_NAME'] = '0.0.0.0'
title='AptaBLE Model Inference'
desc='AptaBLE, trained to predict the likelihood an aptamer will form a complex with a target protein!\n\nPass in a FASTA-formatted file of all aptamers and input your protein target amino acid sequence. Your output scores are available for download via an Excel file. At the moment, our demo only supports inference with DNA aptamers.'
global pipeline
pipeline = AptaBLE_Pipeline(
lr=1e-6,
weight_decay=None,
epochs=None,
model_type=None,
model_version=None,
model_save_path=None,
accelerate_save_path=None,
tensorboard_logdir=None,
d_model=128,
d_ff=512,
n_layers=6,
n_heads=8,
dropout=0.1,
load_best_pt=True, # already loads the pretrained model using the datasets included in repo -- no need to run the bottom two cells
device='cuda',
seed=1004)
def comparison(protein, aptamer_file, analysis):
print('analysis: ', analysis)
display = []
table_data = pd.DataFrame()
r_names, aptamers = read_fasta(aptamer_file)
proteins = [protein for i in range(len(aptamers))]
df = pd.DataFrame(columns=['Protein', 'Protein Seq', 'Aptamer', 'Aptamer Seq', 'Score'])
# print('Number of aptamers: ', len(aptamers))
scores = get_scores(aptamers, proteins)
df['Protein'] = ['protein_prov.']*len(aptamers)
df['Aptamer'] = r_names
df['Protein Seq'] = proteins
df['Aptamer Seq'] = aptamers
df['Score'] = scores
with tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx") as temp_file:
with pd.ExcelWriter(temp_file.name, engine='openpyxl') as writer:
df.to_excel(writer, index=False)
temp_file_path = temp_file.name
print('Saving to excel!')
df.to_excel(f'{aptamer_file}.xlsx')
torch.cuda.empty_cache()
return '\n'.join(display), temp_file_path
def read_fasta(file_path):
headers = []
sequences = []
with open(file_path, 'r') as file:
content = file.readlines()
for i in range(0, len(content), 2):
header = content[i].strip()
if header.startswith('>'):
headers.append(header)
sequences.append(content[i+1].strip())
return headers, sequences
def get_scores(aptamers, proteins):
pipeline.model.to('cuda')
scores = pipeline.inference(aptamers, proteins, [0]*len(aptamers))
pipeline.model.to('cpu')
return scores
iface = gr.Interface(
fn=comparison,
inputs=[
gr.Textbox(lines=2, placeholder="Protein"),
gr.File(type="filepath"),
],
outputs=[
gr.Textbox(placeholder="Scores"),
gr.File(label="Download Excel")
],
description=desc
)
iface.launch()