File size: 1,369 Bytes
27ec10e
 
6cb3558
27ec10e
 
 
 
 
 
6cb3558
 
29f27f5
27ec10e
 
 
 
 
 
 
 
 
 
 
 
6cb3558
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#!/usr/bin/env python3
from typing import Dict, List, Any
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration

class EndpointHandler():
    def __init__(self, path=""):
        # Preload all the elements you are going to need at inference.
        # pseudo:
        # self.model= load_model(path)
        self.model = PaliGemmaForConditionalGeneration.from_pretrained(path)
        self.processor = AutoProcessor.from_pretrained(path)
        pass

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """

        # pseudo
        # self.model(input)

        raw_inputs = data.pop("inputs", data)

        image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
        raw_image = Image.open(requests.get(image_file, stream=True).raw)
        inputs = self.processor(raw_inputs["prompt"], raw_image, return_tensors="pt")
        output = self.model.generate(**inputs, max_new_tokens=20)
        response = processor.decode(output[0], skip_special_tokens=True)
        return response
        # print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])