wiusdy's picture
simple VQA
8072ca2
raw
history blame
927 Bytes
import gradio as gr
import os
from transformers import ViltProcessor, ViltForQuestionAnswering
def vqa(image, text):
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
encoding = processor(image, text, return_tensors="pt")
outputs = model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
return f"{text}: {model.config.id2label[idx]}"
with gr.Blocks() as demo:
txt = gr.Textbox(label="Insert a question..", lines=2)
txt_3 = gr.Textbox(value="", label="Your answer is here..")
btn = gr.Button(value="Submit")
dogs = os.path.join(os.path.dirname(__file__), "617.jpg")
image = gr.Image(type="pil", value=dogs)
btn.click(vqa, inputs=[image, txt], outputs=[txt_3])
btn = gr.Button(value="Submit")
if __name__ == "__main__":
demo.launch()