from io import BytesIO import streamlit as st import pandas as pd import json import os import numpy as np from streamlit.elements import markdown from PIL import Image from model.flax_clip_vision_bert.modeling_clip_vision_bert import ( FlaxCLIPVisionBertForSequenceClassification, ) from utils import ( get_transformed_image, get_text_attributes, get_top_5_predictions, plotly_express_horizontal_bar_plot, translate_labels, ) import matplotlib.pyplot as plt from mtranslate import translate from session import _get_state state = _get_state() @st.cache(persist=True) def load_model(ckpt): return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt) @st.cache(persist=True) def predict(transformed_image, question_inputs): return np.array(model(pixel_values=transformed_image, **question_inputs)[0][0]) def softmax(logits): return np.exp(logits) / np.sum(np.exp(logits), axis=0) def read_markdown(path, parent="./sections/"): with open(os.path.join(parent, path)) as f: return f.read() # def resize_height(image, new_height): # h, w, c = image.shape # new_width = int(w * new_height / h) # return cv2.resize(image, (new_width, new_height)) checkpoints = ["./ckpt/ckpt-60k-5999"] # TODO: Maybe add more checkpoints? dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t") code_to_name = { "en": "English", "fr": "French", "de": "German", "es": "Spanish", } with open("answer_reverse_mapping.json") as f: answer_reverse_mapping = json.load(f) st.set_page_config( page_title="Multilingual VQA", layout="wide", initial_sidebar_state="collapsed", page_icon="./misc/mvqa-logo.png", ) st.title("Multilingual Visual Question Answering") st.write( "[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)" ) with st.beta_expander("Usage"): st.markdown(read_markdown("usage.md")) first_index = 20 # Init Session State if state.image_file is None: state.image_file = dummy_data.loc[first_index, "image_file"] state.question = dummy_data.loc[first_index, "question"].strip("- ") state.answer_label = dummy_data.loc[first_index, "answer_label"] state.question_lang_id = dummy_data.loc[first_index, "lang_id"] state.answer_lang_id = dummy_data.loc[first_index, "lang_id"] image_path = os.path.join("images", state.image_file) image = plt.imread(image_path) state.image = image col1, col2 = st.beta_columns([6, 4]) if col2.button("Get a random example"): sample = dummy_data.sample(1).reset_index() state.image_file = sample.loc[0, "image_file"] state.question = sample.loc[0, "question"].strip("- ") state.answer_label = sample.loc[0, "answer_label"] state.question_lang_id = sample.loc[0, "lang_id"] state.answer_lang_id = sample.loc[0, "lang_id"] image_path = os.path.join("images", state.image_file) image = plt.imread(image_path) state.image = image col2.write("OR") uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"]) if uploaded_file is not None: state.image_file = os.path.join("images/val2014", uploaded_file.name) state.image = np.array(Image.open(uploaded_file)) transformed_image = get_transformed_image(state.image) # Display Image col1.image(state.image, use_column_width="auto") # Display Question question = col2.text_input(label="Question", value=state.question) col2.markdown( f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}""" ) col2.markdown("**Actual Answer in English**: " + answer_reverse_mapping[str(state.answer_label)]) question_inputs = get_text_attributes(question) # Select Language options = ["en", "de", "es", "fr"] state.answer_lang_id = col2.selectbox( "Answer Language", index=options.index(state.answer_lang_id), options=options, format_func=lambda x: code_to_name[x], ) # Display Top-5 Predictions with st.spinner("Loading model..."): model = load_model(checkpoints[0]) with st.spinner("Predicting..."): logits = predict(transformed_image, dict(question_inputs)) logits = softmax(logits) labels, values = get_top_5_predictions(logits, answer_reverse_mapping) translated_labels = translate_labels(labels, state.answer_lang_id) fig = plotly_express_horizontal_bar_plot(values, translated_labels) st.plotly_chart(fig, use_container_width=True) st.write(read_markdown("abstract.md")) st.write(read_markdown("caveats.md")) st.write("# Methodology") st.image( "./misc/Multilingual-VQA.png", caption="Masked LM model for Image-text Pretraining." ) st.markdown(read_markdown("pretraining.md")) st.markdown(read_markdown("finetuning.md")) st.write(read_markdown("challenges.md")) st.write(read_markdown("social_impact.md")) st.write(read_markdown("references.md")) st.write(read_markdown("checkpoints.md")) st.write(read_markdown("acknowledgements.md"))