matanninio commited on
Commit
72dbfd7
·
1 Parent(s): 022cccc

both first demos now work

Browse files
Files changed (1) hide show
  1. new_app.py +14 -16
new_app.py CHANGED
@@ -291,21 +291,23 @@ Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
291
 
292
  def run_model(self, sample_dict, model: Mammal):
293
  # Generate Prediction
294
- batch_dict = model.generate(
295
- [sample_dict],
296
- output_scores=True,
297
- return_dict_in_generate=True,
298
- max_new_tokens=5,
299
- )
300
  return batch_dict
301
 
302
  def decode_output(self,batch_dict, model_holder):
303
 
304
  # Get output
305
- generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
306
- score = batch_dict["model.out.scores"][0][1][self.positive_token_id(model_holder)].item()
307
-
308
- return generated_output, score
 
 
 
 
 
 
 
309
 
310
 
311
  def create_and_run_prompt(self,model_name,target_seq, drug_seq):
@@ -353,15 +355,11 @@ Given a protein sequence and a drug (in SMILES), estimate the binding affinity.
353
  prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
354
 
355
  with gr.Row():
356
- decoded = gr.Textbox(label="Mammal output")
357
  run_mammal.click(
358
  fn=self.create_and_run_prompt,
359
  inputs=[model_name_widget, target_textbox, drug_textbox],
360
- outputs=[prompt_box, decoded, gr.Number(label="PPI score")],
361
- )
362
- with gr.Row():
363
- gr.Markdown(
364
- "```<SENTINEL_ID_0>``` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting"
365
  )
366
  demo.visible = False
367
  return demo
 
291
 
292
  def run_model(self, sample_dict, model: Mammal):
293
  # Generate Prediction
294
+ batch_dict = model.forward_encoder_only([sample_dict])
 
 
 
 
 
295
  return batch_dict
296
 
297
  def decode_output(self,batch_dict, model_holder):
298
 
299
  # Get output
300
+ batch_dict = DtiBindingdbKdTask.process_model_output(
301
+ batch_dict,
302
+ scalars_preds_processed_key="model.out.dti_bindingdb_kd",
303
+ norm_y_mean=5.79384684128215,
304
+ norm_y_std=1.33808027428196,
305
+ )
306
+ ans = (
307
+ "model.out.dti_bindingdb_kd",
308
+ float(batch_dict["model.out.dti_bindingdb_kd"][0]),
309
+ )
310
+ return ans
311
 
312
 
313
  def create_and_run_prompt(self,model_name,target_seq, drug_seq):
 
355
  prompt_box = gr.Textbox(label="Mammal prompt", lines=5)
356
 
357
  with gr.Row():
358
+ decoded = gr.Textbox(label="Mammal output key")
359
  run_mammal.click(
360
  fn=self.create_and_run_prompt,
361
  inputs=[model_name_widget, target_textbox, drug_textbox],
362
+ outputs=[prompt_box, decoded, gr.Number(label="binding affinity")],
 
 
 
 
363
  )
364
  demo.visible = False
365
  return demo