Minqin's picture
first commit image display
520d399
raw
history blame
No virus
1.03 kB
import numpy as np
from PIL import Image
from transformers import ViltConfig, ViltProcessor, ViltForQuestionAnswering
import streamlit as st
st.title("Live demo of multimodal vqa")
config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned")
uploaded_file = st.file_uploader("Please upload one image (jpg)", type="jpg")
question = st.text_input("Type here one question on the image")
if uploaded_file is not None:
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
img = Image.fromarray(file_bytes)
# st.image(img, caption="Here is the uploaded image", use_column_width=True)
encoding = processor(images=file_bytes, text=question, return_tensors="pt")
outputs = model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
pred = model.config.id2label[idx]
st.text(f"Answer: {pred}")