{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import torch\n", "from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp\n", "from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask\n", "from mammal.keys import *\n", "from mammal.model import Mammal\n", "from abc import ABC, abstractmethod\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MammalObjectBroker():\n", " def __init__(self, model_path: str, name:str= None, task_list: list[str]=None) -> None:\n", " self.model_path = model_path\n", " if name is None:\n", " name = model_path\n", " self.name = name \n", " \n", " if task_list is not None:\n", " self.tasks=task_list\n", " else:\n", " self.task = []\n", " self._model = None\n", " self._tokenizer_op = None\n", " \n", " \n", " @property\n", " def model(self)-> Mammal:\n", " if self._model is None:\n", " self._model = Mammal.from_pretrained(self.model_path)\n", " self._model.eval()\n", " return self._model\n", " \n", " @property\n", " def tokenizer_op(self):\n", " if self._tokenizer_op is None:\n", " self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)\n", " return self._tokenizer_op\n", " \n", " \n", " \n", " \n", "\n", "class MammalTask(ABC):\n", " def __init__(self, name:str) -> None:\n", " self.name = name\n", " self.description = None\n", " self._demo = None\n", "\n", " @abstractmethod\n", " def generate_prompt(self, **kwargs) -> str:\n", " \"\"\"Formatting prompt to match pre-training syntax\n", "\n", " Args:\n", " prot1 (_type_): _description_\n", " prot2 (_type_): _description_\n", "\n", " Raises:\n", " No: _description_\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", " @abstractmethod\n", " def crate_sample_dict(self, prompt: str, **kwargs) -> dict:\n", " \"\"\"Formatting prompt to match pre-training syntax\n", "\n", " Args:\n", " prompt (str): _description_\n", "\n", " Returns:\n", " dict: sample_dict for feeding into model\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", " @abstractmethod\n", " def run_model(_, sample_dict, model:Mammal):\n", " raise NotImplementedError()\n", " \n", " def decode_output(self,batch_dict, model):\n", " pass\n", "\n", " @abstractmethod\n", " def create_demo(self):\n", " \"\"\"create an gradio demo group\n", "\n", " Returns:\n", " _type_: _description_\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", " \n", " def demo(self,model_dropdown=None):\n", " if self._demo is None:\n", " self._demo = self.create_demo(model_dropdown)\n", " return self._demo\n", "\n", " @abstractmethod\n", " def decode_output(self,batch_dict, model:Mammal):\n", " raise NotImplementedError()\n", "\n", " #self._setup()\n", " \n", " # def _setup(self):\n", " # pass\n", " \n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_tasks = dict()\n", "all_models= dict()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "class PpiTask(MammalTask):\n", " def __init__(self):\n", " super().__init__(name=\"PPI\")\n", " self.description = \"Protein-Protein Interaction (PPI)\"\n", " self.examples = {\n", " \"protein_calmodulin\": ,\"MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK\"\n", " \"protein_calcineurin\": \"MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ\",\n", " }\n", " self.markup_text = \"\"\"\n", " # Mammal based {self.description} demonstration\n", " \n", " Given two protein sequences, estimate if the proteins interact or not.\"\"\"\n", " \n", " \n", " \n", " @staticmethod\n", " def positive_token_id(model_holder: MammalObjectBroker):\n", " \"\"\"token for positive binding\n", "\n", " Args:\n", " model (MammalTrainedModel): model holding tokenizer\n", "\n", " Returns:\n", " int: id of positive binding token\n", " \"\"\"\n", " return model_holder.tokenizer_op.get_token_id(\"<1>\")\n", " \n", " def generate_prompt(self, prot1, prot2):\n", " \"\"\"Formatting prompt to match pre-training syntax\n", "\n", " Args:\n", " prot1 (str): sequance of protein number 1\n", " prot2 (str): sequance of protein number 2\n", "\n", " Returns:\n", " str: prompt\n", " \"\"\" \n", " prompt = \"<@TOKENIZER-TYPE=AA>\"\\\n", " \"\"\\\n", " f\"{prot1}\"\\\n", " \"\"\\\n", " f\"{prot2}\"\n", " return prompt\n", " \n", " \n", " def crate_sample_dict(self,prompt: str, model_holder:MammalObjectBroker):\n", " # Create and load sample\n", " sample_dict = dict()\n", " sample_dict[ENCODER_INPUTS_STR] = prompt\n", "\n", " # Tokenize\n", " sample_dict = model_holder.tokenizer_op(\n", " sample_dict=sample_dict,\n", " key_in=ENCODER_INPUTS_STR,\n", " key_out_tokens_ids=ENCODER_INPUTS_TOKENS,\n", " key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,\n", " )\n", " sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(\n", " sample_dict[ENCODER_INPUTS_TOKENS]\n", " )\n", " sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(\n", " sample_dict[ENCODER_INPUTS_ATTENTION_MASK]\n", " )\n", " return sample_dict\n", "\n", " def run_model(_, sample_dict, model: Mammal):\n", " # Generate Prediction\n", " batch_dict = model.generate(\n", " [sample_dict],\n", " output_scores=True,\n", " return_dict_in_generate=True,\n", " max_new_tokens=5,\n", " )\n", " return batch_dict\n", " \n", " def decode_output(self,batch_dict, model_holder):\n", "\n", " # Get output\n", " generated_output = model_holder.tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])\n", " score = batch_dict[\"model.out.scores\"][0][1][self.positive_token_id(model_holder)].item()\n", "\n", " return generated_output, score\n", "\n", "\n", " def create_and_run_prompt(self,model_name,protein1, protein2):\n", " model_holder = all_models[model_name]\n", " prompt = self.generate_prompt(protein1, protein2)\n", " sample_dict = self.crate_sample_dict(prompt=prompt, model_holder=model_holder)\n", " model_output = self.run_model(sample_dict=sample_dict, model=model_holder.model)\n", " res = prompt, *model_output\n", " return res\n", "\n", " \n", " def create_demo(self,model_name_dropdown):\n", " \n", " # \"\"\"\n", " # ### Using the model from\n", "\n", " # ```{model} ```\n", " # \"\"\"\n", " with gr.Group() as demo:\n", " gr.Markdown(self.markup_text)\n", " with gr.Row():\n", " prot1 = gr.Textbox(\n", " label=\"Protein 1 sequence\",\n", " # info=\"standard\",\n", " interactive=True,\n", " lines=3,\n", " value=self.examples[\"protein_calmodulin\"],\n", " )\n", " prot2 = gr.Textbox(\n", " label=\"Protein 2 sequence\",\n", " # info=\"standard\",\n", " interactive=True,\n", " lines=3,\n", " value=self.examples[\"protein_calcineurin\"],\n", " )\n", " with gr.Row():\n", " run_mammal = gr.Button(\n", " \"Run Mammal prompt for Protein-Protein Interaction\", variant=\"primary\"\n", " )\n", " with gr.Row():\n", " prompt_box = gr.Textbox(label=\"Mammal prompt\", lines=5)\n", "\n", " with gr.Row():\n", " decoded = gr.Textbox(label=\"Mammal output\")\n", " run_mammal.click(\n", " fn=self.create_and_run_prompt,\n", " inputs=[model_name_dropdown, prot1, prot2],\n", " outputs=[prompt_box, decoded, gr.Number(label=\"PPI score\")],\n", " )\n", " with gr.Row():\n", " gr.Markdown(\n", " \"`````` contains the binding affinity class, which is ```<1>``` for interacting and ```<0>``` for non-interacting\"\n", " )\n", " demo.visible = True\n", " return demo\n", "\n", "ppi_task = PpiTask()\n", "all_tasks[ppi_task.name]=ppi_task\n", "all_tasks" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "### DTI:\n", "\n", "#\n", "dti = \"Drug-Target Binding Affinity\"\n", "\n", "\n", "\n", "\n", "\n", "# input\n", "target_seq = \"NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC\"\n", "drug_seq = \"CC(=O)NCCC1=CNc2c1cc(OC)cc2\"\n", "\n", "\n", "# token for positive binding\n", "positive_token_id = tokenizer_op[dti].get_token_id(\"<1>\")\n", "\n", "\n", "def generate_prompt_dti(prot, drug):\n", " sample_dict = {\"target_seq\": target_seq, \"drug_seq\": drug_seq}\n", " sample_dict = DtiBindingdbKdTask.data_preprocessing(\n", " sample_dict=sample_dict,\n", " tokenizer_op=tokenizer_op[dti],\n", " target_sequence_key=\"target_seq\",\n", " drug_sequence_key=\"drug_seq\",\n", " norm_y_mean=None,\n", " norm_y_std=None,\n", " device=models[dti].device,\n", " )\n", " return sample_dict\n", "\n", "\n", "def create_and_run_prompt_dtb(prot, drug):\n", " sample_dict = generate_prompt_dti(prot, drug)\n", " # Post-process the model's output\n", " # batch_dict = model_dti.forward_encoder_only([sample_dict])\n", " batch_dict = models[dti].forward_encoder_only([sample_dict])\n", " batch_dict = DtiBindingdbKdTask.process_model_output(\n", " batch_dict,\n", " scalars_preds_processed_key=\"model.out.dti_bindingdb_kd\",\n", " norm_y_mean=5.79384684128215,\n", " norm_y_std=1.33808027428196,\n", " )\n", " ans = [\n", " \"model.out.dti_bindingdb_kd\",\n", " float(batch_dict[\"model.out.dti_bindingdb_kd\"][0]),\n", " ]\n", " res = sample_dict[\"data.query.encoder_input\"], *ans\n", " return res\n", "\n", "\n", "def create_tdb_demo():\n", " markup_text = f\"\"\"\n", "# Mammal based Target-Drug binding affinity demonstration\n", "\n", "Given a protein sequence and a drug (in SMILES), estimate the binding affinity.\n", "\n", "### Using the model from\n", "\n", " ```{model_paths[dti]} ```\n", "\"\"\"\n", " with gr.Group() as tdb_demo:\n", " gr.Markdown(markup_text)\n", " with gr.Row():\n", " prot = gr.Textbox(\n", " label=\"Protein sequence\",\n", " # info=\"standard\",\n", " interactive=True,\n", " lines=3,\n", " value=target_seq,\n", " )\n", " drug = gr.Textbox(\n", " label=\"drug sequence (SMILES)\",\n", " # info=\"standard\",\n", " interactive=True,\n", " lines=3,\n", " value=drug_seq,\n", " )\n", " with gr.Row():\n", " run_mammal = gr.Button(\n", " \"Run Mammal prompt for Target Drug Affinity\", variant=\"primary\"\n", " )\n", " with gr.Row():\n", " prompt_box = gr.Textbox(label=\"Mammal prompt\", lines=5)\n", "\n", " with gr.Row():\n", " decoded = gr.Textbox(label=\"Mammal output\")\n", " run_mammal.click(\n", " fn=create_and_run_prompt_dtb,\n", " inputs=[prot, drug],\n", " outputs=[prompt_box, decoded, gr.Number(label=\"DTI score\")],\n", " )\n", " tdb_demo.visible = False\n", " return tdb_demo\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "ppi_model = MammalObjectBroker(model_path=\"ibm/biomed.omics.bl.sm.ma-ted-458m\", task_list=[\"PPI\"])\n", "\n", "all_models[ppi_model.name]=ppi_model\n", "# tdi_model = MammalTrainedModel(model_path=\"ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd\") TODO: ## task list still empty\n", "# all_models.append(tdi_model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "def create_application():\n", " def task_change(value):\n", " choices=[model_name for model_name, model in all_models.items() if value in model.tasks]\n", " if choices:\n", " return gr.update(choices=choices, value=choices[0])\n", " else:\n", " return\n", " # return model_dropdown\n", " \n", " \n", " with gr.Blocks() as demo:\n", " task_dropdown = gr.Dropdown(choices=[\"select demo\"] + list(all_tasks.keys()))\n", " task_dropdown.interactive = True\n", " model_dropdown = gr.Dropdown(choices=[model_name for model_name, model in all_models.items() if task_dropdown.value in model.tasks], interactive=True)\n", " task_dropdown.change(task_change,inputs=[task_dropdown],outputs=[model_dropdown])\n", " \n", " \n", "\n", "\n", "\n", " ppi_demo = all_tasks[\"PPI\"].demo(model_dropdown = model_dropdown)\n", " ppi_demo.visible = True\n", " # dtb_demo = create_tdb_demo()\n", "\n", " def set_ppi_vis(main_text):\n", " main_text=main_text\n", " print(f\"main text is {main_text}\")\n", " return gr.Group(visible=True)\n", " #return gr.Group(visible=(main_text == \"PPI\"))\n", " # , gr.Group( visible=(main_text == \"DTI\") )\n", "\n", " task_dropdown.change(\n", " set_ppi_vis, inputs=task_dropdown, outputs=[ppi_demo]\n", " )\n", " return demo\n", "\n", "full_demo=None\n", "def main():\n", " global full_demo\n", " full_demo = create_application()\n", " full_demo.launch(show_error=True, share=False)\n", "\n", "\n", "if __name__ == \"__main__\":\n", " main()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for model_name, model_holder in all_models.items():\n", " print(model_name)\n", " print(model_holder.tasks, \"PPI\" in model_holder.tasks)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "full_demo.blocks[240].EVENTS" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from mammal.examples.tcr_epitope_binding.main_infer import load_model, task_infer\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "\n", "tcr_beta_seq = \"NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT\"\n", "epitope_seq = \"LLQTGIHVRVSQPSL\"\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Path doesn't exist. Will try to download fron hf hub. pretrained_model_name_or_path='ibm/biomed.omics.bl.sm.ma-ted-458m'\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9d3f97cfb3784a95a974b73b5bdbf0cc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 10 files: 0%| | 0/10 [00:00 7\u001b[0m \u001b[43mall_models\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mregister_model\u001b[49m \u001b[38;5;241m=\u001b[39m register_model\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__get__\u001b[39m(all_models, \u001b[38;5;28mdict\u001b[39m)\n\u001b[1;32m 8\u001b[0m all_models\u001b[38;5;241m.\u001b[39mregister_model(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel3\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(all_models)\n", "\u001b[0;31mAttributeError\u001b[0m: 'dict' object has no attribute 'register_model'" ] } ], "source": [ "# Assisted by watsonx Code Assistant \n", "all_models = {'model1': 'model1_path', 'model2': 'model2_path'}\n", "\n", "def register_model(self, name):\n", " self.update({name: f'{name}_path'})\n", "\n", "all_models.register_model = register_model.__get__(all_models, dict)\n", "all_models.register_model(\"model3\")\n", "print(all_models)\n" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "class AllModels(dict):\n", " def register_model(self, name):\n", " self.update({name: f'{name}_path'})\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_models=AllModels()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "all_models.register_model(\"abc\")" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'abc': 'abc_path'}" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all_models" ] } ], "metadata": { "kernelspec": { "display_name": "mammal", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 2 }