matanninio commited on
Commit
19dfa7a
1 Parent(s): f8080fc

cleanup and minor touches + renamed to the standard app name

Browse files
Files changed (5) hide show
  1. app.py +72 -233
  2. mammal_demo/demo_framework.py +36 -40
  3. mammal_demo/dti_task.py +26 -25
  4. mammal_demo/ppi_task.py +44 -40
  5. new_app.py +0 -76
app.py CHANGED
@@ -1,247 +1,86 @@
1
  import gradio as gr
2
- import torch
3
- from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
4
- from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
5
- from mammal.keys import *
6
- from mammal.model import Mammal
7
 
8
- model_paths = dict()
9
-
10
- ppi = "Protein-Protein Interaction (PPI)"
11
- model_paths[ppi] = "ibm/biomed.omics.bl.sm.ma-ted-458m"
12
-
13
- #
14
- dti = "Drug-Target Binding Affinity"
15
- model_paths[dti] = "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd"
16
-
17
-
18
- # load models (should probably be lazy)
19
-
20
- models = dict()
21
- tokenizer_op = dict()
22
-
23
-
24
- for task, model_path in model_paths.items():
25
- if task not in models:
26
- models[task] = Mammal.from_pretrained(model_path)
27
- models[task].eval()
28
- # Load Tokenizer
29
- tokenizer_op[task] = ModularTokenizerOp.from_pretrained(model_path)
30
-
31
-
32
- ### PPI:
33
- # token for positive binding
34
- positive_token_id = tokenizer_op[ppi].get_token_id("<1>")
35
-
36
- # Default input proteins
37
- protein_calmodulin = "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK"
38
- protein_calcineurin = "MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ"
39
-
40
-
41
- def format_prompt_ppi(prot1, prot2):
42
- # Formatting prompt to match pre-training syntax
43
- return f"<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END><MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN><SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"
44
-
45
-
46
- def run_prompt(prompt):
47
- # Create and load sample
48
- sample_dict = dict()
49
- sample_dict[ENCODER_INPUTS_STR] = prompt
50
-
51
- # Tokenize
52
- sample_dict = tokenizer_op[ppi](
53
- sample_dict=sample_dict,
54
- key_in=ENCODER_INPUTS_STR,
55
- key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
56
- key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
57
- )
58
- sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
59
- sample_dict[ENCODER_INPUTS_TOKENS]
60
- )
61
- sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
62
- sample_dict[ENCODER_INPUTS_ATTENTION_MASK]
63
- )
64
-
65
- # Generate Prediction
66
- batch_dict = models[ppi].generate(
67
- [sample_dict],
68
- output_scores=True,
69
- return_dict_in_generate=True,
70
- max_new_tokens=5,
71
- )
72
-
73
- # Get output
74
- generated_output = tokenizer_op[ppi]._tokenizer.decode(batch_dict[CLS_PRED][0])
75
- score = batch_dict["model.out.scores"][0][1][positive_token_id].item()
76
-
77
- return generated_output, score
78
-
79
-
80
- def create_and_run_prompt(protein1, protein2):
81
- prompt = format_prompt_ppi(protein1, protein2)
82
- res = prompt, *run_prompt(prompt=prompt)
83
- return res
84
-
85
-
86
- def create_ppi_demo():
87
- markup_text = f"""
88
- # Mammal based Protein-Protein Interaction (PPI) demonstration
89
-
90
- Given two protein sequences, estimate if the proteins interact or not.
91
-
92
- ### Using the model from
93
-
94
- ```{model_paths[ppi]} ```
95
- """
96
- with gr.Group() as ppi_demo:
97
- gr.Markdown(markup_text)
98
- with gr.Row():
99
- prot1 = gr.Textbox(
100
- label="Protein 1 sequence",
101
- # info="standard",
102
- interactive=True,
103
- lines=3,
104
- value=protein_calmodulin,
105
- )
106
- prot2 = gr.Textbox(
107
- label="Protein 2 sequence",
108
- # info="standard",
109
- interactive=True,
110
- lines=3,
111
- value=protein_calcineurin,
112
- )
113
- with gr.Row():
114
- run_mammal = gr.Button(
115
- "Run Mammal prompt for Protein-Protein Interaction", variant="primary"
116
- )
117
- with gr.Row():
118
- prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
119
-
120
- with gr.Row():
121
- decoded = gr.Textbox(label="Mammal output")
122
- run_mammal.click(
123
- fn=create_and_run_prompt,
124
- inputs=[prot1, prot2],
125
- outputs=[prompt_box, decoded, gr.Number(label="PPI score")],
126
- )
127
- with gr.Row():
128
- gr.Markdown(
129
- "```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting"
130
- )
131
- ppi_demo.visible = False
132
- return ppi_demo
133
-
134
-
135
- ### DTI:
136
- # input
137
- target_seq = "NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC"
138
- drug_seq = "CC(=O)NCCC1=CNc2c1cc(OC)cc2"
139
-
140
-
141
- # token for positive binding
142
- positive_token_id = tokenizer_op[dti].get_token_id("<1>")
143
-
144
-
145
- def format_prompt_dti(prot, drug):
146
- sample_dict = {"target_seq": target_seq, "drug_seq": drug_seq}
147
- sample_dict = DtiBindingdbKdTask.data_preprocessing(
148
- sample_dict=sample_dict,
149
- tokenizer_op=tokenizer_op[dti],
150
- target_sequence_key="target_seq",
151
- drug_sequence_key="drug_seq",
152
- norm_y_mean=None,
153
- norm_y_std=None,
154
- device=models[dti].device,
155
- )
156
- return sample_dict
157
-
158
-
159
- def create_and_run_prompt_dtb(prot, drug):
160
- sample_dict = format_prompt_dti(prot, drug)
161
- # Post-process the model's output
162
- # batch_dict = model_dti.forward_encoder_only([sample_dict])
163
- batch_dict = models[dti].forward_encoder_only([sample_dict])
164
- batch_dict = DtiBindingdbKdTask.process_model_output(
165
- batch_dict,
166
- scalars_preds_processed_key="model.out.dti_bindingdb_kd",
167
- norm_y_mean=5.79384684128215,
168
- norm_y_std=1.33808027428196,
169
- )
170
- ans = [
171
- "model.out.dti_bindingdb_kd",
172
- float(batch_dict["model.out.dti_bindingdb_kd"][0]),
173
- ]
174
- res = sample_dict["data.query.encoder_input"], *ans
175
- return res
176
-
177
-
178
- def create_tdb_demo():
179
- markup_text = f"""
180
- # Mammal based Target-Drug binding affinity demonstration
181
-
182
- Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
183
-
184
- ### Using the model from
185
-
186
- ```{model_paths[dti]} ```
187
- """
188
- with gr.Group() as tdb_demo:
189
- gr.Markdown(markup_text)
190
- with gr.Row():
191
- prot = gr.Textbox(
192
- label="Protein sequence",
193
- # info="standard",
194
- interactive=True,
195
- lines=3,
196
- value=target_seq,
197
- )
198
- drug = gr.Textbox(
199
- label="drug sequence (SMILES)",
200
- # info="standard",
201
- interactive=True,
202
- lines=3,
203
- value=drug_seq,
204
- )
205
- with gr.Row():
206
- run_mammal = gr.Button(
207
- "Run Mammal prompt for Target Drug Affinity", variant="primary"
208
- )
209
- with gr.Row():
210
- prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
211
-
212
- with gr.Row():
213
- decoded = gr.Textbox(label="Mammal output")
214
- run_mammal.click(
215
- fn=create_and_run_prompt_dtb,
216
- inputs=[prot, drug],
217
- outputs=[prompt_box, decoded, gr.Number(label="DTI score")],
218
- )
219
- tdb_demo.visible = False
220
- return tdb_demo
221
 
 
 
