import random import onnxruntime import pandas as pd import plotly.express as px import streamlit as st import torch from lang_map import langs from PIL import Image from transformers import AutoTokenizer, CLIPProcessor st.set_page_config(layout="wide") options = list(langs.keys()) class SessionState: def __init__(self, **kwargs): for key, val in kwargs.items(): setattr(self, key, val) def get_state(**kwargs): if "session_state" not in st.session_state: st.session_state["session_state"] = SessionState(**kwargs) return st.session_state["session_state"] def add_selectbox_and_input(key): col1, col2 = st.columns(2) with col1: select = st.selectbox("Select a language", options, key=f"{key}_select") with col2: user_input = st.text_input("Input text", key=f"{key}_text") state.inputs[key] = (select, user_input) state = get_state(count=1, inputs={}) st.title("Zero-shot image classification with CLIP in 201 languages") col1, col2 = st.columns(2) image: Image.Image = None with col1: st.subheader("Image") uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image.", use_column_width=True) def process(): session_options = onnxruntime.SessionOptions() session_options.graph_optimization_level = ( onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL ) onnx_path = "model-quant.onnx" ort_session = onnxruntime.InferenceSession(onnx_path, session_options) processor = CLIPProcessor.from_pretrained( "openai/clip-vit-base-patch32" ).image_processor tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") image_inputs = processor(images=image, return_tensors="pt") classes = [] languages = [] for key, value in state.inputs.items(): languages.append(str(value[0])) classes.append(str(value[1])) languages = [langs[lang] for lang in languages] input_ids = [] attention_mask = [] for i, _ in enumerate(languages): tokenizer.set_src_lang_special_tokens(languages[i]) input = tokenizer.batch_encode_plus( [classes[i]], return_tensors="pt", padding="max_length", truncation=True, max_length=100, ) input_ids.append(input["input_ids"]) attention_mask.append(input["attention_mask"]) input_ids = torch.concat(input_ids, dim=0) attention_mask = torch.concat(attention_mask, dim=0) ort_inputs = { "pixel_values": image_inputs["pixel_values"].numpy(), "input_ids": input_ids.numpy(), "attention_mask": attention_mask.numpy(), } ort_outputs = ort_session.run(None, ort_inputs) logits = torch.tensor(ort_outputs[0]) probabilities = logits.softmax(dim=-1).squeeze().detach().numpy() chart_data = pd.DataFrame({"Class": classes, "Probability": probabilities}) chart_data = chart_data.sort_values(by=["Probability"], ascending=True) fig = px.bar(chart_data, x="Probability", y="Class", orientation="h") with col2: st.subheader("Predictions") st.write(fig) with col2: st.subheader("Classes") add_selectbox_and_input("Input 1") for i in range(2, state.count + 1): add_selectbox_and_input(f"Input {i}") if st.button("Add class"): state.count += 1 add_selectbox_and_input(f"Input {state.count}") st.markdown("""---""") if st.button("Generate"): with st.spinner("Processing the data"): process()