Spaces:
Runtime error
Runtime error
import json | |
import os | |
from io import BytesIO | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from mtranslate import translate | |
from PIL import Image | |
from streamlit.elements import markdown | |
from model.flax_clip_vision_bert.modeling_clip_vision_bert import ( | |
FlaxCLIPVisionBertForSequenceClassification, | |
) | |
from session import _get_state | |
from utils import ( | |
get_text_attributes, | |
get_top_5_predictions, | |
get_transformed_image, | |
plotly_express_horizontal_bar_plot, | |
translate_labels, | |
) | |
state = _get_state() | |
def load_model(ckpt): | |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt) | |
def predict(transformed_image, question_inputs): | |
return np.array(model(pixel_values=transformed_image, **question_inputs)[0][0]) | |
def softmax(logits): | |
return np.exp(logits) / np.sum(np.exp(logits), axis=0) | |
def read_markdown(path, parent="./sections/"): | |
with open(os.path.join(parent, path)) as f: | |
return f.read() | |
# def resize_height(image, new_height): | |
# h, w, c = image.shape | |
# new_width = int(w * new_height / h) | |
# return cv2.resize(image, (new_width, new_height)) | |
checkpoints = ["./ckpt/vqa/ckpt-60k-5999"] # TODO: Maybe add more checkpoints? | |
dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t") | |
code_to_name = { | |
"en": "English", | |
"fr": "French", | |
"de": "German", | |
"es": "Spanish", | |
} | |
with open("answer_reverse_mapping.json") as f: | |
answer_reverse_mapping = json.load(f) | |
st.set_page_config( | |
page_title="Multilingual VQA", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
page_icon="./misc/mvqa-logo-3-white.png", | |
) | |
st.title("Multilingual Visual Question Answering") | |
st.write( | |
"[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)" | |
) | |
image_col, intro_col = st.beta_columns([3, 8]) | |
image_col.image("./misc/mvqa-logo-3-white.png", use_column_width="always") | |
intro_col.write(read_markdown("intro.md")) | |
with st.beta_expander("Usage"): | |
st.write(read_markdown("usage.md")) | |
with st.beta_expander("Article"): | |
st.write(read_markdown("abstract.md")) | |
st.write(read_markdown("caveats.md")) | |
st.write("## Methodology") | |
st.image( | |
"./misc/Multilingual-VQA.png", | |
caption="Masked LM model for Image-text Pretraining.", | |
) | |
st.markdown(read_markdown("pretraining.md")) | |
st.markdown(read_markdown("finetuning.md")) | |
st.write(read_markdown("challenges.md")) | |
st.write(read_markdown("social_impact.md")) | |
st.write(read_markdown("references.md")) | |
st.write(read_markdown("checkpoints.md")) | |
st.write(read_markdown("acknowledgements.md")) | |
first_index = 20 | |
# Init Session State | |
if state.image_file is None: | |
state.image_file = dummy_data.loc[first_index, "image_file"] | |
state.question = dummy_data.loc[first_index, "question"].strip("- ") | |
state.answer_label = dummy_data.loc[first_index, "answer_label"] | |
state.question_lang_id = dummy_data.loc[first_index, "lang_id"] | |
state.answer_lang_id = dummy_data.loc[first_index, "lang_id"] | |
image_path = os.path.join("images", state.image_file) | |
image = plt.imread(image_path) | |
state.image = image | |
# col1, col2 = st.beta_columns([6, 4]) | |
if st.button( | |
"Get a random example", | |
help="Get a random example from the 100 `seeded` image-text pairs.", | |
): | |
sample = dummy_data.sample(1).reset_index() | |
state.image_file = sample.loc[0, "image_file"] | |
state.question = sample.loc[0, "question"].strip("- ") | |
state.answer_label = sample.loc[0, "answer_label"] | |
state.question_lang_id = sample.loc[0, "lang_id"] | |
state.answer_lang_id = sample.loc[0, "lang_id"] | |
image_path = os.path.join("images", state.image_file) | |
image = plt.imread(image_path) | |
state.image = image | |
# col2.write("OR") | |
# uploaded_file = col2.file_uploader( | |
# "Upload your image", | |
# type=["png", "jpg", "jpeg"], | |
# help="Upload a file of your choosing.", | |
# ) | |
# if uploaded_file is not None: | |
# state.image_file = os.path.join("images/val2014", uploaded_file.name) | |
# state.image = np.array(Image.open(uploaded_file)) | |
transformed_image = get_transformed_image(state.image) | |
# Display Image | |
st.image(state.image, use_column_width="auto") | |
new_col1, new_col2 = st.beta_columns([5, 5]) | |
# Display Question | |
question = new_col1.text_input( | |
label="Question", | |
value=state.question, | |
help="Type your question regarding the image above in one of the four languages.", | |
) | |
new_col1.markdown( | |
f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}""" | |
) | |
question_inputs = get_text_attributes(question) | |
# Select Language | |
options = ["en", "de", "es", "fr"] | |
state.answer_lang_id = new_col2.selectbox( | |
"Answer Language", | |
index=options.index(state.answer_lang_id), | |
options=options, | |
format_func=lambda x: code_to_name[x], | |
help="The language to be used to show the top-5 labels.", | |
) | |
actual_answer = answer_reverse_mapping[str(state.answer_label)] | |
new_col2.markdown( | |
"**Actual Answer**: " | |
+ translate_labels([actual_answer], state.answer_lang_id)[0] | |
+ " (" | |
+ actual_answer | |
+ ")" | |
) | |
# Display Top-5 Predictions | |
with st.spinner("Loading model..."): | |
model = load_model(checkpoints[0]) | |
with st.spinner("Predicting..."): | |
logits = predict(transformed_image, dict(question_inputs)) | |
logits = softmax(logits) | |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping) | |
translated_labels = translate_labels(labels, state.answer_lang_id) | |
fig = plotly_express_horizontal_bar_plot(values, translated_labels) | |
st.plotly_chart(fig, use_container_width=True) | |