222
 
223
- def create_application():
 
 
 
 
 
 
 
 
 
224
 
225
- with gr.Blocks() as demo:
226
- main_dropdown = gr.Dropdown(choices=["select demo", ppi, dti])
227
- main_dropdown.interactive = True
228
- ppi_demo = create_ppi_demo()
229
- dtb_demo = create_tdb_demo()
230
 
231
- def set_ppi_vis(main_text):
232
- return gr.Group(visible=main_text == ppi), gr.Group(
233
- visible=main_text == dti
234
- )
235
 
236
- main_dropdown.change(
237
- set_ppi_vis, inputs=main_dropdown, outputs=[ppi_demo, dtb_demo]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  )
239
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
 
242
  def main():
243
- demo = create_application()
244
- demo.launch(show_error=True, share=True)
 
245
 
246
 
247
  if __name__ == "__main__":
 
1
  import gradio as gr
 
 
 
 
 
2
 
3
+ from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
4
+ from mammal_demo.dti_task import DtiTask
5
+ from mammal_demo.ppi_task import PpiTask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ all_tasks: dict[str, MammalTask] = dict()
8
+ all_models: dict[str, MammalObjectBroker] = dict()
9
 
10
+ ppi_task = PpiTask(model_dict=all_models)
11
+ all_tasks[ppi_task.name] = ppi_task
12
+
13
+ tdi_task = DtiTask(model_dict=all_models)
14
+ all_tasks[tdi_task.name] = tdi_task
15
+
16
+ ppi_model = MammalObjectBroker(
17
+ model_path="ibm/biomed.omics.bl.sm.ma-ted-458m", task_list=[ppi_task.name]
18
+ )
19
+ all_models[ppi_model.name] = ppi_model
20
 
