File size: 4,123 Bytes
405f2d4
 
 
 
 
2c8f495
405f2d4
 
 
 
 
 
 
2c8f495
fb3c77c
405f2d4
 
 
 
 
2c8f495
405f2d4
 
 
2c8f495
 
405f2d4
fb3c77c
 
 
 
2c8f495
405f2d4
4b29c6a
 
2c8f495
 
 
 
 
405f2d4
 
 
2c8f495
405f2d4
 
f15eef4
2c8f495
 
 
405f2d4
2c8f495
 
 
 
405f2d4
2c8f495
405f2d4
2c8f495
405f2d4
4b29c6a
0cb8576
 
4b29c6a
405f2d4
 
 
 
 
 
2c8f495
405f2d4
2c8f495
 
 
 
405f2d4
2c8f495
405f2d4
2c8f495
405f2d4
2c8f495
405f2d4
 
 
 
2c8f495
405f2d4
 
 
 
 
2c8f495
405f2d4
 
 
2c8f495
 
 
405f2d4
 
 
 
2c8f495
 
 
 
405f2d4
2c8f495
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from .utils import (
    get_text_attributes,
    get_top_5_predictions,
    get_transformed_image,
    plotly_express_horizontal_bar_plot,
    bert_tokenizer,
)

import streamlit as st
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from mtranslate import translate
from .utils import read_markdown

from .model.flax_clip_vision_bert.modeling_clip_vision_bert import (
    FlaxCLIPVisionBertForMaskedLM,
)


def softmax(logits):
    return np.exp(logits) / np.sum(np.exp(logits), axis=0)

def app(state):
    mlm_state = state

    with st.beta_expander("Usage"):
        st.write(read_markdown("mlm_usage.md"))
    st.write(read_markdown("mlm_intro.md"))

    # @st.cache(persist=False) # TODO: Make this work with mlm_state. Currently not supported.
    def predict(transformed_image, caption_inputs):
        outputs = mlm_state.mlm_model(pixel_values=transformed_image, **caption_inputs)
        indices = np.where(caption_inputs["input_ids"] == bert_tokenizer.mask_token_id)[1][0]
        preds = outputs.logits[0][indices]
        scores = np.array(preds)
        return scores

    # @st.cache(persist=False)
    def load_model(ckpt):
        return FlaxCLIPVisionBertForMaskedLM.from_pretrained(ckpt)

    mlm_checkpoints = ["flax-community/clip-vision-bert-cc12m-70k"]
    dummy_data = pd.read_csv("cc12m_data/vqa_val.tsv", sep="\t")

    first_index = 15
    # Init Session mlm_state
    if mlm_state.mlm_image_file is None:
        mlm_state.mlm_image_file = dummy_data.loc[first_index, "image_file"]
        caption = dummy_data.loc[first_index, "caption"].strip("- ")
        ids = bert_tokenizer.encode(caption)
        ids[np.random.randint(1, len(ids) - 1)] = bert_tokenizer.mask_token_id
        mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
        mlm_state.caption_lang_id = dummy_data.loc[first_index, "lang_id"]

        image_path = os.path.join("cc12m_data/images_vqa", mlm_state.mlm_image_file)
        image = plt.imread(image_path)
        mlm_state.mlm_image = image

    if mlm_state.mlm_model is None:
        # Display Top-5 Predictions
        with st.spinner("Loading model..."):
            mlm_state.mlm_model = load_model(mlm_checkpoints[0])

    if st.button(
        "Get a random example",
        help="Get a random example from the 100 `seeded` image-text pairs.",
    ):
        sample = dummy_data.sample(1).reset_index()
        mlm_state.mlm_image_file = sample.loc[0, "image_file"]
        caption = sample.loc[0, "caption"].strip("- ")
        ids = bert_tokenizer.encode(caption)
        ids[np.random.randint(1, len(ids) - 1)] = bert_tokenizer.mask_token_id
        mlm_state.caption = bert_tokenizer.decode(ids[1:-1])
        mlm_state.caption_lang_id = sample.loc[0, "lang_id"]

        image_path = os.path.join("cc12m_data/images_vqa", mlm_state.mlm_image_file)
        image = plt.imread(image_path)
        mlm_state.mlm_image = image

    transformed_image = get_transformed_image(mlm_state.mlm_image)

    new_col1, new_col2 = st.beta_columns([5, 5])

    # Display Image
    new_col1.image(mlm_state.mlm_image, use_column_width="always")

    # Display caption
    new_col2.write("Write your text with exactly one [MASK] token.")
    caption = new_col2.text_input(
        label="Text",
        value=mlm_state.caption,
        help="Type your masked caption regarding the image above in one of the four languages.",
    )

    new_col2.markdown(
        f"""**English Translation**: {caption if mlm_state.caption_lang_id == "en" else translate(caption, 'en')}"""
    )
    caption_inputs = get_text_attributes(caption)

    # Display Top-5 Predictions
    with st.spinner("Predicting..."):
        scores = predict(transformed_image, dict(caption_inputs))
    scores = softmax(scores)
    labels, values = get_top_5_predictions(scores)
    # newer_col1, newer_col2 = st.beta_columns([6,4])
    fig = plotly_express_horizontal_bar_plot(values, labels)
    st.dataframe(pd.DataFrame({"Tokens":labels, "English Translation": list(map(lambda x: translate(x),labels))}).T)
    st.plotly_chart(fig, use_container_width=True)