import torch import io from typing import Any, Dict from PIL import Image from transformers import ViltProcessor, ViltForQuestionAnswering class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.processor = ViltProcessor.from_pretrained(path) self.model = ViltForQuestionAnswering.from_pretrained(path) self.device = "cuda" if torch.cuda.is_available() else "cpu" def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: # process input inputs = data.pop("inputs", data) image = inputs["image"] image = Image.open(io.BytesIO(eval(image))) text = inputs["text"] # preprocess encoding = self.processor(image, text, return_tensors="pt") outputs = self.model(**encoding) # postprocess the prediction logits = outputs.logits idx = logits.argmax(-1).item() return [{"best_answer": self.model.config.id2label[idx], "logits":outputs.logits}]