21
+ tdi_model = MammalObjectBroker(
22
+ model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd",
23
+ task_list=[tdi_task.name],
24
+ )
25
+ all_models[tdi_model.name] = tdi_model
26
 
 
 
 
 
27
 
28
+ def create_application():
29
+ def task_change(value):
30
+ visibility = [gr.update(visible=(task == value)) for task in all_tasks.keys()]
31
+ # all_tasks[task].demo().visible =
32
+ choices = [
33
+ model_name
34
+ for model_name, model in all_models.items()
35
+ if value in model.tasks
36
+ ]
37
+ if choices:
38
+ return (gr.update(choices=choices, value=choices[0], visible=True), *visibility)
39
+ else:
40
+ return (gr.skip, *visibility)
41
+ # return model_name_dropdown
42
+
43
+ with gr.Blocks() as application:
44
+ task_dropdown = gr.Dropdown(choices=["select demo"] + list(all_tasks.keys()), label="Mammal Task")
45
+ task_dropdown.interactive = True
46
+ model_name_dropdown = gr.Dropdown(
47
+ choices=[
48
+ model_name
49
+ for model_name, model in all_models.items()
50
+ if task_dropdown.value in model.tasks
51
+ ],
52
+ interactive=True,
53
+ label="Matching Mammal models",
54
+ visible=False,
55
+ )
56
+
57
+ task_dropdown.change(
58
+ task_change,
59
+ inputs=[task_dropdown],
60
+ outputs=[model_name_dropdown]
61
+ + [all_tasks[task].demo(model_name_widgit=model_name_dropdown) for task in all_tasks],
62
  )
63
+
64
+ # def set_demo_vis(main_text):
65
+ # main_text=main_text
66
+ # print(f"main text is {main_text}")
67
+ # return gr.Group(visible=True)
68
+ # #return gr.Group(visible=(main_text == "PPI"))
69
+ # # , gr.Group( visible=(main_text == "DTI") )
70
+
71
+ # task_dropdown.change(
72
+ # set_ppi_vis, inputs=task_dropdown, outputs=[ppi_demo]
73
+ # )
74
+ return application
75
+
76
+
77
+ full_demo = None
78
 
79
 
80
  def main():
81
+ global full_demo
82
+ full_demo = create_application()
83
+ full_demo.launch(show_error=True, share=False)
84
 
85
 
86
  if __name__ == "__main__":
mammal_demo/demo_framework.py CHANGED
@@ -1,51 +1,48 @@
 
 
1
  import gradio as gr
2
  from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
3
- from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
4
- from mammal.keys import *
5
  from mammal.model import Mammal
6
- from abc import ABC, abstractmethod
7
 
8
-
9
 
10
-
11
-
12
- class MammalObjectBroker():
13
- def __init__(self, model_path: str, name:str= None, task_list: list[str]=None) -> None:
 
 
 
14
  self.model_path = model_path
15
  if name is None:
16
  name = model_path
17
- self.name = name
18
-
 
19
  if task_list is not None:
