File size: 2,920 Bytes
19dfa7a
 
93d0d1a
 
 
 
 
19dfa7a
 
 
 
 
 
 
93d0d1a
 
 
19dfa7a
 
 
93d0d1a
19dfa7a
 
93d0d1a
19dfa7a
93d0d1a
19dfa7a
93d0d1a
19dfa7a
 
93d0d1a
19dfa7a
93d0d1a
 
 
19dfa7a
93d0d1a
19dfa7a
93d0d1a
 
19dfa7a
 
 
 
 
93d0d1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19dfa7a
 
 
93d0d1a
 
 
 
 
 
 
 
 
 
 
19dfa7a
93d0d1a
19dfa7a
93d0d1a
 
 
 
 
 
 
 
 
 
 
 
 
19dfa7a
93d0d1a
19dfa7a
93d0d1a
 
 
 
19dfa7a
93d0d1a
 
19dfa7a
 
93d0d1a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from abc import ABC, abstractmethod

import gradio as gr
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.model import Mammal


class MammalObjectBroker:
    def __init__(
        self,
        model_path: str,
        name: str | None = None,
        task_list: list[str] | None = None,
    ) -> None:
        self.model_path = model_path
        if name is None:
            name = model_path
        self.name = name

        self.tasks: list[str] = []
        if task_list is not None:
            self.tasks = task_list
        self._model: Mammal | None = None
        self._tokenizer_op = None

    @property
    def model(self) -> Mammal:
        if self._model is None:
            self._model = Mammal.from_pretrained(self.model_path)
        self._model.eval()
        return self._model

    @property
    def tokenizer_op(self):
        if self._tokenizer_op is None:
            self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
        return self._tokenizer_op


class MammalTask(ABC):
    def __init__(self, name: str, model_dict: dict[str, MammalObjectBroker]) -> None:
        self.name = name
        self.description = None
        self._demo = None
        self.model_dict = model_dict

    # @abstractmethod
    # def _generate_prompt(self, **kwargs) -> str:
    #     """Formatting prompt to match pre-training syntax

    #     Args:
    #         prot1 (_type_): _description_
    #         prot2 (_type_): _description_

    #     Raises:
    #         No: _description_
    #     """
    #     raise NotImplementedError()

    @abstractmethod
    def crate_sample_dict(
        self, sample_inputs: dict, model_holder: MammalObjectBroker
    ) -> dict:
        """Formatting prompt to match pre-training syntax

        Args:
            prompt (str): _description_

        Returns:
            dict: sample_dict for feeding into model
        """
        raise NotImplementedError()

    # @abstractmethod
    def run_model(self, sample_dict, model: Mammal):
        raise NotImplementedError()

    def create_demo(self, model_name_widget: gr.component) -> gr.Group:
        """create an gradio demo group

        Args:
            model_name_widgit (gr.Component): widget holding the model name to use.  This is needed to create
                gradio actions with the current model name as an input


        Raises:
            NotImplementedError: _description_
        """
        raise NotImplementedError()

    def demo(self, model_name_widgit: gr.component = None):
        if self._demo is None:
            model_name_widget: gr.component
            self._demo = self.create_demo(model_name_widget=model_name_widgit)
        return self._demo

    @abstractmethod
    def decode_output(self, batch_dict, model: Mammal):
        raise NotImplementedError()

    # self._setup()

    # def _setup(self):
    #     pass