matanninio's picture
cleanup and normalization of tasks
b93c8a7
raw
history blame
4.65 kB
from abc import ABC, abstractmethod
import gradio as gr
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.model import Mammal
class MammalObjectBroker:
def __init__(
self,
model_path: str,
name: str | None = None,
task_list: list[str] | None = None,
*,
force_preload=False,
) -> None:
self.model_path = model_path
if name is None:
name = model_path
self.name = name
self.tasks: list[str] = []
if task_list is not None:
self.tasks = task_list
self._model: Mammal | None = None
self._tokenizer_op = None
if force_preload:
self.force_preload()
@property
def model(self) -> Mammal:
if self._model is None:
self._model = Mammal.from_pretrained(self.model_path)
self._model.eval()
return self._model
@property
def tokenizer_op(self):
if self._tokenizer_op is None:
self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
return self._tokenizer_op
def force_preload(self):
"""pre-load the model and tokenizer (in this order)"""
_ = self.model
_ = self.tokenizer_op
class MammalTask(ABC):
def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None:
self.name = name
self.description = None
self._demo = None
self.model_dict = model_dict
@abstractmethod
def crate_sample_dict(
self, sample_inputs: dict, model_holder: MammalObjectBroker
) -> dict:
"""Formatting prompt to match pre-training syntax
Args:
prompt (str): _description_
Returns:
dict: sample_dict for feeding into model
"""
raise NotImplementedError()
# @abstractmethod
def run_model(self, sample_dict, model: Mammal):
raise NotImplementedError()
def create_demo(self, model_name_widget: gr.component) -> gr.Group:
"""create an gradio demo group
Args:
model_name_widgit (gr.Component): widget holding the model name to use. This is needed to create
gradio actions with the current model name as an input
Raises:
NotImplementedError: _description_
"""
raise NotImplementedError()
def demo(self, model_name_widgit: gr.component = None):
if self._demo is None:
self._demo = self.create_demo(model_name_widget=model_name_widgit)
return self._demo
@abstractmethod
def decode_output(self, batch_dict, model: Mammal) -> list:
raise NotImplementedError()
# classification helpers
@staticmethod
def positive_token_id(tokenizer_op: ModularTokenizerOp) -> int:
"""token for positive binding
Args:
model (MammalTrainedModel): model holding tokenizer
Returns:
int: id of positive binding token
"""
return tokenizer_op.get_token_id("<1>")
@staticmethod
def negative_token_id(tokenizer_op: ModularTokenizerOp) -> int:
"""token for negative binding
Args:
model (MammalTrainedModel): model holding tokenizer
Returns:
int: id of negative binding token
"""
return tokenizer_op.get_token_id("<0>")
@staticmethod
def get_label_from_token(tokenizer_op: ModularTokenizerOp, token_id):
label_mapping = {
MammalTask.negative_token_id(tokenizer_op): "negative",
MammalTask.positive_token_id(tokenizer_op): "positive",
}
return label_mapping.get(token_id, token_id)
class TaskRegistry(dict[str, MammalTask]):
"""just a dictionary with a register method"""
def register_task(self, task: MammalTask):
self[task.name] = task
return task.name
class ModelRegistry(dict[str, MammalObjectBroker]):
"""just a dictionary with a register models"""
def register_model(
self, model_path, task_list=None, name=None, *, force_preload=False
):
"""register a model and return the name of the model
Args:
model_path (_type_): _description_
name (optional str): explicit name for the model
Returns:
str: model name
"""
model_holder = MammalObjectBroker(
model_path=model_path,
task_list=task_list,
name=name,
force_preload=force_preload,
)
self[model_holder.name] = model_holder
return model_holder.name