20
- self.tasks=task_list
21
- else:
22
- self.task = []
23
- self._model = None
24
  self._tokenizer_op = None
25
-
26
-
27
  @property
28
- def model(self)-> Mammal:
29
  if self._model is None:
30
- self._model = Mammal.from_pretrained(self.model_path)
31
- self._model.eval()
32
  return self._model
33
-
34
  @property
35
  def tokenizer_op(self):
36
  if self._tokenizer_op is None:
37
- self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
38
  return self._tokenizer_op
39
-
40
-
41
-
42
 
43
  class MammalTask(ABC):
44
- def __init__(self, name:str, model_dict: dict[str,MammalObjectBroker]) -> None:
45
- self.name = name
46
- self.description = None
47
- self._demo = None
48
- self.model_dict = model_dict
49
 
50
  # @abstractmethod
51
  # def _generate_prompt(self, **kwargs) -> str:
@@ -61,7 +58,9 @@ class MammalTask(ABC):
61
  # raise NotImplementedError()
62
 
63
  @abstractmethod
64
- def crate_sample_dict(self,sample_inputs: dict, model_holder:MammalObjectBroker) -> dict:
 
 
65
  """Formatting prompt to match pre-training syntax
66
 
67
  Args:
@@ -73,9 +72,9 @@ class MammalTask(ABC):
73
  raise NotImplementedError()
74
 
75
  # @abstractmethod
76
- def run_model(self, sample_dict, model:Mammal):
77
  raise NotImplementedError()
78
-
79
  def create_demo(self, model_name_widget: gr.component) -> gr.Group:
80
  """create an gradio demo group
81
 
@@ -89,20 +88,17 @@ class MammalTask(ABC):
89
  """
90
  raise NotImplementedError()
91
 
92
-
93
-
94
- def demo(self,model_name_widgit:gr.component=None):
95
  if self._demo is None:
96
- model_name_widget:gr.component
97
  self._demo = self.create_demo(model_name_widget=model_name_widgit)
98
  return self._demo
99
 
100
  @abstractmethod
101
- def decode_output(self,batch_dict, model:Mammal):
102
  raise NotImplementedError()
103
 
104
- #self._setup()
105
-
106
  # def _setup(self):
107
  # pass
108
-
 
1
+ from abc import ABC, abstractmethod
2
+
3
  import gradio as gr
4
  from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
 
 
5
  from mammal.model import Mammal
 
6
 
 
7
 
8
+ class MammalObjectBroker:
9
+ def __init__(
10
+ self,
11
+ model_path: str,
12
+ name: str | None = None,
13
+ task_list: list[str] | None = None,
14
+ ) -> None:
15
  self.model_path = model_path
16
  if name is None:
17
  name = model_path
18
+ self.name = name
19
+
20
+ self.tasks: list[str] = []
21
  if task_list is not None:
22
+ self.tasks = task_list
23
+ self._model: Mammal | None = None
 
 
24
  self._tokenizer_op = None
25
+
 
26
  @property
27
+ def model(self) -> Mammal:
28
  if self._model is None:
29
+ self._model = Mammal.from_pretrained(self.model_path)
30
+ self._model.eval()
31
  return self._model
32
+
33
  @property
34
  def tokenizer_op(self):
35
  if self._tokenizer_op is None:
36
+ self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
37
  return self._tokenizer_op
38
+
 
 
39
 
40
  class MammalTask(ABC):
41
+ def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None:
42
+ self.name = name
43
+ self.description = None
44
+ self._demo = None
45
+ self.model_dict = model_dict
46
 
47
  # @abstractmethod
48
  # def _generate_prompt(self, **kwargs) -> str:
 
58
  # raise NotImplementedError()
59
 
60
  @abstractmethod
61
+ def crate_sample_dict(
62
+ self, sample_inputs: dict, model_holder: MammalObjectBroker
63
+ ) -> dict:
64
  """Formatting prompt to match pre-training syntax
65
 
66
  Args:
 
72
  raise NotImplementedError()
73
 
74
  # @abstractmethod
75
+ def run_model(self, sample_dict, model: Mammal):
76
  raise NotImplementedError()
77
+
78
  def create_demo(self, model_name_widget: gr.component) -> gr.Group:
79
  """create an gradio demo group
80
 
 
88
  """
