Spaces:
Runtime error
Runtime error
from io import BytesIO | |
import streamlit as st | |
import pandas as pd | |
import json | |
import os | |
import numpy as np | |
from streamlit.elements import markdown | |
from model.flax_clip_vision_bert.modeling_clip_vision_bert import FlaxCLIPVisionBertForSequenceClassification | |
from utils import get_transformed_image, get_text_attributes, get_top_5_predictions, plotly_express_horizontal_bar_plot, translate_labels | |
import matplotlib.pyplot as plt | |
from mtranslate import translate | |
from PIL import Image | |
from session import _get_state | |
state = _get_state() | |
def load_model(ckpt): | |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt) | |
def predict(transformed_image, question_inputs): | |
return np.array(model(pixel_values = transformed_image, **question_inputs)[0][0]) | |
def softmax(logits): | |
return np.exp(logits)/np.sum(np.exp(logits), axis=0) | |
def read_markdown(path, parent="./sections/"): | |
with open(os.path.join(parent,path)) as f: | |
return f.read() | |
checkpoints = ['./ckpt/ckpt-60k-5999'] # TODO: Maybe add more checkpoints? | |
dummy_data = pd.read_csv('dummy_vqa_multilingual.tsv', sep='\t') | |
code_to_name = { | |
"en": "English", | |
"fr": "French", | |
"de": "German", | |
"es": "Spanish", | |
} | |
with open('answer_reverse_mapping.json') as f: | |
answer_reverse_mapping = json.load(f) | |
st.set_page_config( | |
page_title="Multilingual VQA", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
page_icon="./misc/mvqa-logo.png", | |
) | |
st.title("Multilingual Visual Question Answering") | |
st.write("[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)") | |
with st.beta_expander("Usage"): | |
st.markdown(read_markdown("usage.md")) | |
first_index = 20 | |
# Init Session State | |
if state.image_file is None: | |
state.image_file = dummy_data.loc[first_index,'image_file'] | |
state.question = dummy_data.loc[first_index,'question'].strip('- ') | |
state.answer_label = dummy_data.loc[first_index,'answer_label'] | |
state.question_lang_id = dummy_data.loc[first_index, 'lang_id'] | |
state.answer_lang_id = dummy_data.loc[first_index, 'lang_id'] | |
image_path = os.path.join('images',state.image_file) | |
image = plt.imread(image_path) | |
state.image = image | |
col1, col2 = st.beta_columns([6,4]) | |
if col2.button('Get a random example'): | |
sample = dummy_data.sample(1).reset_index() | |
state.image_file = sample.loc[0,'image_file'] | |
state.question = sample.loc[0,'question'].strip('- ') | |
state.answer_label = sample.loc[0,'answer_label'] | |
state.question_lang_id = sample.loc[0, 'lang_id'] | |
state.answer_lang_id = sample.loc[0, 'lang_id'] | |
image_path = os.path.join('images',state.image_file) | |
image = plt.imread(image_path) | |
state.image = image | |
col2.write("OR") | |
uploaded_file = col2.file_uploader('Upload your image', type=['png','jpg','jpeg']) | |
if uploaded_file is not None: | |
state.image_file = os.path.join('images/val2014',uploaded_file.name) | |
state.image = np.array(Image.open(uploaded_file)) | |
transformed_image = get_transformed_image(state.image) | |
# Display Image | |
col1.image(state.image, use_column_width='always') | |
# Display Question | |
question = col2.text_input(label="Question", value=state.question) | |
col2.markdown(f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}""") | |
question_inputs = get_text_attributes(question) | |
# Select Language | |
options = ['en', 'de', 'es', 'fr'] | |
state.answer_lang_id = col2.selectbox('Answer Language', index=options.index(state.answer_lang_id), options=options, format_func = lambda x: code_to_name[x]) | |
# Display Top-5 Predictions | |
with st.spinner('Loading model...'): | |
model = load_model(checkpoints[0]) | |
with st.spinner('Predicting...'): | |
logits = predict(transformed_image, dict(question_inputs)) | |
logits = softmax(logits) | |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping) | |
translated_labels = translate_labels(labels, state.answer_lang_id) | |
fig = plotly_express_horizontal_bar_plot(values, translated_labels) | |
st.plotly_chart(fig, use_container_width = True) | |
st.write(read_markdown("abstract.md")) | |
st.write(read_markdown("caveats.md")) | |
st.write("# Methodology") | |
st.image("./misc/Multilingual-VQA.png", caption="Masked LM model for Image-text Pretraining.") | |
st.markdown(read_markdown("pretraining.md")) | |
st.markdown(read_markdown("finetuning.md")) | |
st.write(read_markdown("challenges.md")) | |
st.write(read_markdown("social_impact.md")) | |
st.write(read_markdown("references.md")) | |
st.write(read_markdown("checkpoints.md")) | |
st.write(read_markdown("acknowledgements.md")) | |