File size: 2,940 Bytes
93d0d1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
from mammal.keys import *
from mammal.model import Mammal
from abc import ABC, abstractmethod

    



class MammalObjectBroker():
    def __init__(self, model_path: str, name:str= None, task_list: list[str]=None) -> None:
        self.model_path = model_path
        if name is None:
            name = model_path
        self.name = name        
        
        if task_list is not None:
            self.tasks=task_list
        else:
            self.task = []
        self._model = None
        self._tokenizer_op = None
        
        
    @property
    def model(self)-> Mammal:
        if self._model is None:
            self._model =  Mammal.from_pretrained(self.model_path)
            self._model.eval()
        return self._model
    
    @property
    def tokenizer_op(self):
        if self._tokenizer_op is None:
            self._tokenizer_op =  ModularTokenizerOp.from_pretrained(self.model_path)
        return self._tokenizer_op
    
    
    

class MammalTask(ABC):
    def __init__(self, name:str) -> None:
            self.name = name
            self.description = None
            self._demo = None

    # @abstractmethod
    # def _generate_prompt(self, **kwargs) -> str:
    #     """Formatting prompt to match pre-training syntax

    #     Args:
    #         prot1 (_type_): _description_
    #         prot2 (_type_): _description_

    #     Raises:
    #         No: _description_
    #     """
    #     raise NotImplementedError()

    @abstractmethod
    def crate_sample_dict(self,sample_inputs: dict, model_holder:MammalObjectBroker) -> dict:
        """Formatting prompt to match pre-training syntax

        Args:
            prompt (str): _description_

        Returns:
            dict: sample_dict for feeding into model
        """
        raise NotImplementedError()

    # @abstractmethod
    def run_model(self, sample_dict, model:Mammal):
        raise NotImplementedError()
    
    def create_demo(self, model_name_widget: gr.component) -> gr.Group:
        """create an gradio demo group

        Args:
            model_name_widgit (gr.Component): widget holding the model name to use.  This is needed to create
                gradio actions with the current model name as an input


        Raises:
            NotImplementedError: _description_
        """
        raise NotImplementedError()


    
    def demo(self,model_name_widgit:gr.component=None):
        if self._demo is None:
            model_name_widget:gr.component
            self._demo = self.create_demo(model_name_widget=model_name_widgit)
        return self._demo

    @abstractmethod
    def decode_output(self,batch_dict, model:Mammal):
        raise NotImplementedError()

    #self._setup()
        
    # def _setup(self):
    #     pass