89
  raise NotImplementedError()
90
 
91
+ def demo(self, model_name_widgit: gr.component = None):
 
 
92
  if self._demo is None:
93
+ model_name_widget: gr.component
94
  self._demo = self.create_demo(model_name_widget=model_name_widgit)
95
  return self._demo
96
 
97
  @abstractmethod
98
+ def decode_output(self, batch_dict, model: Mammal):
99
  raise NotImplementedError()
100
 
101
+ # self._setup()
102
+
103
  # def _setup(self):
104
  # pass
 
mammal_demo/dti_task.py CHANGED
@@ -3,7 +3,8 @@ from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
3
  from mammal.keys import *
4
  from mammal.model import Mammal
5
 
6
- from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
 
7
 
8
  class DtiTask(MammalTask):
9
  def __init__(self, model_dict):
@@ -11,15 +12,15 @@ class DtiTask(MammalTask):
11
  self.description = "Drug-Target Binding Affinity (tdi)"
12
  self.examples = {
13
  "target_seq": "NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC",
14
- "drug_seq":"CC(=O)NCCC1=CNc2c1cc(OC)cc2"
15
- }
16
  self.markup_text = """
17
  # Mammal based Target-Drug binding affinity demonstration
18
 
19
  Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
20
  """
21
-
22
- def crate_sample_dict(self, sample_inputs:dict, model_holder:MammalObjectBroker):
23
  """convert sample_inputs to sample_dict including creating a proper prompt
24
 
25
  Args:
@@ -39,14 +40,13 @@ Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
39
  device=model_holder.model.device,
40
  )
41
  return sample_dict
42
-
43
 
44
  def run_model(self, sample_dict, model: Mammal):
45
  # Generate Prediction
46
  batch_dict = model.forward_encoder_only([sample_dict])
47
  return batch_dict
48
-
49
- def decode_output(self,batch_dict, model_holder):
50
 
51
  # Get output
52
  batch_dict = DtiBindingdbKdTask.process_model_output(
@@ -54,34 +54,34 @@ Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
54
  scalars_preds_processed_key="model.out.dti_bindingdb_kd",
55
  norm_y_mean=5.79384684128215,
56
  norm_y_std=1.33808027428196,
57
- )
58
  ans = (
59
- "model.out.dti_bindingdb_kd",
60
- float(batch_dict["model.out.dti_bindingdb_kd"][0]),
61
- )
62
  return ans
63
 
64
-
65
- def create_and_run_prompt(self,model_name,target_seq, drug_seq):
66
  model_holder = self.model_dict[model_name]
67
  inputs = {
68
  "target_seq": target_seq,
69
  "drug_seq": drug_seq,
70
  }
71
- sample_dict = self.crate_sample_dict(sample_inputs=inputs, model_holder=model_holder)
72
- prompt=sample_dict[ENCODER_INPUTS_STR]
 
 
73
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
74
- res = prompt, *self.decode_output(batch_dict,model_holder=model_holder)
75
  return res
76
 
77
-
78
- def create_demo(self,model_name_widget):
79
-
80
- # """
81
- # ### Using the model from
82
 
83
- # ```{model} ```
84
- # """
85
  with gr.Group() as demo:
86
  gr.Markdown(self.markup_text)
87
  with gr.Row():
@@ -101,7 +101,8 @@ Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
101
  )
102
  with gr.Row():
103
  run_mammal = gr.Button(
104
- "Run Mammal prompt for Protein-Protein Interaction", variant="primary"
 
105
  )
106
  with gr.Row():
107
  prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
 
3
  from mammal.keys import *
4
  from mammal.model import Mammal
5
 
6
+ from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
7
+
8
 
9
  class DtiTask(MammalTask):
10
  def __init__(self, model_dict):
 
12
  self.description = "Drug-Target Binding Affinity (tdi)"
13
  self.examples = {
14
  "target_seq": "NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC",
15
+ "drug_seq": "CC(=O)NCCC1=CNc2c1cc(OC)cc2",
16
+ }
17
  self.markup_text = """
18
  # Mammal based Target-Drug binding affinity demonstration
19
 
20
  Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
21
  """
