Minqin's picture
Update app.py
30b0855
raw
history blame
No virus
1.03 kB
import numpy as np
from PIL import Image
from transformers import ViltConfig, ViltProcessor, ViltForQuestionAnswering
import streamlit as st
st.title("Live demo of multimodal vqa")
config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned")
uploaded_file = st.file_uploader("Please upload one image (jpg)", type="jpg")
question = st.text_input("Type here one question on the image")
if uploaded_file is not None:
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
img = Image.fromarray(file_bytes)
st.image(img, caption="Here is the uploaded image", use_column_width=True)
encoding = processor(images=file_bytes, text=question, return_tensors="pt")
outputs = model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
pred = model.config.id2label[idx]
st.text(f"Answer: {pred}")