matanninio commited on
Commit
41a03fb
·
1 Parent(s): 5988df0

added three MOLNET tasks. text in demo still not done

Browse files
mammal_demo/__init__.py CHANGED
@@ -7,7 +7,7 @@ from mammal_demo.dti_task import DtiTask
7
  from mammal_demo.ppi_task import PpiTask
8
  from mammal_demo.ps_task import PsTask
9
  from mammal_demo.tcr_task import TcrTask
10
-
11
 
12
  def tasks_and_models():
13
  all_tasks = TaskRegistry()
@@ -21,6 +21,9 @@ def tasks_and_models():
21
  tdi_task = all_tasks.register_task(DtiTask(model_dict=all_models))
22
  ps_task = all_tasks.register_task(PsTask(model_dict=all_models))
23
  tcr_task = all_tasks.register_task(TcrTask(model_dict=all_models))
 
 
 
24
 
25
  # create the model holders. hold the model and the tokenizer, lazy download
26
  # note that the list of relevent tasks needs to be stated.
@@ -46,13 +49,16 @@ def tasks_and_models():
46
  task_list=[ppi_task],
47
  )
48
  all_models.register_model(
49
- "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_tox"
 
50
  )
51
  all_models.register_model(
52
- "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_fda"
 
53
  )
54
  all_models.register_model(
55
- "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_bbbp"
 
56
  )
57
 
58
  return all_tasks,all_models
 
7
  from mammal_demo.ppi_task import PpiTask
8
  from mammal_demo.ps_task import PsTask
9
  from mammal_demo.tcr_task import TcrTask
10
+ from mammal_demo.molnet_task import MolnetTask
11
 
12
  def tasks_and_models():
13
  all_tasks = TaskRegistry()
 
21
  tdi_task = all_tasks.register_task(DtiTask(model_dict=all_models))
22
  ps_task = all_tasks.register_task(PsTask(model_dict=all_models))
23
  tcr_task = all_tasks.register_task(TcrTask(model_dict=all_models))
24
+ bbbp_task = all_tasks.register_task(MolnetTask(model_dict=all_models,task_name="BBBP"))
25
+ toxicity_task = all_tasks.register_task(MolnetTask(model_dict=all_models,task_name="TOXICITY"))
26
+ fda_appr_task = all_tasks.register_task(MolnetTask(model_dict=all_models,task_name="FDA_APPR"))
27
 
28
  # create the model holders. hold the model and the tokenizer, lazy download
29
  # note that the list of relevent tasks needs to be stated.
 
49
  task_list=[ppi_task],
50
  )
51
  all_models.register_model(
52
+ "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_tox",
53
+ task_list=[toxicity_task]
54
  )
55
  all_models.register_model(
56
+ "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_fda",
57
+ task_list=[fda_appr_task]
58
  )
59
  all_models.register_model(
60
+ "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_bbbp",
61
+ task_list=[bbbp_task],
62
  )
63
 
64
  return all_tasks,all_models
mammal_demo/molnet_task.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from mammal.examples.molnet.molnet_infer import create_sample_dict as molnet_create_sample_dict, get_predictions, process_model_output
3
+ from mammal.keys import *
4
+ from mammal.model import Mammal
5
+
6
+ from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
7
+
8
+
9
+ class MolnetTask(MammalTask):
10
+ def __init__(self, model_dict, task_name="BBBP"):
11
+ super().__init__(name=f"Molnet: {task_name}", model_dict=model_dict)
12
+ self.description = f"MOLNET {task_name}"
13
+ self.examples = {
14
+ "drug_seq": "CC(=O)NCCC1=CNc2c1cc(OC)cc2",
15
+ }
16
+ self.task_name=task_name
17
+ self.markup_text = """
18
+ # Mammal based Drug-Target binding affinity demonstration
19
+
20
+ Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
21
+ """
22
+
23
+ def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker) -> dict:
24
+ return molnet_create_sample_dict(task_name=self.task_name, smiles_seq=sample_inputs["drug_seq"], tokenizer_op=model_holder.tokenizer_op, model=model_holder.model)
25
+
26
+ def run_model(self, sample_dict, model: Mammal):
27
+ # Generate Prediction
28
+ batch_dict = get_predictions(model=model,sample_dict=sample_dict)
29
+ return batch_dict
30
+
31
+ def decode_output(self, batch_dict, model_holder):
32
+ result = process_model_output(
33
+ tokenizer_op=model_holder.tokenizer_op,
34
+ decoder_output=batch_dict[CLS_PRED][0],
35
+ decoder_output_scores=batch_dict[SCORES][0],
36
+ )
37
+ generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
38
+ return generated_output, result['pred'], result['score']
39
+
40
+ def create_and_run_prompt(self, model_name, drug_seq):
41
+ model_holder = self.model_dict[model_name]
42
+ inputs = {
43
+ "drug_seq": drug_seq,
44
+ }
45
+ sample_dict = self.crate_sample_dict(
46
+ sample_inputs=inputs, model_holder=model_holder
47
+ )
48
+ prompt = sample_dict[ENCODER_INPUTS_STR]
49
+ batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
50
+ res = prompt, *self.decode_output(batch_dict, model_holder=model_holder)
51
+ return res
52
+
53
+ def create_demo(self, model_name_widget):
54
+
55
+ # """
56
+ # ### Using the model from
57
+
58
+ # ```{model} ```
59
+ # """
60
+ with gr.Group() as demo:
61
+ gr.Markdown(self.markup_text)
62
+ with gr.Row():
63
+ drug_textbox = gr.Textbox(
64
+ label="Drug sequance (in SMILES)",
65
+ # info="standard",
66
+ interactive=True,
67
+ lines=3,
68
+ value=self.examples["drug_seq"],
69
+ )
70
+ with gr.Row():
71
+ run_mammal = gr.Button(
72
+ "Run Mammal prompt for task",
73
+ variant="primary",
74
+ )
75
+ with gr.Row():
76
+ prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
77
+
78
+ with gr.Row():
79
+ decoded = gr.Textbox(label="Mammal output")
80
+ prediction_box=gr.Textbox(label="Mammal prediction")
81
+ score_box=gr.Number(label="score")
82
+ run_mammal.click(
83
+ fn=self.create_and_run_prompt,
84
+ inputs=[model_name_widget, drug_textbox],
85
+ outputs=[prompt_box, decoded, prediction_box, score_box],
86
+ )
87
+ demo.visible = False
88
+ return demo