22
+
23
+ def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
24
  """convert sample_inputs to sample_dict including creating a proper prompt
25
 
26
  Args:
 
40
  device=model_holder.model.device,
41
  )
42
  return sample_dict
 
43
 
44
  def run_model(self, sample_dict, model: Mammal):
45
  # Generate Prediction
46
  batch_dict = model.forward_encoder_only([sample_dict])
47
  return batch_dict
48
+
49
+ def decode_output(self, batch_dict, model_holder):
50
 
51
  # Get output
52
  batch_dict = DtiBindingdbKdTask.process_model_output(
 
54
  scalars_preds_processed_key="model.out.dti_bindingdb_kd",
55
  norm_y_mean=5.79384684128215,
56
  norm_y_std=1.33808027428196,
57
+ )
58
  ans = (
59
+ "model.out.dti_bindingdb_kd",
60
+ float(batch_dict["model.out.dti_bindingdb_kd"][0]),
61
+ )
62
  return ans
63
 
64
+ def create_and_run_prompt(self, model_name, target_seq, drug_seq):
 
65
  model_holder = self.model_dict[model_name]
66
  inputs = {
67
  "target_seq": target_seq,
68
  "drug_seq": drug_seq,
69
  }
70
+ sample_dict = self.crate_sample_dict(
71
+ sample_inputs=inputs, model_holder=model_holder
72
+ )
73
+ prompt = sample_dict[ENCODER_INPUTS_STR]
74
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
75
+ res = prompt, *self.decode_output(batch_dict, model_holder=model_holder)
76
  return res
77
 
78
+ def create_demo(self, model_name_widget):
79
+
80
+ # """
81
+ # ### Using the model from
 
82
 
83
+ # ```{model} ```
84
+ # """
85
  with gr.Group() as demo:
86
  gr.Markdown(self.markup_text)
87
  with gr.Row():
 
101
  )
102
  with gr.Row():
103
  run_mammal = gr.Button(
104
+ "Run Mammal prompt for Protein-Protein Interaction",
105
+ variant="primary",
106
  )
107
  with gr.Row():
108
  prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
mammal_demo/ppi_task.py CHANGED
@@ -1,12 +1,14 @@
1
  import gradio as gr
2
  import torch
3
- from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
4
- from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
5
- from mammal.keys import *
 
 
 
6
  from mammal.model import Mammal
7
 
8
- from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
9
-
10
 
11
 
12
  class PpiTask(MammalTask):
@@ -19,11 +21,9 @@ class PpiTask(MammalTask):
19
  }
20
  self.markup_text = """
21
  # Mammal based {self.description} demonstration
22
-
23
  Given two protein sequences, estimate if the proteins interact or not."""
24
-
25
-
26
-
27
  @staticmethod
28
  def positive_token_id(model_holder: MammalObjectBroker):
29
  """token for positive binding
@@ -35,7 +35,7 @@ class PpiTask(MammalTask):
35
  int: id of positive binding token
36
  """
37
  return model_holder.tokenizer_op.get_token_id("<1>")
38
-
39
  def generate_prompt(self, prot1, prot2):
40
  """Formatting prompt to match pre-training syntax
41
 
@@ -45,16 +45,17 @@ class PpiTask(MammalTask):
45
 
46
  Returns:
47
  str: prompt
48
- """
49
- prompt = f"<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"\
50
- "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"\
51
- "<SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END>"\
52
- "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"\
53
- "<SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"
 
 
54
  return prompt
55
-
56
-
57
- def crate_sample_dict(self,sample_inputs: dict, model_holder:MammalObjectBroker):
58
  # Create and load sample
59
  sample_dict = dict()
60
  prompt = self.generate_prompt(*sample_inputs)
@@ -84,35 +85,37 @@ class PpiTask(MammalTask):
84
  max_new_tokens=5,
85
  )
86
  return batch_dict
87
-
88
- def decode_output(self,batch_dict, model_holder:MammalObjectBroker):
89
 
90
  # Get output
91
- generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
92
- score = batch_dict["model.out.scores"][0][1][self.positive_token_id(model_holder)].item()
 
 
 
 
93
 
94
  return generated_output, score
95
 
96
-
97
- def create_and_run_prompt(self,model_name,protein1, protein2):
98
  model_holder = self.model_dict[model_name]
