Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod | |
import gradio as gr | |
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp | |
from mammal.model import Mammal | |
class MammalObjectBroker: | |
def __init__( | |
self, | |
model_path: str, | |
name: str | None = None, | |
task_list: list[str] | None = None, | |
) -> None: | |
self.model_path = model_path | |
if name is None: | |
name = model_path | |
self.name = name | |
self.tasks: list[str] = [] | |
if task_list is not None: | |
self.tasks = task_list | |
self._model: Mammal | None = None | |
self._tokenizer_op = None | |
def model(self) -> Mammal: | |
if self._model is None: | |
self._model = Mammal.from_pretrained(self.model_path) | |
self._model.eval() | |
return self._model | |
def tokenizer_op(self): | |
if self._tokenizer_op is None: | |
self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path) | |
return self._tokenizer_op | |
class MammalTask(ABC): | |
def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None: | |
self.name = name | |
self.description = None | |
self._demo = None | |
self.model_dict = model_dict | |
# @abstractmethod | |
# def _generate_prompt(self, **kwargs) -> str: | |
# """Formatting prompt to match pre-training syntax | |
# Args: | |
# prot1 (_type_): _description_ | |
# prot2 (_type_): _description_ | |
# Raises: | |
# No: _description_ | |
# """ | |
# raise NotImplementedError() | |
def crate_sample_dict( | |
self, sample_inputs: dict, model_holder: MammalObjectBroker | |
) -> dict: | |
"""Formatting prompt to match pre-training syntax | |
Args: | |
prompt (str): _description_ | |
Returns: | |
dict: sample_dict for feeding into model | |
""" | |
raise NotImplementedError() | |
# @abstractmethod | |
def run_model(self, sample_dict, model: Mammal): | |
raise NotImplementedError() | |
def create_demo(self, model_name_widget: gr.component) -> gr.Group: | |
"""create an gradio demo group | |
Args: | |
model_name_widgit (gr.Component): widget holding the model name to use. This is needed to create | |
gradio actions with the current model name as an input | |
Raises: | |
NotImplementedError: _description_ | |
""" | |
raise NotImplementedError() | |
def demo(self, model_name_widgit: gr.component = None): | |
if self._demo is None: | |
self._demo = self.create_demo(model_name_widget=model_name_widgit) | |
return self._demo | |
def decode_output(self, batch_dict, model: Mammal) -> list: | |
raise NotImplementedError() | |
# self._setup() | |
# def _setup(self): | |
# pass | |
class TaskRegistry(dict[str, MammalTask]): | |
"""just a dictionary with a register method""" | |
def register_task(self, task: MammalTask): | |
self[task.name] = task | |
return task.name | |
class ModelRegistry(dict[str, MammalObjectBroker]): | |
"""just a dictionary with a register models""" | |
def register_model(self, model_path, task_list=None, name=None): | |
"""register a model and return the name of the model | |
Args: | |
model_path (_type_): _description_ | |
name (optional str): explicit name for the model | |
Returns: | |
str: model name | |
""" | |
model_holder = MammalObjectBroker( | |
model_path=model_path, task_list=task_list, name=name | |
) | |
self[model_holder.name] = model_holder | |
return model_holder.name | |