import numpy as np import torch from PIL import Image from transformers import ViltConfig, ViltProcessor, ViltForQuestionAnswering from transformers import BlipProcessor, BlipForQuestionAnswering import cv2 import streamlit as st st.title("Live demo of multimodal vqa") config = ViltConfig.from_pretrained("dandelin/vilt-b32-finetuned-vqa") processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") model = ViltForQuestionAnswering.from_pretrained("Minqin/carets_vqa_finetuned") orig_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") blip_processor = BlipProcessor.from_pretrained('Salesforce/blip-vqa-base') blip_model = BlipForQuestionAnswering.from_pretrained('Salesforce/blip-vqa-base') uploaded_file = st.file_uploader("Please upload one image", type=["jpg", "png", "bmp", "jpeg"]) question = st.text_input("Type here one question on the image") if uploaded_file is not None: file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) opencv_img = cv2.imdecode(file_bytes, 1) image_cv2 = cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB) st.image(image_cv2, channels="RGB") img = Image.fromarray(image_cv2) encoding = processor(images=img, text=question, return_tensors="pt") outputs = model(**encoding) logits = outputs.logits idx = logits.argmax(-1).item() pred = model.config.id2label[idx] orig_outputs = orig_model(**encoding) orig_logits = orig_outputs.logits idx = orig_logits.argmax(-1).item() orig_pred = orig_model.config.id2label[idx] ## BLIP pixel_values = blip_processor(images=img, return_tensors="pt").pixel_values blip_ques = blip_processor.tokenizer.cls_token + question batch_input_ids = blip_processor(text=blip_ques, add_special_tokens=False).input_ids batch_input_ids = torch.tensor(batch_input_ids) st.text(f"input dimension: {batch_input_ids.shape}") st.text(f"pixel dimension: {pixel_values.shape}") # generate_ids = blip_model.generate(pixel_values=pixel_values, input_ids=batch_input_ids, max_length=50) # blip_output = blip_processor.batch_decode(generate_ids, skip_special_tokens=True) st.text(f"Answer of ViLT: {orig_pred}") st.text(f"Answer after fine-tuning: {pred}") # st.text(f"Answer of BLIP: {blip_output}")