visheratin commited on
Commit
f04d812
1 Parent(s): b1d45e0

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +125 -0
  2. lang_map.py +203 -0
  3. model-quant.onnx +3 -0
  4. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import onnxruntime
4
+ import pandas as pd
5
+ import plotly.express as px
6
+ import streamlit as st
7
+ import torch
8
+ from lang_map import langs
9
+ from PIL import Image
10
+ from transformers import AutoTokenizer, CLIPProcessor
11
+
12
+ st.set_page_config(layout="wide")
13
+
14
+ options = list(langs.keys())
15
+
16
+
17
+ class SessionState:
18
+ def __init__(self, **kwargs):
19
+ for key, val in kwargs.items():
20
+ setattr(self, key, val)
21
+
22
+
23
+ def get_state(**kwargs):
24
+ if "session_state" not in st.session_state:
25
+ st.session_state["session_state"] = SessionState(**kwargs)
26
+ return st.session_state["session_state"]
27
+
28
+
29
+ def add_selectbox_and_input(key):
30
+ col1, col2 = st.columns(2)
31
+ with col1:
32
+ select = st.selectbox("Select a language", options, key=f"{key}_select")
33
+ with col2:
34
+ user_input = st.text_input("Input text", key=f"{key}_text")
35
+
36
+ state.inputs[key] = (select, user_input)
37
+
38
+
39
+ state = get_state(count=1, inputs={})
40
+
41
+ st.title("Zero-shot image classification with CLIP in 201 languages")
42
+
43
+ col1, col2 = st.columns(2)
44
+
45
+ image: Image.Image = None
46
+ with col1:
47
+ st.subheader("Image")
48
+ uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"])
49
+ if uploaded_file is not None:
50
+ image = Image.open(uploaded_file)
51
+ st.image(image, caption="Uploaded Image.", use_column_width=True)
52
+
53
+
54
+ def process():
55
+ session_options = onnxruntime.SessionOptions()
56
+ session_options.graph_optimization_level = (
57
+ onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
58
+ )
59
+ onnx_path = "model-quant.onnx"
60
+ ort_session = onnxruntime.InferenceSession(onnx_path, session_options)
61
+
62
+ processor = CLIPProcessor.from_pretrained(
63
+ "openai/clip-vit-base-patch32"
64
+ ).image_processor
65
+
66
+ tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
67
+
68
+ image_inputs = processor(images=image, return_tensors="pt")
69
+
70
+ classes = []
71
+ languages = []
72
+ for key, value in state.inputs.items():
73
+ languages.append(str(value[0]))
74
+ classes.append(str(value[1]))
75
+
76
+ languages = [langs[lang] for lang in languages]
77
+
78
+ input_ids = []
79
+ attention_mask = []
80
+ for i, _ in enumerate(languages):
81
+ tokenizer.set_src_lang_special_tokens(languages[i])
82
+ input = tokenizer.batch_encode_plus(
83
+ [classes[i]],
84
+ return_tensors="pt",
85
+ padding="max_length",
86
+ truncation=True,
87
+ max_length=100,
88
+ )
89
+ input_ids.append(input["input_ids"])
90
+ attention_mask.append(input["attention_mask"])
91
+ input_ids = torch.concat(input_ids, dim=0)
92
+ attention_mask = torch.concat(attention_mask, dim=0)
93
+
94
+ ort_inputs = {
95
+ "pixel_values": image_inputs["pixel_values"].numpy(),
96
+ "input_ids": input_ids.numpy(),
97
+ "attention_mask": attention_mask.numpy(),
98
+ }
99
+ ort_outputs = ort_session.run(None, ort_inputs)
100
+ logits = torch.tensor(ort_outputs[0])
101
+ probabilities = logits.softmax(dim=-1).squeeze().detach().numpy()
102
+
103
+ chart_data = pd.DataFrame({"Class": classes, "Probability": probabilities})
104
+ chart_data = chart_data.sort_values(by=["Probability"], ascending=True)
105
+ fig = px.bar(chart_data, x="Probability", y="Class", orientation="h")
106
+ with col2:
107
+ st.subheader("Predictions")
108
+ st.write(fig)
109
+
110
+
111
+ with col2:
112
+ st.subheader("Classes")
113
+ add_selectbox_and_input("Input 1")
114
+
115
+ for i in range(2, state.count + 1):
116
+ add_selectbox_and_input(f"Input {i}")
117
+
118
+ if st.button("Add class"):
119
+ state.count += 1
120
+ add_selectbox_and_input(f"Input {state.count}")
121
+
122
+ st.markdown("""---""")
123
+ if st.button("Generate"):
124
+ with st.spinner("Processing the data"):
125
+ process()
lang_map.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langs = {
2
+ "Acehnese (Arabic script)": "ace_Arab",
3
+ "Acehnese (Latin script)": "ace_Latn",
4
+ "Mesopotamian Arabic": "acm_Arab",
5
+ "Ta'izzi-Adeni Arabic": "acq_Arab",
6
+ "Tunisian Arabic": "aeb_Arab",
7
+ "Afrikaans": "afr_Latn",
8
+ "South Levantine Arabic": "ajp_Arab",
9
+ "Akan": "aka_Latn",
10
+ "Amharic": "amh_Ethi",
11
+ "North Levantine Arabic": "apc_Arab",
12
+ "Modern Standard Arabic": "arb_Arab",
13
+ "Najdi Arabic": "ars_Arab",
14
+ "Moroccan Arabic": "ary_Arab",
15
+ "Egyptian Arabic": "arz_Arab",
16
+ "Assamese": "asm_Beng",
17
+ "Asturian": "ast_Latn",
18
+ "Awadhi": "awa_Deva",
19
+ "Central Aymara": "ayr_Latn",
20
+ "South Azerbaijani": "azb_Arab",
21
+ "North Azerbaijani": "azj_Latn",
22
+ "Bashkir": "bak_Cyrl",
23
+ "Bambara": "bam_Latn",
24
+ "Balinese": "ban_Latn",
25
+ "Belarusian": "bel_Cyrl",
26
+ "Bemba": "bem_Latn",
27
+ "Bengali": "ben_Beng",
28
+ "Bhojpuri": "bho_Deva",
29
+ "Banjar (Arabic script)": "bjn_Arab",
30
+ "Banjar (Latin script)": "bjn_Latn",
31
+ "Standard Tibetan": "bod_Tibt",
32
+ "Bosnian": "bos_Latn",
33
+ "Buginese": "bug_Latn",
34
+ "Bulgarian": "bul_Cyrl",
35
+ "Catalan": "cat_Latn",
36
+ "Cebuano": "ceb_Latn",
37
+ "Czech": "ces_Latn",
38
+ "Chokwe": "cjk_Latn",
39
+ "Central Kurdish": "ckb_Arab",
40
+ "Crimean Tatar": "crh_Latn",
41
+ "Welsh": "cym_Latn",
42
+ "Danish": "dan_Latn",
43
+ "German": "deu_Latn",
44
+ "Southwestern Dinka": "dik_Latn",
45
+ "Dyula": "dyu_Latn",
46
+ "Dzongkha": "dzo_Tibt",
47
+ "Greek": "ell_Grek",
48
+ "English": "eng_Latn",
49
+ "Esperanto": "epo_Latn",
50
+ "Estonian": "est_Latn",
51
+ "Basque": "eus_Latn",
52
+ "Ewe": "ewe_Latn",
53
+ "Faroese": "fao_Latn",
54
+ "Fijian": "fij_Latn",
55
+ "Finnish": "fin_Latn",
56
+ "Fon": "fon_Latn",
57
+ "French": "fra_Latn",
58
+ "Friulian": "fur_Latn",
59
+ "Nigerian Fulfulde": "fuv_Latn",
60
+ "Scottish Gaelic": "gla_Latn",
61
+ "Irish": "gle_Latn",
62
+ "Galician": "glg_Latn",
63
+ "Guarani": "grn_Latn",
64
+ "Gujarati": "guj_Gujr",
65
+ "Haitian Creole": "hat_Latn",
66
+ "Hausa": "hau_Latn",
67
+ "Hebrew": "heb_Hebr",
68
+ "Hindi": "hin_Deva",
69
+ "Chhattisgarhi": "hne_Deva",
70
+ "Croatian": "hrv_Latn",
71
+ "Hungarian": "hun_Latn",
72
+ "Armenian": "hye_Armn",
73
+ "Igbo": "ibo_Latn",
74
+ "Ilocano": "ilo_Latn",
75
+ "Indonesian": "ind_Latn",
76
+ "Icelandic": "isl_Latn",
77
+ "Italian": "ita_Latn",
78
+ "Javanese": "jav_Latn",
79
+ "Japanese": "jpn_Jpan",
80
+ "Kabyle": "kab_Latn",
81
+ "Jingpho": "kac_Latn",
82
+ "Kamba": "kam_Latn",
83
+ "Kannada": "kan_Knda",
84
+ "Kashmiri (Arabic script)": "kas_Arab",
85
+ "Kashmiri (Devanagari script)": "kas_Deva",
86
+ "Georgian": "kat_Geor",
87
+ "Central Kanuri (Arabic script)": "knc_Arab",
88
+ "Central Kanuri (Latin script)": "knc_Latn",
89
+ "Kazakh": "kaz_Cyrl",
90
+ "Kabiyè": "kbp_Latn",
91
+ "Kabuverdianu": "kea_Latn",
92
+ "Khmer": "khm_Khmr",
93
+ "Kikuyu": "kik_Latn",
94
+ "Kinyarwanda": "kin_Latn",
95
+ "Kyrgyz": "kir_Cyrl",
96
+ "Kimbundu": "kmb_Latn",
97
+ "Northern Kurdish": "kmr_Latn",
98
+ "Kikongo": "kon_Latn",
99
+ "Korean": "kor_Hang",
100
+ "Lao": "lao_Laoo",
101
+ "Ligurian": "lij_Latn",
102
+ "Limburgish": "lim_Latn",
103
+ "Lingala": "lin_Latn",
104
+ "Lithuanian": "lit_Latn",
105
+ "Lombard": "lmo_Latn",
106
+ "Latgalian": "ltg_Latn",
107
+ "Luxembourgish": "ltz_Latn",
108
+ "Luba-Kasai": "lua_Latn",
109
+ "Ganda": "lug_Latn",
110
+ "Luo": "luo_Latn",
111
+ "Mizo": "lus_Latn",
112
+ "Standard Latvian": "lvs_Latn",
113
+ "Magahi": "mag_Deva",
114
+ "Maithili": "mai_Deva",
115
+ "Malayalam": "mal_Mlym",
116
+ "Marathi": "mar_Deva",
117
+ "Minangkabau (Latin script)": "min_Latn",
118
+ "Macedonian": "mkd_Cyrl",
119
+ "Plateau Malagasy": "plt_Latn",
120
+ "Maltese": "mlt_Latn",
121
+ "Meitei (Bengali script)": "mni_Beng",
122
+ "Halh Mongolian": "khk_Cyrl",
123
+ "Mossi": "mos_Latn",
124
+ "Maori": "mri_Latn",
125
+ "Burmese": "mya_Mymr",
126
+ "Dutch": "nld_Latn",
127
+ "Norwegian Nynorsk": "nno_Latn",
128
+ "Norwegian Bokmål": "nob_Latn",
129
+ "Nepali": "npi_Deva",
130
+ "Northern Sotho": "nso_Latn",
131
+ "Nuer": "nus_Latn",
132
+ "Nyanja": "nya_Latn",
133
+ "Occitan": "oci_Latn",
134
+ "West Central Oromo": "gaz_Latn",
135
+ "Odia": "ory_Orya",
136
+ "Pangasinan": "pag_Latn",
137
+ "Eastern Panjabi": "pan_Guru",
138
+ "Papiamento": "pap_Latn",
139
+ "Western Persian": "pes_Arab",
140
+ "Polish": "pol_Latn",
141
+ "Portuguese": "por_Latn",
142
+ "Dari": "prs_Arab",
143
+ "Southern Pashto": "pbt_Arab",
144
+ "Ayacucho Quechua": "quy_Latn",
145
+ "Romanian": "ron_Latn",
146
+ "Rundi": "run_Latn",
147
+ "Russian": "rus_Cyrl",
148
+ "Sango": "sag_Latn",
149
+ "Sanskrit": "san_Deva",
150
+ "Sicilian": "scn_Latn",
151
+ "Shan": "shn_Mymr",
152
+ "Sinhala": "sin_Sinh",
153
+ "Slovak": "slk_Latn",
154
+ "Slovenian": "slv_Latn",
155
+ "Samoan": "smo_Latn",
156
+ "Shona": "sna_Latn",
157
+ "Sindhi": "snd_Arab",
158
+ "Somali": "som_Latn",
159
+ "Southern Sotho": "sot_Latn",
160
+ "Spanish": "spa_Latn",
161
+ "Tosk Albanian": "als_Latn",
162
+ "Sardinian": "srd_Latn",
163
+ "Serbian": "srp_Cyrl",
164
+ "Swati": "ssw_Latn",
165
+ "Sundanese": "sun_Latn",
166
+ "Swedish": "swe_Latn",
167
+ "Swahili": "swh_Latn",
168
+ "Silesian": "szl_Latn",
169
+ "Tamil": "tam_Taml",
170
+ "Tatar": "tat_Cyrl",
171
+ "Telugu": "tel_Telu",
172
+ "Tajik": "tgk_Cyrl",
173
+ "Tagalog": "tgl_Latn",
174
+ "Thai": "tha_Thai",
175
+ "Tigrinya": "tir_Ethi",
176
+ "Tamasheq (Latin script)": "taq_Latn",
177
+ "Tamasheq (Tifinagh script)": "taq_Tfng",
178
+ "Tok Pisin": "tpi_Latn",
179
+ "Tswana": "tsn_Latn",
180
+ "Tsonga": "tso_Latn",
181
+ "Turkmen": "tuk_Latn",
182
+ "Tumbuka": "tum_Latn",
183
+ "Turkish": "tur_Latn",
184
+ "Twi": "twi_Latn",
185
+ "Central Atlas Tamazight": "tzm_Tfng",
186
+ "Uyghur": "uig_Arab",
187
+ "Ukrainian": "ukr_Cyrl",
188
+ "Umbundu": "umb_Latn",
189
+ "Urdu": "urd_Arab",
190
+ "Northern Uzbek": "uzn_Latn",
191
+ "Venetian": "vec_Latn",
192
+ "Vietnamese": "vie_Latn",
193
+ "Waray": "war_Latn",
194
+ "Wolof": "wol_Latn",
195
+ "Xhosa": "xho_Latn",
196
+ "Eastern Yiddish": "ydd_Hebr",
197
+ "Yoruba": "yor_Latn",
198
+ "Yue Chinese": "yue_Hant",
199
+ "Chinese (Simplified)": "zho_Hans",
200
+ "Chinese (Traditional)": "zho_Hant",
201
+ "Standard Malay": "zsm_Latn",
202
+ "Zulu": "zul_Latn",
203
+ }
model-quant.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e480324809ba9e7ba30e3a804f7a8d98ec445855abe11cded716abcd956c554
3
+ size 504902375
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ plotly
3
+ pandas
4
+ onnx
5
+ onnxruntime
6
+ transformers