|
from typing import Dict, List, Any |
|
from transformers import AutoProcessor, Blip2ForConditionalGeneration |
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
import string |
|
import torch |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.processor = AutoProcessor.from_pretrained(path) |
|
self.model = Blip2ForConditionalGeneration.from_pretrained(path, device_map="auto", load_in_4bit=True) |
|
|
|
|
|
def __call__(self, data): |
|
""" |
|
Args: |
|
inputs: |
|
Dict of image and text inputs. |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
image = Image.open(BytesIO(base64.b64decode(inputs['image']))) |
|
inputs = self.processor(images=image, text=inputs["text"], return_tensors="pt").to("cuda", torch.float16) |
|
generated_ids = self.model.generate( |
|
**inputs, |
|
temperature=1.0, |
|
length_penalty=1.0, |
|
repetition_penalty=1.5, |
|
max_length=30, |
|
min_length=1, |
|
num_beams=5, |
|
top_p=0.9, |
|
) |
|
result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
|
if result and result[-1] not in string.punctuation: |
|
result += "." |
|
|
|
|
|
return [{"generated_text": result}] |