File size: 1,242 Bytes
c9d8c8e
 
 
 
25162b9
c9d8c8e
 
 
 
 
ced1bec
c9d8c8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ced1bec
c9d8c8e
b44ea03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import Dict, List, Any
import torch 
from transformers import AltCLIPModel, AltCLIPProcessor, AutoProcessor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# https://huggingface.co/docs/inference-endpoints/guides/custom_handler
class EndpointHandler():
    def __init__(self, path=""):
        # Preload all the elements you are going to need at inference.
        # pseudo:
        # self.model= load_model(path)
        self.md_model = AltCLIPModel.from_pretrained(path).to(device)
        self.md_processor = AltCLIPProcessor.from_pretrained(path)


    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """

       data args:

            inputs (:obj: `str` | `PIL.Image` | `np.array`)

            kwargs

      Return:

            A :obj:`list` | `dict`: will be serialized and returned

        """

        # pseudo
        # self.model(input)
        with torch.inference_mode():
            texts = data.pop("inputs",data)
            inputs = self.md_processor(text = texts, padding=True, return_tensors="pt").to(device)
            text_feature = self.md_model.get_text_features(**inputs)
            return {"feature":text_feature.cpu().tolist()}