File size: 2,813 Bytes
c97c8cb
b244c27
313c6b1
254bcc9
c97c8cb
 
 
254bcc9
042fc02
9e2d39f
50da7fb
254bcc9
 
6ff1d6b
629ae8a
f8654b9
 
 
 
313c6b1
 
a6def05
629ae8a
254bcc9
 
c97c8cb
 
 
 
 
 
 
b4bc0d9
 
 
 
 
 
 
50da7fb
3607fbc
 
 
254bcc9
79bde87
50da7fb
 
 
0a6c0fe
254bcc9
629ae8a
254bcc9
046b29c
79bde87
f8654b9
9d5c1cc
 
 
 
 
79bde87
c97c8cb
 
 
 
90c4e38
50da7fb
90c4e38
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipForQuestionAnswering, BitsAndBytesConfig
from transformers import AutoProcessor, AutoModelForCausalLM
from typing import Dict, List, Any
from PIL import Image
from transformers import pipeline
import requests
import torch
from io import BytesIO
import base64

class EndpointHandler():
    def __init__(self, path=""):
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        print("device:",self.device)
        # self.model_base = "Salesforce/blip2-opt-2.7b"
        # self.model_name = "sooh-j/blip2-vizwizqa"
        self.model_name = "Salesforce/blip2-opt-2.7b"

        self.processor = AutoProcessor.from_pretrained(self.model_name)
        self.model = Blip2ForConditionalGeneration.from_pretrained(self.model_name,
                                                              device_map="auto", 
                                                             ).to(self.device)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        data args:
            inputs (:obj: `str` | `PIL.Image` | `np.array`)
            kwargs
        Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # await hf.visualQuestionAnswering({
        #       model: 'dandelin/vilt-b32-finetuned-vqa',
        #       inputs: {
        #         question: 'How many cats are lying down?',
        #         image: await (await fetch('https://placekitten.com/300/300')).blob()
        #       }
        #     })
        
        inputs = data.get("inputs")
        imageBase64 = inputs.get("image")
        question = inputs.get("question")

        if ('http:' in imageBase64) or ('https:' in imageBase64): 
            image = Image.open(requests.get(imageBase64, stream=True).raw)
        else:
            image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[0].encode())))

        prompt = f"Question: {question}, Answer:"
        processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            out = self.model.generate(**processed, 
                                     max_new_tokens=20,
                                     temperature = 0.5,
                                     do_sample=True,
                                     top_k=50,
                                     top_p=0.9,
                                     repetition_penalty=1.2  
                                     ).to(self.device)
        
        result = {}
        text_output = self.processor.decode(out[0], skip_special_tokens=True)
        result["text_output"] = text_output
        score = 0
        
        return [{"answer":text_output,"score":score}]