Spaces:
Sleeping
Sleeping
matanninio
commited on
Commit
·
41a03fb
1
Parent(s):
5988df0
added three MOLNET tasks. text in demo still not done
Browse files- mammal_demo/__init__.py +10 -4
- mammal_demo/molnet_task.py +88 -0
mammal_demo/__init__.py
CHANGED
@@ -7,7 +7,7 @@ from mammal_demo.dti_task import DtiTask
|
|
7 |
from mammal_demo.ppi_task import PpiTask
|
8 |
from mammal_demo.ps_task import PsTask
|
9 |
from mammal_demo.tcr_task import TcrTask
|
10 |
-
|
11 |
|
12 |
def tasks_and_models():
|
13 |
all_tasks = TaskRegistry()
|
@@ -21,6 +21,9 @@ def tasks_and_models():
|
|
21 |
tdi_task = all_tasks.register_task(DtiTask(model_dict=all_models))
|
22 |
ps_task = all_tasks.register_task(PsTask(model_dict=all_models))
|
23 |
tcr_task = all_tasks.register_task(TcrTask(model_dict=all_models))
|
|
|
|
|
|
|
24 |
|
25 |
# create the model holders. hold the model and the tokenizer, lazy download
|
26 |
# note that the list of relevent tasks needs to be stated.
|
@@ -46,13 +49,16 @@ def tasks_and_models():
|
|
46 |
task_list=[ppi_task],
|
47 |
)
|
48 |
all_models.register_model(
|
49 |
-
"ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_tox"
|
|
|
50 |
)
|
51 |
all_models.register_model(
|
52 |
-
"ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_fda"
|
|
|
53 |
)
|
54 |
all_models.register_model(
|
55 |
-
"ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_bbbp"
|
|
|
56 |
)
|
57 |
|
58 |
return all_tasks,all_models
|
|
|
7 |
from mammal_demo.ppi_task import PpiTask
|
8 |
from mammal_demo.ps_task import PsTask
|
9 |
from mammal_demo.tcr_task import TcrTask
|
10 |
+
from mammal_demo.molnet_task import MolnetTask
|
11 |
|
12 |
def tasks_and_models():
|
13 |
all_tasks = TaskRegistry()
|
|
|
21 |
tdi_task = all_tasks.register_task(DtiTask(model_dict=all_models))
|
22 |
ps_task = all_tasks.register_task(PsTask(model_dict=all_models))
|
23 |
tcr_task = all_tasks.register_task(TcrTask(model_dict=all_models))
|
24 |
+
bbbp_task = all_tasks.register_task(MolnetTask(model_dict=all_models,task_name="BBBP"))
|
25 |
+
toxicity_task = all_tasks.register_task(MolnetTask(model_dict=all_models,task_name="TOXICITY"))
|
26 |
+
fda_appr_task = all_tasks.register_task(MolnetTask(model_dict=all_models,task_name="FDA_APPR"))
|
27 |
|
28 |
# create the model holders. hold the model and the tokenizer, lazy download
|
29 |
# note that the list of relevent tasks needs to be stated.
|
|
|
49 |
task_list=[ppi_task],
|
50 |
)
|
51 |
all_models.register_model(
|
52 |
+
"ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_tox",
|
53 |
+
task_list=[toxicity_task]
|
54 |
)
|
55 |
all_models.register_model(
|
56 |
+
"ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_fda",
|
57 |
+
task_list=[fda_appr_task]
|
58 |
)
|
59 |
all_models.register_model(
|
60 |
+
"ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_bbbp",
|
61 |
+
task_list=[bbbp_task],
|
62 |
)
|
63 |
|
64 |
return all_tasks,all_models
|
mammal_demo/molnet_task.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from mammal.examples.molnet.molnet_infer import create_sample_dict as molnet_create_sample_dict, get_predictions, process_model_output
|
3 |
+
from mammal.keys import *
|
4 |
+
from mammal.model import Mammal
|
5 |
+
|
6 |
+
from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
|
7 |
+
|
8 |
+
|
9 |
+
class MolnetTask(MammalTask):
|
10 |
+
def __init__(self, model_dict, task_name="BBBP"):
|
11 |
+
super().__init__(name=f"Molnet: {task_name}", model_dict=model_dict)
|
12 |
+
self.description = f"MOLNET {task_name}"
|
13 |
+
self.examples = {
|
14 |
+
"drug_seq": "CC(=O)NCCC1=CNc2c1cc(OC)cc2",
|
15 |
+
}
|
16 |
+
self.task_name=task_name
|
17 |
+
self.markup_text = """
|
18 |
+
# Mammal based Drug-Target binding affinity demonstration
|
19 |
+
|
20 |
+
Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker) -> dict:
|
24 |
+
return molnet_create_sample_dict(task_name=self.task_name, smiles_seq=sample_inputs["drug_seq"], tokenizer_op=model_holder.tokenizer_op, model=model_holder.model)
|
25 |
+
|
26 |
+
def run_model(self, sample_dict, model: Mammal):
|
27 |
+
# Generate Prediction
|
28 |
+
batch_dict = get_predictions(model=model,sample_dict=sample_dict)
|
29 |
+
return batch_dict
|
30 |
+
|
31 |
+
def decode_output(self, batch_dict, model_holder):
|
32 |
+
result = process_model_output(
|
33 |
+
tokenizer_op=model_holder.tokenizer_op,
|
34 |
+
decoder_output=batch_dict[CLS_PRED][0],
|
35 |
+
decoder_output_scores=batch_dict[SCORES][0],
|
36 |
+
)
|
37 |
+
generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
|
38 |
+
return generated_output, result['pred'], result['score']
|
39 |
+
|
40 |
+
def create_and_run_prompt(self, model_name, drug_seq):
|
41 |
+
model_holder = self.model_dict[model_name]
|
42 |
+
inputs = {
|
43 |
+
"drug_seq": drug_seq,
|
44 |
+
}
|
45 |
+
sample_dict = self.crate_sample_dict(
|
46 |
+
sample_inputs=inputs, model_holder=model_holder
|
47 |
+
)
|
48 |
+
prompt = sample_dict[ENCODER_INPUTS_STR]
|
49 |
+
batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
|
50 |
+
res = prompt, *self.decode_output(batch_dict, model_holder=model_holder)
|
51 |
+
return res
|
52 |
+
|
53 |
+
def create_demo(self, model_name_widget):
|
54 |
+
|
55 |
+
# """
|
56 |
+
# ### Using the model from
|
57 |
+
|
58 |
+
# ```{model} ```
|
59 |
+
# """
|
60 |
+
with gr.Group() as demo:
|
61 |
+
gr.Markdown(self.markup_text)
|
62 |
+
with gr.Row():
|
63 |
+
drug_textbox = gr.Textbox(
|
64 |
+
label="Drug sequance (in SMILES)",
|
65 |
+
# info="standard",
|
66 |
+
interactive=True,
|
67 |
+
lines=3,
|
68 |
+
value=self.examples["drug_seq"],
|
69 |
+
)
|
70 |
+
with gr.Row():
|
71 |
+
run_mammal = gr.Button(
|
72 |
+
"Run Mammal prompt for task",
|
73 |
+
variant="primary",
|
74 |
+
)
|
75 |
+
with gr.Row():
|
76 |
+
prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
|
77 |
+
|
78 |
+
with gr.Row():
|
79 |
+
decoded = gr.Textbox(label="Mammal output")
|
80 |
+
prediction_box=gr.Textbox(label="Mammal prediction")
|
81 |
+
score_box=gr.Number(label="score")
|
82 |
+
run_mammal.click(
|
83 |
+
fn=self.create_and_run_prompt,
|
84 |
+
inputs=[model_name_widget, drug_textbox],
|
85 |
+
outputs=[prompt_box, decoded, prediction_box, score_box],
|
86 |
+
)
|
87 |
+
demo.visible = False
|
88 |
+
return demo
|