Spaces:
Runtime error
Runtime error
File size: 4,123 Bytes
405f2d4 2c8f495 405f2d4 2c8f495 fb3c77c 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 fb3c77c 2c8f495 405f2d4 4b29c6a 2c8f495 405f2d4 2c8f495 405f2d4 f15eef4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 4b29c6a 0cb8576 4b29c6a 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 405f2d4 2c8f495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from .utils import (
get_text_attributes,
get_top_5_predictions,
get_transformed_image,
plotly_express_horizontal_bar_plot,
bert_tokenizer,
)
import streamlit as st
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from mtranslate import translate
from .utils import read_markdown
from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
FlaxCLIPVisionBertForMaskedLM,
)
def softmax(logits):
return np.exp(logits) / np.sum(np.exp(logits), axis=0)
def app(state):
mlm_state = state
with st.beta_expander("Usage"):
st.write(read_markdown("mlm_usage.md"))
st.write(read_markdown("mlm_intro.md"))
# @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
def predict(transformed_image, caption_inputs):
outputs = mlm_state.mlm_model(pixel_values=transformed_image, **caption_inputs)
indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[1][0]
preds = outputs.logits[0][indices]
scores = np.array(preds)
return scores
# @st.cache(persist=False)
def load_model(ckpt):
return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)
mlm_checkpoints = ["flax-community/clip-vision-bert-cc12m-70k"]
dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")
first_index = 15
# Init Session mlm_state
if mlm_state.mlm_image_file is None:
mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
caption = dummy_data.loc[first_index, "caption"].strip("- ")
ids = bert_tokenizer.encode(caption)
ids[np.random.randint(1, len(ids) - 1)] = bert_tokenizer.mask_token_id
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]
image_path = os.path.join("cc12m_data/images_vqa", mlm_state.mlm_image_file)
image = plt.imread(image_path)
mlm_state.mlm_image = image
if mlm_state.mlm_model is None:
# Display Top-5 Predictions
with st.spinner("Loading model..."):
mlm_state.mlm_model = load_model(mlm_checkpoints[0])
if st.button(
"Get a random example",
help="Get a random example from the 100 `seeded` image-text pairs.",
):
sample = dummy_data.sample(1).reset_index()
mlm_state.mlm_image_file = sample.loc[0, "image_file"]
caption = sample.loc[0, "caption"].strip("- ")
ids = bert_tokenizer.encode(caption)
ids[np.random.randint(1, len(ids) - 1)] = bert_tokenizer.mask_token_id
mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
mlm_state.caption_lang_id = sample.loc[0, "lang_id"]
image_path = os.path.join("cc12m_data/images_vqa", mlm_state.mlm_image_file)
image = plt.imread(image_path)
mlm_state.mlm_image = image
transformed_image = get_transformed_image(mlm_state.mlm_image)
new_col1, new_col2 = st.beta_columns([5, 5])
# Display Image
new_col1.image(mlm_state.mlm_image, use_column_width="always")
# Display caption
new_col2.write("Write your text with exactly one [MASK] token.")
caption = new_col2.text_input(
label="Text",
value=mlm_state.caption,
help="Type your masked caption regarding the image above in one of the four languages.",
)
new_col2.markdown(
f"""**English Translation**: {caption if mlm_state.caption_lang_id == "en" else translate(caption, 'en')}"""
)
caption_inputs = get_text_attributes(caption)
# Display Top-5 Predictions
with st.spinner("Predicting..."):
scores = predict(transformed_image, dict(caption_inputs))
scores = softmax(scores)
labels, values = get_top_5_predictions(scores)
# newer_col1, newer_col2 = st.beta_columns([6,4])
fig = plotly_express_horizontal_bar_plot(values, labels)
st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
st.plotly_chart(fig, use_container_width=True)
|