VQA_fashion_hvar / inference.py
wiusdy's picture
engineering the code
0a5203f
raw
history blame
460 Bytes
from transformers import ViltProcessor, ViltForQuestionAnswering
def inference(image, text):
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
encoding = processor(image, text, return_tensors="pt")
outputs = model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
return f"{model.config.id2label[idx]}"