Spaces:
Runtime error
Runtime error
File size: 3,700 Bytes
6cc012f a8d91bf 6cc012f 578e499 6cc012f 578e499 6cc012f 578e499 6cc012f 578e499 6cc012f 578e499 6cc012f 578e499 6cc012f 4474721 578e499 6cc012f 578e499 a8d91bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import nmslib
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, CLIPProcessor
from model import FlaxHybridCLIP
from PIL import Image
import jax.numpy as jnp
import os
import jax
# st.header('Under construction')
st.sidebar.write("")
st.title("CLIP React Demo")
st.write("[Model Card](https://huggingface.co/flax-community/clip-reply)")
st.write(" ")
@st.cache(allow_output_mutation=True)
def load_model():
model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
processor.tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base")
return model, processor
@st.cache(allow_output_mutation=True)
def load_image_index():
index = nmslib.init(method='hnsw', space='cosinesimil')
index.loadIndex("./features/image_embeddings", load_data=True)
return index
file_names=os.listdir("./imgs")
file_names.sort()
image_index = load_image_index()
model, processor = load_model()
col_count=4
top_k=st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20)
show_val=st.sidebar.button("show all validation set images")
if show_val:
cols=st.sidebar.beta_columns(col_count)
for i,im in enumerate(file_names):
j=i%col_count
cols[j].image("./imgs/"+im)
# TODO
def add_image_emb(image):
image = Image.open(image).convert("RGB")
inputs = processor(text=[""], images=image, return_tensors="jax", padding=True)
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
features = model(**inputs).image_embeds
image_index.addDataPoint(features)
def query_with_images(query_images,query_text):
images = [Image.open(im).convert("RGB") for im in query_images]
inputs = processor(text=[query_text], images=images, return_tensors="jax", padding=True)
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1])
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image.reshape(-1)
st.write(logits_per_image)
probs = jax.nn.softmax(logits_per_image)
st.write(probs)
st.write(list(zip(images,probs)))
results = sorted(list(zip(images,probs)),key=lambda x: x[1], reverse=True)
st.write(results)
return zip(*results)
q_cols=st.beta_columns([5,2,5])
q_cols[0].markdown(
"""
Example Queries :
- I'm so scared right now
- I got the job 🎉
- OMG that is disgusting
- I'm awesome
"""
)
q_cols[2].markdown(
"""
Searches among the validation set images if not specified
(There may be non-exact duplicates)
"""
)
query_text = q_cols[0].text_input("Input text you want to get reaction for", value="I love you ❤️")
query_images = q_cols[2].file_uploader("(optional) Upload images to rank them",type=['jpg','jpeg'], accept_multiple_files=True)
if query_images:
st.write("Ranking your uploaded images with respect to input text:")
ids, dists = query_with_images(query_images,query_text)
else:
st.write("Found these images within validation set:")
proc = processor(text=[query_text], images=None, return_tensors="jax", padding=True)
vec = np.asarray(model.get_text_features(**proc))
ids, dists = image_index.knnQuery(vec, k=top_k)
res_cols=st.beta_columns(col_count)
for i,(id_, dist) in enumerate(zip(ids, dists)):
j=i%col_count
with res_cols[j]:
if isinstance(id_, np.int32):
st.image("./imgs/"+file_names[id_])
# st.write(file_names[id_])
st.write(1.0 - dist, help="score")
else:
st.image(id_)
st.write(dist, help="score")
|