from .utils import ( get_text_attributes, get_top_5_predictions, get_transformed_image, plotly_express_horizontal_bar_plot, translate_labels, ) import streamlit as st import numpy as np import pandas as pd import os import matplotlib.pyplot as plt import json from mtranslate import translate from .utils import read_markdown from .model.flax_clip_vision_bert.modeling_clip_vision_bert import ( FlaxCLIPVisionBertForSequenceClassification, ) def softmax(logits): return np.exp(logits) / np.sum(np.exp(logits), axis=0) def app(state): vqa_state = state with st.beta_expander("Usage"): st.write(read_markdown("vqa_usage.md")) st.write(read_markdown("vqa_intro.md")) # @st.cache(persist=False) def predict(transformed_image, question_inputs): return np.array( vqa_state.vqa_model(pixel_values=transformed_image, **question_inputs)[0][0] ) # @st.cache(persist=False) def load_model(ckpt): return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt) vqa_checkpoints = [ "flax-community/clip-vision-bert-vqa-ft-6k" ] # TODO: Maybe add more checkpoints? dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t") code_to_name = { "en": "English", "fr": "French", "de": "German", "es": "Spanish", } with open("answer_reverse_mapping.json") as f: answer_reverse_mapping = json.load(f) first_index = 20 # Init Session vqa_state if vqa_state.vqa_image_file is None: vqa_state.vqa_image_file = dummy_data.loc[first_index, "image_file"] vqa_state.question = dummy_data.loc[first_index, "question"].strip("- ") vqa_state.answer_label = dummy_data.loc[first_index, "answer_label"] vqa_state.question_lang_id = dummy_data.loc[first_index, "lang_id"] vqa_state.answer_lang_id = dummy_data.loc[first_index, "lang_id"] image_path = os.path.join("resized_images", vqa_state.vqa_image_file) image = plt.imread(image_path) vqa_state.vqa_image = image if vqa_state.vqa_model is None: with st.spinner("Loading model..."): vqa_state.vqa_model = load_model(vqa_checkpoints[0]) # Display Top-5 Predictions if st.button( "Get a random example", help="Get a random example from the 100 `seeded` image-text pairs.", ): sample = dummy_data.sample(1).reset_index() vqa_state.vqa_image_file = sample.loc[0, "image_file"] vqa_state.question = sample.loc[0, "question"].strip("- ") vqa_state.answer_label = sample.loc[0, "answer_label"] vqa_state.question_lang_id = sample.loc[0, "lang_id"] vqa_state.answer_lang_id = sample.loc[0, "lang_id"] image_path = os.path.join("resized_images", vqa_state.vqa_image_file) image = plt.imread(image_path) vqa_state.vqa_image = image transformed_image = get_transformed_image(vqa_state.vqa_image) new_col1, new_col2 = st.beta_columns([5, 5]) # Display Image new_col1.image(vqa_state.vqa_image, use_column_width="always") # Display Question question = new_col2.text_input( label="Question", value=vqa_state.question, help="Type your question regarding the image above in one of the four languages.", ) new_col2.markdown( f"""**English Translation**: {question if vqa_state.question_lang_id == "en" else translate(question, 'en')}""" ) question_inputs = get_text_attributes(question) # Select Language options = ["en", "de", "es", "fr"] vqa_state.answer_lang_id = new_col2.selectbox( "Answer Language", index=options.index(vqa_state.answer_lang_id), options=options, format_func=lambda x: code_to_name[x], help="The language to be used to show the top-5 labels.", ) actual_answer = answer_reverse_mapping[str(vqa_state.answer_label)] new_col2.markdown( "**Actual Answer**: " + translate_labels([actual_answer], vqa_state.answer_lang_id)[0] + " (" + actual_answer + ")" ) with st.spinner("Predicting..."): logits = predict(transformed_image, dict(question_inputs)) logits = softmax(logits) labels, values = get_top_5_predictions(logits, answer_reverse_mapping) translated_labels = translate_labels(labels, vqa_state.answer_lang_id) fig = plotly_express_horizontal_bar_plot(values, translated_labels) st.plotly_chart(fig, use_container_width=True)