laion-nllb / app.py
visheratin's picture
Upload 4 files
f04d812
import random
import onnxruntime
import pandas as pd
import plotly.express as px
import streamlit as st
import torch
from lang_map import langs
from PIL import Image
from transformers import AutoTokenizer, CLIPProcessor
st.set_page_config(layout="wide")
options = list(langs.keys())
class SessionState:
def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
def get_state(**kwargs):
if "session_state" not in st.session_state:
st.session_state["session_state"] = SessionState(**kwargs)
return st.session_state["session_state"]
def add_selectbox_and_input(key):
col1, col2 = st.columns(2)
with col1:
select = st.selectbox("Select a language", options, key=f"{key}_select")
with col2:
user_input = st.text_input("Input text", key=f"{key}_text")
state.inputs[key] = (select, user_input)
state = get_state(count=1, inputs={})
st.title("Zero-shot image classification with CLIP in 201 languages")
col1, col2 = st.columns(2)
image: Image.Image = None
with col1:
st.subheader("Image")
uploaded_file = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image.", use_column_width=True)
def process():
session_options = onnxruntime.SessionOptions()
session_options.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
onnx_path = "model-quant.onnx"
ort_session = onnxruntime.InferenceSession(onnx_path, session_options)
processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32"
).image_processor
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
image_inputs = processor(images=image, return_tensors="pt")
classes = []
languages = []
for key, value in state.inputs.items():
languages.append(str(value[0]))
classes.append(str(value[1]))
languages = [langs[lang] for lang in languages]
input_ids = []
attention_mask = []
for i, _ in enumerate(languages):
tokenizer.set_src_lang_special_tokens(languages[i])
input = tokenizer.batch_encode_plus(
[classes[i]],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=100,
)
input_ids.append(input["input_ids"])
attention_mask.append(input["attention_mask"])
input_ids = torch.concat(input_ids, dim=0)
attention_mask = torch.concat(attention_mask, dim=0)
ort_inputs = {
"pixel_values": image_inputs["pixel_values"].numpy(),
"input_ids": input_ids.numpy(),
"attention_mask": attention_mask.numpy(),
}
ort_outputs = ort_session.run(None, ort_inputs)
logits = torch.tensor(ort_outputs[0])
probabilities = logits.softmax(dim=-1).squeeze().detach().numpy()
chart_data = pd.DataFrame({"Class": classes, "Probability": probabilities})
chart_data = chart_data.sort_values(by=["Probability"], ascending=True)
fig = px.bar(chart_data, x="Probability", y="Class", orientation="h")
with col2:
st.subheader("Predictions")
st.write(fig)
with col2:
st.subheader("Classes")
add_selectbox_and_input("Input 1")
for i in range(2, state.count + 1):
add_selectbox_and_input(f"Input {i}")
if st.button("Add class"):
state.count += 1
add_selectbox_and_input(f"Input {state.count}")
st.markdown("""---""")
if st.button("Generate"):
with st.spinner("Processing the data"):
process()