99
- sample_inputs = {"prot1":protein1,
100
- "prot2":protein2
101
- }
102
- sample_dict = self.crate_sample_dict(sample_inputs=sample_inputs, model_holder=model_holder)
103
  prompt = sample_dict[ENCODER_INPUTS_STR]
104
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
105
- res = prompt, *self.decode_output(batch_dict,model_holder=model_holder)
106
  return res
107
 
108
-
109
- def create_demo(self,model_name_widget:gr.component):
110
-
111
- # """
112
- # ### Using the model from
113
 
114
- # ```{model} ```
115
- # """
116
  with gr.Group() as demo:
117
  gr.Markdown(self.markup_text)
118
  with gr.Row():
@@ -132,17 +135,18 @@ class PpiTask(MammalTask):
132
  )
133
  with gr.Row():
134
  run_mammal: gr.Button = gr.Button(
135
- "Run Mammal prompt for Protein-Protein Interaction", variant="primary"
 
136
  )
137
  with gr.Row():
138
  prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
139
-
140
  with gr.Row():
141
  decoded = gr.Textbox(label="Mammal output")
142
  run_mammal.click(
143
  fn=self.create_and_run_prompt,
144
  inputs=[model_name_widget, prot1, prot2],
145
- outputs=[prompt_box, decoded, gr.Number(label="PPI score")],
146
  )
147
  with gr.Row():
148
  gr.Markdown(
 
1
  import gradio as gr
2
  import torch
3
+ from mammal.keys import (
4
+ CLS_PRED,
5
+ ENCODER_INPUTS_ATTENTION_MASK,
6
+ ENCODER_INPUTS_STR,
7
+ ENCODER_INPUTS_TOKENS,
8
+ )
9
  from mammal.model import Mammal
10
 
11
+ from mammal_demo.demo_framework import MammalObjectBroker, MammalTask
 
12
 
13
 
14
  class PpiTask(MammalTask):
 
21
  }
22
  self.markup_text = """
23
  # Mammal based {self.description} demonstration
24
+
25
  Given two protein sequences, estimate if the proteins interact or not."""
26
+
 
 
27
  @staticmethod
28
  def positive_token_id(model_holder: MammalObjectBroker):
29
  """token for positive binding
 
35
  int: id of positive binding token
36
  """
37
  return model_holder.tokenizer_op.get_token_id("<1>")
38
+
39
  def generate_prompt(self, prot1, prot2):
40
  """Formatting prompt to match pre-training syntax
41
 
 
45
 
46
  Returns:
47
  str: prompt
48
+ """
49
+ prompt = (
50
+ "<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_CLASS><SENTINEL_ID_0>"
51
+ + "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"
52
+ + f"<SEQUENCE_NATURAL_START>{prot1}<SEQUENCE_NATURAL_END>"
53
+ + "<MOLECULAR_ENTITY><MOLECULAR_ENTITY_GENERAL_PROTEIN>"
54
+ + f"<SEQUENCE_NATURAL_START>{prot2}<SEQUENCE_NATURAL_END><EOS>"
55
+ )
56
  return prompt
57
+
58
+ def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker):
 
59
  # Create and load sample
60
  sample_dict = dict()
61
  prompt = self.generate_prompt(*sample_inputs)
 
85
  max_new_tokens=5,
86
  )
87
  return batch_dict
88
+
89
+ def decode_output(self, batch_dict, model_holder: MammalObjectBroker):
90
 
91
  # Get output
92
+ generated_output = model_holder.tokenizer_op._tokenizer.decode(
93
+ batch_dict[CLS_PRED][0]
94
+ )
95
+ score = batch_dict["model.out.scores"][0][1][
96
+ self.positive_token_id(model_holder)
97
+ ].item()
98
 
99
  return generated_output, score
100
 
101
+ def create_and_run_prompt(self, model_name, protein1, protein2):
 
102
  model_holder = self.model_dict[model_name]
103
+ sample_inputs = {"prot1": protein1, "prot2": protein2}
104
+ sample_dict = self.crate_sample_dict(
105
+ sample_inputs=sample_inputs, model_holder=model_holder
106
+ )
107
  prompt = sample_dict[ENCODER_INPUTS_STR]
108
  batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model)
