VQA_fashion_hvar / inference.py
wiusdy's picture
solv assertion error
658f534
raw
history blame
1.51 kB
from transformers import ViltProcessor, ViltForQuestionAnswering, Pix2StructProcessor, Pix2StructForConditionalGeneration, Blip2Processor, Blip2ForConditionalGeneration
class Inference:
def __init__(self):
self.vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
self.deplot_processor = Pix2StructProcessor.from_pretrained('google/deplot')
self.deplot_model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
def inference(self, selected, image, text):
if selected == "Model 1":
return self.__inference_deplot(image, text)
elif selected == "Model 2":
return self.__inference_deplot(image, text)
elif selected == "Model 3":
return self.__inference_vilt(image, text)
def __inference_vilt(self, image, text):
encoding = self.vilt_processor(image, text, return_tensors="pt")
outputs = self.vilt_model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
return f"{self.vilt_model.config.id2label[idx]}"
def __inference_deplot(self, image, text):
inputs = self.deplot_processor(images=image, text=text, return_tensors="pt")
predictions = self.deplot_model.generate(**inputs, max_new_tokens=512)
return f"{self.deplot_processor.decode(predictions[0], skip_special_tokens=True)}"