from abc import ABC, abstractmethod import gradio as gr from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp from mammal.model import Mammal class MammalObjectBroker: def __init__( self, model_path: str, name: str | None = None, task_list: list[str] | None = None, *, force_preload=False, ) -> None: self.model_path = model_path if name is None: name = model_path self.name = name self.tasks: list[str] = [] if task_list is not None: self.tasks = task_list self._model: Mammal | None = None self._tokenizer_op = None if force_preload: self.force_preload() @property def model(self) -> Mammal: if self._model is None: self._model = Mammal.from_pretrained(self.model_path) self._model.eval() return self._model @property def tokenizer_op(self): if self._tokenizer_op is None: self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path) return self._tokenizer_op def force_preload(self): """pre-load the model and tokenizer (in this order)""" _ = self.model _ = self.tokenizer_op class MammalTask(ABC): def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None: self.name = name self.description = None self._demo = None self.model_dict = model_dict @abstractmethod def crate_sample_dict( self, sample_inputs: dict, model_holder: MammalObjectBroker ) -> dict: """Formatting prompt to match pre-training syntax Args: prompt (str): _description_ Returns: dict: sample_dict for feeding into model """ raise NotImplementedError() # @abstractmethod def run_model(self, sample_dict, model: Mammal): raise NotImplementedError() def create_demo(self, model_name_widget: gr.component) -> gr.Group: """create an gradio demo group Args: model_name_widgit (gr.Component): widget holding the model name to use. This is needed to create gradio actions with the current model name as an input Raises: NotImplementedError: _description_ """ raise NotImplementedError() def demo(self, model_name_widgit: gr.component = None): if self._demo is None: self._demo = self.create_demo(model_name_widget=model_name_widgit) return self._demo @abstractmethod def decode_output(self, batch_dict, model: Mammal) -> list: raise NotImplementedError() # classification helpers @staticmethod def positive_token_id(tokenizer_op: ModularTokenizerOp) -> int: """token for positive binding Args: model (MammalTrainedModel): model holding tokenizer Returns: int: id of positive binding token """ return tokenizer_op.get_token_id("<1>") @staticmethod def negative_token_id(tokenizer_op: ModularTokenizerOp) -> int: """token for negative binding Args: model (MammalTrainedModel): model holding tokenizer Returns: int: id of negative binding token """ return tokenizer_op.get_token_id("<0>") @staticmethod def get_label_from_token(tokenizer_op: ModularTokenizerOp, token_id): label_mapping = { MammalTask.negative_token_id(tokenizer_op): "negative", MammalTask.positive_token_id(tokenizer_op): "positive", } return label_mapping.get(token_id, token_id) class TaskRegistry(dict[str, MammalTask]): """just a dictionary with a register method""" def register_task(self, task: MammalTask): self[task.name] = task return task.name class ModelRegistry(dict[str, MammalObjectBroker]): """just a dictionary with a register models""" def register_model( self, model_path, task_list=None, name=None, *, force_preload=False ): """register a model and return the name of the model Args: model_path (_type_): _description_ name (optional str): explicit name for the model Returns: str: model name """ model_holder = MammalObjectBroker( model_path=model_path, task_list=task_list, name=name, force_preload=force_preload, ) self[model_holder.name] = model_holder return model_holder.name