109
+ res = prompt, *self.decode_output(batch_dict, model_holder=model_holder)
110
  return res
111
 
112
+ def create_demo(self, model_name_widget: gr.component):
113
+
114
+ # """
115
+ # ### Using the model from
 
116
 
117
+ # ```{model} ```
118
+ # """
119
  with gr.Group() as demo:
120
  gr.Markdown(self.markup_text)
121
  with gr.Row():
 
135
  )
136
  with gr.Row():
137
  run_mammal: gr.Button = gr.Button(
138
+ "Run Mammal prompt for Protein-Protein Interaction",
139
+ variant="primary",
140
  )
141
  with gr.Row():
142
  prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
143
+ score_box = gr.Number(label="PPI score")
144
  with gr.Row():
145
  decoded = gr.Textbox(label="Mammal output")
146
  run_mammal.click(
147
  fn=self.create_and_run_prompt,
148
  inputs=[model_name_widget, prot1, prot2],
149
+ outputs=[prompt_box, decoded, score_box],
150
  )
151
  with gr.Row():
152
  gr.Markdown(
new_app.py DELETED
@@ -1,76 +0,0 @@
1
- import gradio as gr
2
- from mammal.keys import *
3
-
4
- from mammal_demo.demo_framework import MammalObjectBroker
5
-
6
-
7
- from mammal_demo.ppi_task import PpiTask
8
- from mammal_demo.dti_task import DtiTask
9
-
10
- all_tasks = dict()
11
- all_models= dict()
12
-
13
- ppi_task = PpiTask(model_dict = all_models)
14
- all_tasks[ppi_task.name]=ppi_task
15
-
16
- tdi_task = DtiTask(model_dict = all_models)
17
- all_tasks[tdi_task.name]=tdi_task
18
-
19
- ppi_model = MammalObjectBroker(model_path="ibm/biomed.omics.bl.sm.ma-ted-458m", task_list=[ppi_task.name])
20
- all_models[ppi_model.name]=ppi_model
21
-
22
- tdi_model = MammalObjectBroker(model_path="ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd", task_list=[tdi_task.name])
23
- all_models[tdi_model.name]=tdi_model
24
-
25
-
26
- def create_application():
27
- def task_change(value):
28
- visibility = [gr.update(visible=(task==value)) for task in all_tasks.keys()]
29
- # all_tasks[task].demo().visible =
30
- choices=[model_name for model_name, model in all_models.items() if value in model.tasks]
31
- if choices:
32
- return (gr.update(choices=choices, value=choices[0]),*visibility)
33
- else:
34
- return (gr.skip,*visibility)
35
- # return model_name_dropdown
36
-
37
-
38
- with gr.Blocks() as application:
39
- task_dropdown = gr.Dropdown(choices=["select demo"] + list(all_tasks.keys()))
40
- task_dropdown.interactive = True
41
- model_name_dropdown = gr.Dropdown(choices=[model_name for model_name, model in all_models.items() if task_dropdown.value in model.tasks], interactive=True)
42
-
43
-
44
-
45
-
46
-
47
- ppi_demo = all_tasks[ppi_task.name].demo(model_name_widgit = model_name_dropdown)
48
- # ppi_demo.visible = True
49
- dtb_demo = all_tasks[tdi_task.name].demo(model_name_widgit = model_name_dropdown)
50
-
51
- task_dropdown.change(task_change,inputs=[task_dropdown],outputs=[model_name_dropdown]+[all_tasks[task].demo() for task in all_tasks])
52
-
53
- # def set_demo_vis(main_text):
54
- # main_text=main_text
55
- # print(f"main text is {main_text}")
56
- # return gr.Group(visible=True)
57
- # #return gr.Group(visible=(main_text == "PPI"))
58
- # # , gr.Group( visible=(main_text == "DTI") )
59
-
60
-
61
- # task_dropdown.change(
62
- # set_ppi_vis, inputs=task_dropdown, outputs=[ppi_demo]
63
- # )
64
- return application
65
-
66
- full_demo=None
67
-
68
- def main():
69
- global full_demo
70
- full_demo = create_application()
71
- full_demo.launch(show_error=True, share=False)
72
-
73
-
74
- if __name__ == "__main__":
75
- main()
76
-