import torch import io from typing import Any, Dict from PIL import Image from transformers import ViltProcessor, ViltForQuestionAnswering class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.processor = ViltProcessor.from_pretrained(path) self.model = ViltForQuestionAnswering.from_pretrained(path) self.device = "cuda" if torch.cuda.is_available() else "cpu" def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: # process input inputs = data.pop("inputs", data) image = inputs["image"] image = Image.open(io.BytesIO(eval(image))) text = inputs["text"] # preprocess encoding = self.processor(image, text, return_tensors="pt") outputs = self.model(**encoding) # postprocess the prediction logits = outputs.logits best_idx = logits.argmax(-1).item() best_answer = self.model.config.id2label[best_idx] probabilities = torch.softmax(logits, dim=-1)[0] id2label = self.model.config.id2label answers = [] for idx, prob in enumerate(probabilities): answer = id2label[idx] answer_score = float(prob) answers.append({"answer": answer, "answer_score": answer_score}) return {"best_answer": best_answer, "answers": answers}