File size: 4,383 Bytes
c97c8cb
b244c27
313c6b1
254bcc9
c97c8cb
 
 
254bcc9
042fc02
9e2d39f
c97c8cb
254bcc9
 
c97c8cb
3686b46
254bcc9
c97c8cb
a6def05
313c6b1
 
a6def05
8364853
254bcc9
 
c97c8cb
 
 
 
 
 
 
b4bc0d9
 
 
 
 
 
 
3607fbc
 
 
 
 
 
 
 
8364853
 
 
 
 
 
 
aa045a1
3607fbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3686b46
 
254bcc9
3686b46
c97c8cb
046b29c
 
ee0d8ea
 
 
 
 
046b29c
 
 
 
0a6c0fe
046b29c
c97c8cb
046b29c
254bcc9
3607fbc
254bcc9
3686b46
 
 
046b29c
 
 
c97c8cb
 
 
 
254bc93
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipForQuestionAnswering, BitsAndBytesConfig
from transformers import AutoProcessor, AutoModelForCausalLM
from typing import Dict, List, Any
from PIL import Image
from transformers import pipeline
import requests
import torch
from io import BytesIO
import base64
        
class EndpointHandler():
    def __init__(self, path=""):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_base = "Salesforce/blip2-opt-2.7b"
        self.model_name = "sooh-j/blip2-vizwizqa"
        # self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16)

        self.processor = AutoProcessor.from_pretrained(self.model_name)
        self.model = Blip2ForConditionalGeneration.from_pretrained(self.model_name,
                                                              device_map="auto", 
                                                             )#.to(self.device)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
        Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # await hf.visualQuestionAnswering({
        #       model: 'dandelin/vilt-b32-finetuned-vqa',
        #       inputs: {
        #         question: 'How many cats are lying down?',
        #         image: await (await fetch('https://placekitten.com/300/300')).blob()
        #       }
        #     })
        ###################
        inputs = data.get("inputs")
        imageBase64 = inputs.get("image")
        # imageURL = inputs.get("image")
        question = inputs.get("question")
        # print(imageURL)
        # print(text)
        # image = Image.open(requests.get(imageBase64, stream=True).raw)
        import base64
        from PIL import Image
        from io image BytesIO
        import matplotlib.pyplot as plt
        #try2
        # image = Image.open(BytesIO(base64.b64decode(imageBase64)))
        #try1
        image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[0].encode())))
        ###################

######################################

        # inputs = data.pop("inputs", data)
        # parameters = data.pop("parameters", {})
        # # if isinstance(inputs, Image.Image):
        # #     image = [inputs]
        # # else:
        # #     try:
        # #         imageBase64 = inputs["image"]
        # #         # image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode())))
        # #         image = Image.open(BytesIO(base64.b64decode(imageBase64)))
    
        #     # except:
        # image_url = inputs['image']
        # image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')


        # question = inputs["question"]
######################################
        # data = data.pop("inputs", data)
        # data = data.pop("image", image)

        # image = Image.open(requests.get(imageBase64, stream=True).raw)
        # image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
#### https://huggingface.co/SlowPacer/witron-image-captioning/blob/main/handler.py

        # if isinstance(inputs, Image.Image):
        #     image = [inputs]
        # else:
        #     inputs = isinstance(inputs, str) and [inputs] or inputs
        #     image = [Image.open(BytesIO(base64.b64decode(_img))) for _img in inputs]
                                     
        # processed_images = self.processor(images=raw_images, return_tensors="pt")
        # processed_images["pixel_values"] = processed_images["pixel_values"].to(device)
        # processed_images = {**processed_images, **parameters}

####

        
        prompt = f"Question: {question}, Answer:"
        processed = self.processor(images=image, text=prompt, return_tensors="pt")#.to(self.device)

        # answer = self._generate_answer(
        #     model_path, prompt, image, 
        # )
    
        with torch.no_grad():
            out = self.model.generate(**processed)
        
        result = {}
        text_output = self.processor.decode(out[0], skip_special_tokens=True)
        result["text_output"] = text_output
        return text_output