zsp / app.py
MassimoGregorioTotaro
checkbox fix, instructions update
475d75f
raw
history blame
4.39 kB
from tempfile import NamedTemporaryFile
from gradio import Blocks, Button, Checkbox, Dropdown, Examples, File, HTML, Markdown, Textbox
from model import get_models
from data import Data
# Define scoring strategies
SCORING = ["wt-marginals", "masked-marginals"]
# Get available models
MODELS = get_models()
def app(*argv):
"""
Main application function
"""
# Unpack the arguments
seq, trg, model_name, *_ = argv
scoring = SCORING[scoring_strategy.value]
try:
# Calculate the data based on the input parameters
data = Data(seq, trg, model_name, scoring, out_file).calculate()
except Exception as e:
# If an error occurs, return an HTML error message
return f'<!DOCTYPE html><html><body><h1 style="background-color:#F70D1A;text-align:center;">Error: {str(e)}</h1></body></html>', None
# If no error occurs, return the calculated data
return repr(data), File(value=out_file.name, visible=True)
# Create the Gradio interface
with open("instructions.md", "r", encoding="utf-8") as md,\
NamedTemporaryFile(mode='w+') as out_file,\
Blocks() as esm_scan:
# Define the interface components
Markdown(md.read())
seq = Textbox(
lines=2,
label="Sequence",
placeholder="FASTA sequence here...",
value=''
)
trg = Textbox(
lines=1,
label="Substitutions",
placeholder="Substitutions here...",
value=""
)
model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
scoring_strategy = Checkbox(value=True, label="Use higher accuracy scoring", interactive=True)
btn = Button(value="Run")
out = HTML()
bto = File(
value=out_file.name,
visible=False,
label="Download",
file_count='single',
interactive=False
)
btn.click(
fn=app,
inputs=[seq, trg, model_name],
outputs=[out, bto]
)
ex = Examples(
examples=[
[
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
"deep mutational scanning",
"facebook/esm2_t6_8M_UR50D"
],
[
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
"217 218 219",
"facebook/esm2_t12_35M_UR50D"
],
[
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
"R218K R218S R218N R218A R218V R218D",
"facebook/esm2_t30_150M_UR50D",
],
[
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMRGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
"MVEQYLLEAIVRDARDGITISDCSRPDNPLVFVNDAFTRMTGYDAEEVIGKNCRFLQRGDINLSAVHTIKIAMLTHEPCLVTLKNYRKDGTIFWNELSLTPIINKNGLITHYLGIQKDVSAQVILNQTLHEENHLLKSNKEMLEYLVNIDALTGLHNRRFLEDQLVIQWKLASRHINTITIFMIDIDYFKAFNDTYGHTAGDEALRTIAKTLNNCFMWGSDFVARYGGEEFTILAIGMTELQAHEYSTKLVQKIENLNIHHKGSPLGHLTISLGYSQANPQYHNDQNLVIEQADRALYSAKVEGKNRAVAYREQ",
"facebook/esm2_t33_650M_UR50D",
],
],
inputs=[seq,
trg,
model_name],
outputs=[out,
bto],
fn=app
)
# Launch the Gradio interface
if __name__ == "__main__":
esm_scan.launch()