Spaces:
Runtime error
Runtime error
File size: 5,734 Bytes
2bbf92c e289356 2bbf92c e289356 9e939e7 e289356 690384a e289356 690384a e289356 690384a f4963f2 690384a 69e32d1 2bbf92c 690384a 69e32d1 b5bd188 690384a 2bbf92c 690384a 2bbf92c 69e32d1 690384a 69e32d1 9e939e7 690384a 69e32d1 690384a 2bbf92c 01fc68e 0ef2c4e 01fc68e 583a144 690384a 583a144 e289356 0ef2c4e e289356 f4963f2 0808df5 7a89f67 0808df5 e289356 0808df5 69e32d1 2bbf92c f4963f2 690384a 2bbf92c f4963f2 2bbf92c f907aa9 f4963f2 f907aa9 e289356 2bbf92c 690384a 2bbf92c 690384a 2bbf92c f4963f2 2bbf92c f907aa9 2bbf92c f907aa9 2bbf92c f4963f2 2bbf92c 69e32d1 f907aa9 69e32d1 e289356 2bbf92c e289356 0808df5 690384a 7c9f5a6 2bbf92c 690384a 0808df5 690384a e289356 690384a 0808df5 7a89f67 e289356 0808df5 2bbf92c 690384a 2bbf92c 690384a b5bd188 2bbf92c f4963f2 2bbf92c 690384a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import json
import os
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
from mtranslate import translate
from PIL import Image
from streamlit.elements import markdown
from model.flax_clip_vision_bert.modeling_clip_vision_bert import (
FlaxCLIPVisionBertForSequenceClassification,
)
from session import _get_state
from utils import (
get_text_attributes,
get_top_5_predictions,
get_transformed_image,
plotly_express_horizontal_bar_plot,
translate_labels,
)
state = _get_state()
@st.cache(persist=True)
def load_model(ckpt):
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
@st.cache(persist=True)
def predict(transformed_image, question_inputs):
return np.array(model(pixel_values=transformed_image, **question_inputs)[0][0])
def softmax(logits):
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
def read_markdown(path, parent="./sections/"):
with open(os.path.join(parent, path)) as f:
return f.read()
# def resize_height(image, new_height):
# h, w, c = image.shape
# new_width = int(w * new_height / h)
# return cv2.resize(image, (new_width, new_height))
checkpoints = ["./ckpt/ckpt-60k-5999"] # TODO: Maybe add more checkpoints?
dummy_data = pd.read_csv("dummy_vqa_multilingual.tsv", sep="\t")
code_to_name = {
"en": "English",
"fr": "French",
"de": "German",
"es": "Spanish",
}
with open("answer_reverse_mapping.json") as f:
answer_reverse_mapping = json.load(f)
st.set_page_config(
page_title="Multilingual VQA",
layout="wide",
initial_sidebar_state="collapsed",
page_icon="./misc/mvqa-logo-3-white.png",
)
st.title("Multilingual Visual Question Answering")
st.write(
"[Gunjan Chhablani](https://huggingface.co/gchhablani), [Bhavitvya Malik](https://huggingface.co/bhavitvyamalik)"
)
image_col, intro_col = st.beta_columns([3, 8])
image_col.image("./misc/mvqa-logo-3-white.png", use_column_width="always")
intro_col.write(read_markdown("intro.md"))
with st.beta_expander("Usage"):
st.write(read_markdown("usage.md"))
with st.beta_expander("Article"):
st.write(read_markdown("abstract.md"))
st.write(read_markdown("caveats.md"))
st.write("## Methodology")
st.image(
"./misc/Multilingual-VQA.png",
caption="Masked LM model for Image-text Pretraining.",
)
st.markdown(read_markdown("pretraining.md"))
st.markdown(read_markdown("finetuning.md"))
st.write(read_markdown("challenges.md"))
st.write(read_markdown("social_impact.md"))
st.write(read_markdown("references.md"))
st.write(read_markdown("checkpoints.md"))
st.write(read_markdown("acknowledgements.md"))
first_index = 20
# Init Session State
if state.image_file is None:
state.image_file = dummy_data.loc[first_index, "image_file"]
state.question = dummy_data.loc[first_index, "question"].strip("- ")
state.answer_label = dummy_data.loc[first_index, "answer_label"]
state.question_lang_id = dummy_data.loc[first_index, "lang_id"]
state.answer_lang_id = dummy_data.loc[first_index, "lang_id"]
image_path = os.path.join("images", state.image_file)
image = plt.imread(image_path)
state.image = image
# col1, col2 = st.beta_columns([6, 4])
if st.button(
"Get a random example",
help="Get a random example from the 100 `seeded` image-text pairs.",
):
sample = dummy_data.sample(1).reset_index()
state.image_file = sample.loc[0, "image_file"]
state.question = sample.loc[0, "question"].strip("- ")
state.answer_label = sample.loc[0, "answer_label"]
state.question_lang_id = sample.loc[0, "lang_id"]
state.answer_lang_id = sample.loc[0, "lang_id"]
image_path = os.path.join("images", state.image_file)
image = plt.imread(image_path)
state.image = image
# col2.write("OR")
# uploaded_file = col2.file_uploader(
# "Upload your image",
# type=["png", "jpg", "jpeg"],
# help="Upload a file of your choosing.",
# )
# if uploaded_file is not None:
# state.image_file = os.path.join("images/val2014", uploaded_file.name)
# state.image = np.array(Image.open(uploaded_file))
transformed_image = get_transformed_image(state.image)
# Display Image
st.image(state.image, use_column_width="auto")
new_col1, new_col2 = st.beta_columns([5, 5])
# Display Question
question = new_col1.text_input(
label="Question",
value=state.question,
help="Type your question regarding the image above in one of the four languages.",
)
new_col1.markdown(
f"""**English Translation**: {question if state.question_lang_id == "en" else translate(question, 'en')}"""
)
question_inputs = get_text_attributes(question)
# Select Language
options = ["en", "de", "es", "fr"]
state.answer_lang_id = new_col2.selectbox(
"Answer Language",
index=options.index(state.answer_lang_id),
options=options,
format_func=lambda x: code_to_name[x],
help="The language to be used to show the top-5 labels.",
)
actual_answer = answer_reverse_mapping[str(state.answer_label)]
new_col2.markdown(
"**Actual Answer**: "
+ translate_labels([actual_answer], state.answer_lang_id)[0]
+ " ("
+ actual_answer
+ ")"
)
# Display Top-5 Predictions
with st.spinner("Loading model..."):
model = load_model(checkpoints[0])
with st.spinner("Predicting..."):
logits = predict(transformed_image, dict(question_inputs))
logits = softmax(logits)
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
translated_labels = translate_labels(labels, state.answer_lang_id)
fig = plotly_express_horizontal_bar_plot(values, translated_labels)
st.plotly_chart(fig, use_container_width=True)
|