biomed-multi-alignment / demo_framework.py
matanninio's picture
refactor
93d0d1a
raw
history blame
2.94 kB
import gradio as gr
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
from mammal.keys import *
from mammal.model import Mammal
from abc import ABC, abstractmethod
class MammalObjectBroker():
def __init__(self, model_path: str, name:str= None, task_list: list[str]=None) -> None:
self.model_path = model_path
if name is None:
name = model_path
self.name = name
if task_list is not None:
self.tasks=task_list
else:
self.task = []
self._model = None
self._tokenizer_op = None
@property
def model(self)-> Mammal:
if self._model is None:
self._model = Mammal.from_pretrained(self.model_path)
self._model.eval()
return self._model
@property
def tokenizer_op(self):
if self._tokenizer_op is None:
self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
return self._tokenizer_op
class MammalTask(ABC):
def __init__(self, name:str) -> None:
self.name = name
self.description = None
self._demo = None
# @abstractmethod
# def _generate_prompt(self, **kwargs) -> str:
# """Formatting prompt to match pre-training syntax
# Args:
# prot1 (_type_): _description_
# prot2 (_type_): _description_
# Raises:
# No: _description_
# """
# raise NotImplementedError()
@abstractmethod
def crate_sample_dict(self,sample_inputs: dict, model_holder:MammalObjectBroker) -> dict:
"""Formatting prompt to match pre-training syntax
Args:
prompt (str): _description_
Returns:
dict: sample_dict for feeding into model
"""
raise NotImplementedError()
# @abstractmethod
def run_model(self, sample_dict, model:Mammal):
raise NotImplementedError()
def create_demo(self, model_name_widget: gr.component) -> gr.Group:
"""create an gradio demo group
Args:
model_name_widgit (gr.Component): widget holding the model name to use. This is needed to create
gradio actions with the current model name as an input
Raises:
NotImplementedError: _description_
"""
raise NotImplementedError()
def demo(self,model_name_widgit:gr.component=None):
if self._demo is None:
model_name_widget:gr.component
self._demo = self.create_demo(model_name_widget=model_name_widgit)
return self._demo
@abstractmethod
def decode_output(self,batch_dict, model:Mammal):
raise NotImplementedError()
#self._setup()
# def _setup(self):
# pass