import nmslib import numpy as np import streamlit as st from transformers import AutoTokenizer, CLIPProcessor from model import FlaxHybridCLIP from PIL import Image import jax.numpy as jnp import os import jax st.header('Under construction') st.title("CLIP Reply Demo") st.sidebar.markdown( """ Validation set: 351 images/273 deduped (There are still duplicates) Example Queries : """ ) @st.cache(allow_output_mutation=True) def load_model(): model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") processor.tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base") return model, processor @st.cache(allow_output_mutation=True) def load_image_index(): index = nmslib.init(method='hnsw', space='cosinesimil') index.loadIndex("./features/image_embeddings", load_data=True) return index file_names=os.listdir("./imgs") file_names.sort() image_index = load_image_index() model, processor = load_model() col_count=4 top_k=10 show_val=st.sidebar.button("show all validation set images") if show_val: cols=st.sidebar.beta_columns(col_count) for i,im in enumerate(file_names): j=i%col_count cols[j].image("./imgs/"+im) # TODO def add_image_emb(image): image = Image.open(image).convert("RGB") inputs = processor(text=[""], images=image, return_tensors="jax", padding=True) inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) features = model(**inputs).image_embeds image_index.addDataPoint(features) def query_with_images(query_images,query_text): images = [Image.open(im).convert("RGB") for im in query_images] inputs = processor(text=[query_text], images=images, return_tensors="jax", padding=True) inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) outputs = model(**inputs) logits_per_image = outputs.logits_per_image.reshape(-1) st.write(logits_per_image) probs = jax.nn.softmax(logits_per_image) st.write(probs) st.write(list(zip(images,probs))) results = sorted(list(zip(images,probs)),key=lambda x: x[1], reverse=True) st.write(results) return zip(*results) q_cols=st.beta_columns(2) query_text = q_cols[0].text_input("Input text", value="I love you") query_images = q_cols[1].file_uploader("(optional) upload query image",type=['jpg','jpeg'], accept_multiple_files=True) if query_images: st.write("Ranking uploaded images with respect to input text") ids, dists = query_with_images(query_images,query_text) else: st.write("Finding within validation set") proc = processor(text=[query_text], images=None, return_tensors="jax", padding=True) vec = np.asarray(model.get_text_features(**proc)) ids, dists = image_index.knnQuery(vec, k=top_k) res_cols=st.beta_columns(col_count) for i,(id_, dist) in enumerate(zip(ids, dists)): j=i%col_count with res_cols[j]: if isinstance(id_, np.int32): st.image("./imgs/"+file_names[id_]) # st.write(file_names[id_]) st.write(1.0 - dist) else: st.image(id_) st.write(dist)