import gradio as gr from mammal.examples.molnet.molnet_infer import create_sample_dict as molnet_create_sample_dict, get_predictions, process_model_output from mammal.keys import * from mammal.model import Mammal from mammal_demo.demo_framework import MammalObjectBroker, MammalTask class MolnetTask(MammalTask): def __init__(self, model_dict, task_name="BBBP", name=None): if name is None: name=f"Molnet: {task_name}" super().__init__(name=name, model_dict=model_dict) self.description = f"MOLNET {task_name}" self.examples = { "drug_seq": "CC(=O)NCCC1=CNc2c1cc(OC)cc2", } self.task_name=task_name self.markup_text = """ # Mammal demonstration """ def crate_sample_dict(self, sample_inputs: dict, model_holder: MammalObjectBroker) -> dict: return molnet_create_sample_dict(task_name=self.task_name, smiles_seq=sample_inputs["drug_seq"], tokenizer_op=model_holder.tokenizer_op, model=model_holder.model) def run_model(self, sample_dict, model: Mammal): # Generate Prediction batch_dict = get_predictions(model=model,sample_dict=sample_dict) return batch_dict def decode_output(self, batch_dict, model_holder): result = process_model_output( tokenizer_op=model_holder.tokenizer_op, decoder_output=batch_dict[CLS_PRED][0], decoder_output_scores=batch_dict[SCORES][0], ) generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]) return generated_output, result['pred'], result['score'] def create_and_run_prompt(self, model_name, drug_seq): model_holder = self.model_dict[model_name] inputs = { "drug_seq": drug_seq, } sample_dict = self.crate_sample_dict( sample_inputs=inputs, model_holder=model_holder ) prompt = sample_dict[ENCODER_INPUTS_STR] batch_dict = self.run_model(sample_dict=sample_dict, model=model_holder.model) res = prompt, *self.decode_output(batch_dict, model_holder=model_holder) return res def create_demo(self, model_name_widget): # """ # ### Using the model from # ```{model} ``` # """ with gr.Group() as demo: gr.Markdown(self.markup_text) with gr.Row(): drug_textbox = gr.Textbox( label="Drug sequance (in SMILES)", # info="standard", interactive=True, lines=3, value=self.examples["drug_seq"], ) with gr.Row(): run_mammal = gr.Button( "Run Mammal prompt for task", variant="primary", ) with gr.Row(): prompt_box = gr.Textbox(label="Mammal prompt", lines=5) with gr.Row(): decoded = gr.Textbox(label="Mammal output") prediction_box=gr.Textbox(label="Mammal prediction") score_box=gr.Number(label="score") run_mammal.click( fn=self.create_and_run_prompt, inputs=[model_name_widget, drug_textbox], outputs=[prompt_box, decoded, prediction_box, score_box], ) demo.visible = False return demo