Spaces:
Runtime error
Runtime error
import os | |
import jax | |
import jax.numpy as jnp | |
import nmslib | |
import numpy as np | |
import streamlit as st | |
from PIL import Image | |
from transformers import AutoTokenizer, CLIPProcessor | |
from model import FlaxHybridCLIP | |
# st.header('Under construction') | |
st.sidebar.title("CLIP React Demo") | |
st.sidebar.write("[Model Card](https://huggingface.co/flax-community/clip-reply)") | |
sc = st.sidebar.columns(2) | |
sc[0].image("./huggingface_explode3.png", width=150) | |
sc[1].write(" ") | |
sc[1].write(" ") | |
sc[1].markdown("## Researching fun") | |
with st.sidebar.expander("Motivation", expanded=True): | |
st.markdown( | |
""" | |
Reaction GIFs became an integral part of communication. | |
They convey complex emotions with many levels, in a short compact format. | |
If a picture is worth a thousand words then a GIF is worth more. | |
A lot of people would agree it is not always easy to find the perfect reaction GIF. | |
This is just a first step in the more ambitious goal of GIF/Image generation. | |
""" | |
) | |
top_k = st.sidebar.slider("Show top-K", min_value=1, max_value=50, value=20) | |
col_count = 4 | |
file_names = os.listdir("./jpg") | |
file_names.sort() | |
show_val = st.sidebar.button("show all validation set images") | |
if show_val: | |
cols = st.sidebar.columns(col_count) | |
for i, im in enumerate(file_names): | |
j = i % col_count | |
cols[j].image("./jpg/" + im) | |
st.write("# Search Reaction GIFs with CLIP ") | |
st.write(" ") | |
st.write(" ") | |
def load_model(): | |
model = FlaxHybridCLIP.from_pretrained("ceyda/clip-reply") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
processor.tokenizer = AutoTokenizer.from_pretrained( | |
"cardiffnlp/twitter-roberta-base" | |
) | |
return model, processor | |
def load_image_index(): | |
index = nmslib.init(method="hnsw", space="cosinesimil") | |
index.loadIndex("./features/image_embeddings", load_data=True) | |
return index | |
image_index = load_image_index() | |
model, processor = load_model() | |
# TODO | |
def add_image_emb(image): | |
image = Image.open(image).convert("RGB") | |
inputs = processor(text=[""], images=image, return_tensors="jax", padding=True) | |
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) | |
features = model(**inputs).image_embeds | |
image_index.addDataPoint(features) | |
def query_with_images(query_images, query_text): | |
images = [] | |
for im in query_images: | |
img = Image.open(im).convert("RGB") | |
if im.name.endswith(".gif"): | |
img.seek(0) | |
images.append(img) | |
inputs = processor( | |
text=[query_text], images=images, return_tensors="jax", padding=True | |
) | |
inputs["pixel_values"] = jnp.transpose(inputs["pixel_values"], axes=[0, 2, 3, 1]) | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image.reshape(-1) | |
# st.write(logits_per_image) | |
probs = jax.nn.softmax(logits_per_image) | |
# st.write(probs) | |
# st.write(list(zip(images,probs))) | |
results = sorted(list(zip(images, probs)), key=lambda x: x[1], reverse=True) | |
# st.write(results) | |
return zip(*results) | |
q_cols = st.columns([5, 2, 5]) | |
examples = [ | |
"OMG that is disgusting", | |
"I'm so scared right now", | |
" I got the job 🎉", | |
"Congratulations to all the flax-community week teams", | |
"You're awesome", | |
"I love you ❤️", | |
] | |
example_input = q_cols[0].radio( | |
"Example Queries :", | |
examples, | |
index=4, | |
help="These are examples I wrote off the top of my head. They don't occur in the dataset", | |
) | |
q_cols[2].markdown( | |
""" | |
Searches among the validation set images if not specified | |
(There may be non-exact duplicates) | |
""" | |
) | |
query_text = q_cols[0].text_input( | |
"Write text you want to get reaction for", value=example_input | |
) | |
query_images = q_cols[2].file_uploader( | |
"(optional) Upload images to rank them", | |
type=["jpg", "jpeg", "gif"], | |
accept_multiple_files=True, | |
) | |
if query_images: | |
st.write("Ranking your uploaded images with respect to input text:") | |
with st.spinner("Calculating..."): | |
ids, dists = query_with_images(query_images, query_text) | |
else: | |
st.write("Found these images within validation set:") | |
with st.spinner("Calculating..."): | |
proc = processor( | |
text=[query_text], images=None, return_tensors="jax", padding=True | |
) | |
vec = np.asarray(model.get_text_features(**proc)) | |
ids, dists = image_index.knnQuery(vec, k=top_k) | |
show_gif = st.checkbox( | |
"Play GIFs", | |
value=True, | |
help="Will play the original animation. Only first frame is used in training!", | |
) | |
ext = "jpg" if not show_gif else "gif" | |
res_cols = st.columns(col_count) | |
for i, (id_, dist) in enumerate(zip(ids, dists)): | |
j = i % col_count | |
with res_cols[j]: | |
if isinstance(id_, np.int32): | |
st.image(f"./{ext}/{file_names[id_][:-4]}.{ext}") | |
# st.write(file_names[id_]) | |
st.write(1.0 - dist) | |
else: | |
st.image(id_) | |
st.write(dist) | |
# Credits | |
st.sidebar.caption("Made by [Ceyda Cinarel](https://huggingface.co/ceyda)") | |