AptaBLE / gui.py
Atom Bioworks
Update gui.py
529b593 verified
raw
history blame
2.88 kB
from api_prediction import AptaBLE_Pipeline
import gradio as gr
import pandas as pd
import torch
import tempfile
from tabulate import tabulate
import itertools
import os
import random
# Visualization
os.environ['GRADIO_SERVER_NAME'] = '0.0.0.0'
title='DNAptaBLE Model Inference'
desc='AptaBLE (cross-attention network), trained to predict the likelihood a DNA aptamer will form a complex with a target protein!\n\nPass in a FASTA-formatted file of all aptamers and input your protein target amino acid sequence. Your output scores are available for download via an Excel file.'
global pipeline
pipeline = AptaBLE_Pipeline(
lr=1e-6,
weight_decay=None,
epochs=None,
model_type=None,
model_version=None,
model_save_path=None,
accelerate_save_path=None,
tensorboard_logdir=None,
d_model=128,
d_ff=512,
n_layers=6,
n_heads=8,
dropout=0.1,
load_best_pt=True, # already loads the pretrained model using the datasets included in repo -- no need to run the bottom two cells
device='cuda',
seed=1004)
def comparison(protein, aptamer_file, analysis):
print('analysis: ', analysis)
display = []
table_data = pd.DataFrame()
r_names, aptamers = read_fasta(aptamer_file)
proteins = [protein for i in range(len(aptamers))]
df = pd.DataFrame(columns=['Protein', 'Protein Seq', 'Aptamer', 'Aptamer Seq', 'Score'])
# print('Number of aptamers: ', len(aptamers))
scores = get_scores(aptamers, proteins)
df['Protein'] = ['protein_prov.']*len(aptamers)
df['Aptamer'] = r_names
df['Protein Seq'] = proteins
df['Aptamer Seq'] = aptamers
df['Score'] = scores
with tempfile.NamedTemporaryFile(delete=False, suffix=".xlsx") as temp_file:
with pd.ExcelWriter(temp_file.name, engine='openpyxl') as writer:
df.to_excel(writer, index=False)
temp_file_path = temp_file.name
print('Saving to excel!')
df.to_excel(f'{aptamer_file}.xlsx')
torch.cuda.empty_cache()
return '\n'.join(display), temp_file_path
def read_fasta(file_path):
headers = []
sequences = []
with open(file_path, 'r') as file:
content = file.readlines()
for i in range(0, len(content), 2):
header = content[i].strip()
if header.startswith('>'):
headers.append(header)
sequences.append(content[i+1].strip())
return headers, sequences
def get_scores(aptamers, proteins):
pipeline.model.to('cuda')
scores = pipeline.inference(aptamers, proteins, [0]*len(aptamers))
pipeline.model.to('cpu')
return scores
iface = gr.Interface(
fn=comparison,
inputs=[
gr.Textbox(lines=2, placeholder="Protein"),
gr.File(type="filepath"),
],
outputs=[
gr.Textbox(placeholder="Scores"),
gr.File(label="Download Excel")
],
description=desc
)
iface.launch()