llava-mistral / handler.py
JadS's picture
Update handler.py
a0eb4c0 verified
raw
history blame contribute delete
No virus
1.98 kB
from typing import Dict, List, Any
from transformers import AutoProcessor, AutoModelForVision2Seq
import io
import base64
from PIL import Image
class EndpointHandler():
def __init__(self, path=""):
# load the optimized model
self.device = "cuda:0"
self.model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/idefics2-8b").to(self.device)
self.processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
def __call__(self, data: dict[str, Any]) -> str:
"""
example:
{"inputs":
messages: [{
"role": "user",
"content": [
{"type": "text", "text": "What’s the difference between these two images?"},
{"type": "image"},
{"type": "image"},
],
}]
images: []
}
"""
text = self.processor.apply_chat_template(data["inputs"]["messages"], add_generation_prompt=False)
images = [self.decode_image_base64(img) for img in data["inputs"]["images"]]
inputs = self.processor(images=images, text=text, return_tensors="pt")
inputs = {k: v.to(self.device) for k,v in inputs.items()}
generated_ids = self.model.generate(**inputs, max_new_tokens=500)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
return generated_text
def decode_image_base64(self, encoded_image):
"""
Decodes a base64-encoded image back into a PIL image.
"""
# Decode the base64-encoded string to bytes
img_data = base64.b64decode(encoded_image.encode("utf-8"))
# Create a BytesIO object from the decoded bytes
img_io = io.BytesIO(img_data)
# Open the image using PIL (Python Imaging Library)
image = Image.open(img_io)
return image