import gradio as gr from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask from mammal.keys import * from mammal.model import Mammal from abc import ABC, abstractmethod class MammalObjectBroker(): def __init__(self, model_path: str, name:str= None, task_list: list[str]=None) -> None: self.model_path = model_path if name is None: name = model_path self.name = name if task_list is not None: self.tasks=task_list else: self.task = [] self._model = None self._tokenizer_op = None @property def model(self)-> Mammal: if self._model is None: self._model = Mammal.from_pretrained(self.model_path) self._model.eval() return self._model @property def tokenizer_op(self): if self._tokenizer_op is None: self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path) return self._tokenizer_op class MammalTask(ABC): def __init__(self, name:str) -> None: self.name = name self.description = None self._demo = None # @abstractmethod # def _generate_prompt(self, **kwargs) -> str: # """Formatting prompt to match pre-training syntax # Args: # prot1 (_type_): _description_ # prot2 (_type_): _description_ # Raises: # No: _description_ # """ # raise NotImplementedError() @abstractmethod def crate_sample_dict(self,sample_inputs: dict, model_holder:MammalObjectBroker) -> dict: """Formatting prompt to match pre-training syntax Args: prompt (str): _description_ Returns: dict: sample_dict for feeding into model """ raise NotImplementedError() # @abstractmethod def run_model(self, sample_dict, model:Mammal): raise NotImplementedError() def create_demo(self, model_name_widget: gr.component) -> gr.Group: """create an gradio demo group Args: model_name_widgit (gr.Component): widget holding the model name to use. This is needed to create gradio actions with the current model name as an input Raises: NotImplementedError: _description_ """ raise NotImplementedError() def demo(self,model_name_widgit:gr.component=None): if self._demo is None: model_name_widget:gr.component self._demo = self.create_demo(model_name_widget=model_name_widgit) return self._demo @abstractmethod def decode_output(self,batch_dict, model:Mammal): raise NotImplementedError() #self._setup() # def _setup(self): # pass