File size: 3,700 Bytes
f04d812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import random

import onnxruntime
import pandas as pd
import plotly.express as px
import streamlit as st
import torch
from lang_map import langs
from PIL import Image
from transformers import AutoTokenizer, CLIPProcessor

st.set_page_config(layout="wide")

options = list(langs.keys())


class SessionState:
    def __init__(self, **kwargs):
        for key, val in kwargs.items():
            setattr(self, key, val)


def get_state(**kwargs):
    if "session_state" not in st.session_state:
        st.session_state["session_state"] = SessionState(**kwargs)
    return st.session_state["session_state"]


def add_selectbox_and_input(key):
    col1, col2 = st.columns(2)
    with col1:
        select = st.selectbox("Select a language", options, key=f"{key}_select")
    with col2:
        user_input = st.text_input("Input text", key=f"{key}_text")

    state.inputs[key] = (select, user_input)


state = get_state(count=1, inputs={})

st.title("Zero-shot image classification with CLIP in 201 languages")

col1, col2 = st.columns(2)

image: Image.Image = None
with col1:
    st.subheader("Image")
    uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"])
    if uploaded_file is not None:
        image = Image.open(uploaded_file)
        st.image(image, caption="Uploaded Image.", use_column_width=True)


def process():
    session_options = onnxruntime.SessionOptions()
    session_options.graph_optimization_level = (
        onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
    )
    onnx_path = "model-quant.onnx"
    ort_session = onnxruntime.InferenceSession(onnx_path, session_options)

    processor = CLIPProcessor.from_pretrained(
        "openai/clip-vit-base-patch32"
    ).image_processor

    tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")

    image_inputs = processor(images=image, return_tensors="pt")

    classes = []
    languages = []
    for key, value in state.inputs.items():
        languages.append(str(value[0]))
        classes.append(str(value[1]))

    languages = [langs[lang] for lang in languages]

    input_ids = []
    attention_mask = []
    for i, _ in enumerate(languages):
        tokenizer.set_src_lang_special_tokens(languages[i])
        input = tokenizer.batch_encode_plus(
            [classes[i]],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=100,
        )
        input_ids.append(input["input_ids"])
        attention_mask.append(input["attention_mask"])
    input_ids = torch.concat(input_ids, dim=0)
    attention_mask = torch.concat(attention_mask, dim=0)

    ort_inputs = {
        "pixel_values": image_inputs["pixel_values"].numpy(),
        "input_ids": input_ids.numpy(),
        "attention_mask": attention_mask.numpy(),
    }
    ort_outputs = ort_session.run(None, ort_inputs)
    logits = torch.tensor(ort_outputs[0])
    probabilities = logits.softmax(dim=-1).squeeze().detach().numpy()

    chart_data = pd.DataFrame({"Class": classes, "Probability": probabilities})
    chart_data = chart_data.sort_values(by=["Probability"], ascending=True)
    fig = px.bar(chart_data, x="Probability", y="Class", orientation="h")
    with col2:
        st.subheader("Predictions")
        st.write(fig)


with col2:
    st.subheader("Classes")
    add_selectbox_and_input("Input 1")

    for i in range(2, state.count + 1):
        add_selectbox_and_input(f"Input {i}")

    if st.button("Add class"):
        state.count += 1
        add_selectbox_and_input(f"Input {state.count}")

    st.markdown("""---""")
    if st.button("Generate"):
        with st.spinner("Processing the data"):
            process()