blip2-flan-t5-xxl / handler.py
smdesai's picture
Update handler.py
02685d6
from typing import Dict, List, Any
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import base64
from io import BytesIO
from PIL import Image
import string
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = AutoProcessor.from_pretrained(path)
self.model = Blip2ForConditionalGeneration.from_pretrained(path, device_map="auto", load_in_4bit=True)
def __call__(self, data):
"""
Args:
inputs:
Dict of image and text inputs.
"""
# process input
inputs = data.pop("inputs", data)
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
inputs = self.processor(images=image, text=inputs["text"], return_tensors="pt").to("cuda", torch.float16)
generated_ids = self.model.generate(
**inputs,
temperature=1.0,
length_penalty=1.0,
repetition_penalty=1.5,
max_length=30,
min_length=1,
num_beams=5,
top_p=0.9,
)
result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
if result and result[-1] not in string.punctuation:
result += "."
return [{"generated_text": result}]