Spaces:
Running
Running
File size: 2,940 Bytes
93d0d1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
from mammal.keys import *
from mammal.model import Mammal
from abc import ABC, abstractmethod
class MammalObjectBroker():
def __init__(self, model_path: str, name:str= None, task_list: list[str]=None) -> None:
self.model_path = model_path
if name is None:
name = model_path
self.name = name
if task_list is not None:
self.tasks=task_list
else:
self.task = []
self._model = None
self._tokenizer_op = None
@property
def model(self)-> Mammal:
if self._model is None:
self._model = Mammal.from_pretrained(self.model_path)
self._model.eval()
return self._model
@property
def tokenizer_op(self):
if self._tokenizer_op is None:
self._tokenizer_op = ModularTokenizerOp.from_pretrained(self.model_path)
return self._tokenizer_op
class MammalTask(ABC):
def __init__(self, name:str) -> None:
self.name = name
self.description = None
self._demo = None
# @abstractmethod
# def _generate_prompt(self, **kwargs) -> str:
# """Formatting prompt to match pre-training syntax
# Args:
# prot1 (_type_): _description_
# prot2 (_type_): _description_
# Raises:
# No: _description_
# """
# raise NotImplementedError()
@abstractmethod
def crate_sample_dict(self,sample_inputs: dict, model_holder:MammalObjectBroker) -> dict:
"""Formatting prompt to match pre-training syntax
Args:
prompt (str): _description_
Returns:
dict: sample_dict for feeding into model
"""
raise NotImplementedError()
# @abstractmethod
def run_model(self, sample_dict, model:Mammal):
raise NotImplementedError()
def create_demo(self, model_name_widget: gr.component) -> gr.Group:
"""create an gradio demo group
Args:
model_name_widgit (gr.Component): widget holding the model name to use. This is needed to create
gradio actions with the current model name as an input
Raises:
NotImplementedError: _description_
"""
raise NotImplementedError()
def demo(self,model_name_widgit:gr.component=None):
if self._demo is None:
model_name_widget:gr.component
self._demo = self.create_demo(model_name_widget=model_name_widgit)
return self._demo
@abstractmethod
def decode_output(self,batch_dict, model:Mammal):
raise NotImplementedError()
#self._setup()
# def _setup(self):
# pass
|