blip2-vizwizqa / handler.py
sooh-j's picture
Update handler.py
6ff1d6b verified
raw
history blame
No virus
4.4 kB
import numpy as np
from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipForQuestionAnswering, BitsAndBytesConfig
from transformers import AutoProcessor, AutoModelForCausalLM
from typing import Dict, List, Any
from PIL import Image
from transformers import pipeline
import requests
import torch
from io import BytesIO
import base64
class EndpointHandler():
def __init__(self, path=""):
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.model_base = "Salesforce/blip2-opt-2.7b"
self.model_name = "sooh-j/blip2-vizwizqa"
# self.pipe = Blip2ForConditionalGeneration.from_pretrained(self.model_base, load_in_8bit=True, torch_dtype=torch.float16)
self.processor = AutoProcessor.from_pretrained(self.model_name)
self.model = Blip2ForConditionalGeneration.from_pretrained(self.model_name,
device_map="auto",
)#.to(self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# await hf.visualQuestionAnswering({
# model: 'dandelin/vilt-b32-finetuned-vqa',
# inputs: {
# question: 'How many cats are lying down?',
# image: await (await fetch('https://placekitten.com/300/300')).blob()
# }
# })
###################
inputs = data.get("inputs")
imageBase64 = inputs.get("image")
# imageURL = inputs.get("image")
question = inputs.get("question")
# print(imageURL)
# print(text)
# image = Image.open(requests.get(imageBase64, stream=True).raw)
import base64
from PIL import Image
from io image BytesIO
import matplotlib.pyplot as plt
#try2
# image = Image.open(BytesIO(base64.b64decode(imageBase64)))
#try1
image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[0].encode())))
###################
######################################
# inputs = data.pop("inputs", data)
# parameters = data.pop("parameters", {})
# # if isinstance(inputs, Image.Image):
# # image = [inputs]
# # else:
# # try:
# # imageBase64 = inputs["image"]
# # # image = Image.open(BytesIO(base64.b64decode(imageBase64.split(",")[1].encode())))
# # image = Image.open(BytesIO(base64.b64decode(imageBase64)))
# # except:
# image_url = inputs['image']
# image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
# question = inputs["question"]
######################################
# data = data.pop("inputs", data)
# data = data.pop("image", image)
# image = Image.open(requests.get(imageBase64, stream=True).raw)
# image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
#### https://huggingface.co/SlowPacer/witron-image-captioning/blob/main/handler.py
# if isinstance(inputs, Image.Image):
# image = [inputs]
# else:
# inputs = isinstance(inputs, str) and [inputs] or inputs
# image = [Image.open(BytesIO(base64.b64decode(_img))) for _img in inputs]
# processed_images = self.processor(images=raw_images, return_tensors="pt")
# processed_images["pixel_values"] = processed_images["pixel_values"].to(device)
# processed_images = {**processed_images, **parameters}
####
prompt = f"Question: {question}, Answer:"
processed = self.processor(images=image, text=prompt, return_tensors="pt")#.to(self.device)
# answer = self._generate_answer(
# model_path, prompt, image,
# )
with torch.no_grad():
out = self.model.generate(**processed).to(self.device)
result = {}
text_output = self.processor.decode(out[0], skip_special_tokens=True)
result["text_output"] = text_output
return text_output