matanninio's picture
package name issues
4cf13c5
raw
history blame
7.72 kB
import gradio as gr
import torch
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
from mammal.keys import *
from mammal.model import Mammal
model_paths = dict()
ppi = "Protein-Protein Interaction (PPI)"
model_paths[ppi] = "ibm/biomed.omics.bl.sm.ma-ted-458m"
#
dti = "Drug-Target Binding Affinity"
model_paths[dti] = "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd"
# load models (should probably be lazy)
models = dict()
tokenizer_op = dict()
for task, model_path in model_paths.items():
if task not in models:
models[task] = Mammal.from_pretrained(model_path)
models[task].eval()
# Load Tokenizer
tokenizer_op[task] = ModularTokenizerOp.from_pretrained(model_path)
### PPI:
# token for positive binding
positive_token_id = tokenizer_op[ppi].get_token_id("<1>")
# Default input proteins
protein_calmodulin = "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK"
protein_calcineurin = "MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ"
def format_prompt_ppi(prot1, prot2):
# Formatting prompt to match pre-training syntax
return f"<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"
def run_prompt(prompt):
# Create and load sample
sample_dict = dict()
sample_dict[ENCODER_INPUTS_STR] = prompt
# Tokenize
sample_dict = tokenizer_op[ppi](
sample_dict=sample_dict,
key_in=ENCODER_INPUTS_STR,
key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
)
sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
sample_dict[ENCODER_INPUTS_TOKENS]
)
sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
sample_dict[ENCODER_INPUTS_ATTENTION_MASK]
)
# Generate Prediction
batch_dict = models[ppi].generate(
[sample_dict],
output_scores=True,
return_dict_in_generate=True,
max_new_tokens=5,
)
# Get output
generated_output = tokenizer_op[ppi]._tokenizer.decode(batch_dict[CLS_PRED][0])
score = batch_dict["model.out.scores"][0][1][positive_token_id].item()
return generated_output, score
def create_and_run_prompt(protein1, protein2):
prompt = format_prompt_ppi(protein1, protein2)
res = prompt, *run_prompt(prompt=prompt)
return res
def create_ppi_demo():
markup_text = f"""
# Mammal based Protein-Protein Interaction (PPI) demonstration
Given two protein sequences, estimate if the proteins interact or not.
### Using the model from
```{model_paths[ppi]} ```
"""
with gr.Group() as ppi_demo:
gr.Markdown(markup_text)
with gr.Row():
prot1 = gr.Textbox(
label="Protein 1 sequence",
# info="standard",
interactive=True,
lines=3,
value=protein_calmodulin,
)
prot2 = gr.Textbox(
label="Protein 2 sequence",
# info="standard",
interactive=True,
lines=3,
value=protein_calcineurin,
)
with gr.Row():
run_mammal = gr.Button(
"Run Mammal prompt for Protein-Protein Interaction", variant="primary"
)
with gr.Row():
prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
with gr.Row():
decoded = gr.Textbox(label="Mammal output")
run_mammal.click(
fn=create_and_run_prompt,
inputs=[prot1, prot2],
outputs=[prompt_box, decoded, gr.Number(label="PPI score")],
)
with gr.Row():
gr.Markdown(
"```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting"
)
ppi_demo.visible = False
return ppi_demo
### DTI:
# input
target_seq = "NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC"
drug_seq = "CC(=O)NCCC1=CNc2c1cc(OC)cc2"
# token for positive binding
positive_token_id = tokenizer_op[dti].get_token_id("<1>")
def format_prompt_dti(prot, drug):
sample_dict = {"target_seq": target_seq, "drug_seq": drug_seq}
sample_dict = DtiBindingdbKdTask.data_preprocessing(
sample_dict=sample_dict,
tokenizer_op=tokenizer_op[dti],
target_sequence_key="target_seq",
drug_sequence_key="drug_seq",
norm_y_mean=None,
norm_y_std=None,
device=models[dti].device,
)
return sample_dict
def create_and_run_prompt_dtb(prot, drug):
sample_dict = format_prompt_dti(prot, drug)
# Post-process the model's output
# batch_dict = model_dti.forward_encoder_only([sample_dict])
batch_dict = models[dti].forward_encoder_only([sample_dict])
batch_dict = DtiBindingdbKdTask.process_model_output(
batch_dict,
scalars_preds_processed_key="model.out.dti_bindingdb_kd",
norm_y_mean=5.79384684128215,
norm_y_std=1.33808027428196,
)
ans = [
"model.out.dti_bindingdb_kd",
float(batch_dict["model.out.dti_bindingdb_kd"][0]),
]
res = sample_dict["data.query.encoder_input"], *ans
return res
def create_tdb_demo():
markup_text = f"""
# Mammal based Target-Drug binding affinity demonstration
Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
### Using the model from
```{model_paths[dti]} ```
"""
with gr.Group() as tdb_demo:
gr.Markdown(markup_text)
with gr.Row():
prot = gr.Textbox(
label="Protein sequence",
# info="standard",
interactive=True,
lines=3,
value=target_seq,
)
drug = gr.Textbox(
label="drug sequence (SMILES)",
# info="standard",
interactive=True,
lines=3,
value=drug_seq,
)
with gr.Row():
run_mammal = gr.Button(
"Run Mammal prompt for Target Drug Affinity", variant="primary"
)
with gr.Row():
prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
with gr.Row():
decoded = gr.Textbox(label="Mammal output")
run_mammal.click(
fn=create_and_run_prompt_dtb,
inputs=[prot, drug],
outputs=[prompt_box, decoded, gr.Number(label="DTI score")],
)
tdb_demo.visible = False
return tdb_demo
def create_application():
with gr.Blocks() as demo:
main_dropdown = gr.Dropdown(choices=["select demo", ppi, dti])
main_dropdown.interactive = True
ppi_demo = create_ppi_demo()
dtb_demo = create_tdb_demo()
def set_ppi_vis(main_text):
return gr.Group(visible=main_text == ppi), gr.Group(
visible=main_text == dti
)
main_dropdown.change(
set_ppi_vis, inputs=main_dropdown, outputs=[ppi_demo, dtb_demo]
)
return demo
def main():
demo = create_application()
demo.launch(show_error=True, share=True)
if __name__ == "__main__":
main()