Spaces:
Sleeping
Sleeping
import ripple | |
import streamlit as stl | |
from tqdm.auto import tqdm | |
# streamlit app | |
stl.set_page_config( | |
page_title="Ripple", | |
) | |
stl.title("ripple search") | |
stl.write( | |
"An app that uses text input to search for described images, using embeddings of selected image datasets. Uses contrastive learning models(CLIP) and the sentence-transformers" | |
) | |
stl.link_button( | |
label="Full library code", | |
url="https://github.com/kelechi-c/ripple_net", | |
) | |
dataset = stl.selectbox( | |
"choose huggingface dataset(bigger datasets take more time to embed..)", | |
options=[ | |
"huggan/few-shot-art-painting", | |
"huggan/wikiart", | |
"zh-plus/tiny-imagenet", | |
"huggan/flowers-102-categories", | |
"lambdalabs/naruto-blip-captions", | |
"detection-datasets/fashionpedia", | |
"fantasyfish/laion-art", | |
"Chris1/cityscapes" | |
], | |
) | |
# initalized global variables | |
embedded_data = None | |
embedder = None | |
finder = None | |
search_term = None | |
ret_images = None | |
scores = None | |
#@stl.cache_data | |
def embed_data(dataset): | |
embedder = ripple.ImageEmbedder( | |
dataset, retrieval_type="text-image", dataset_type="huggingface" | |
) | |
embedded_data = embedder.create_embeddings(device="cpu") | |
return embedded_data, embedder | |
def init_search(_embedded_data, embedder): | |
text_search = ripple.TextSearch(_embedded_data, embedder.embed_model) | |
stl.success("Initialized text search class") | |
return text_search | |
def get_images_from_description(description): | |
scores, ret_images = finder.get_nearest_examples(description, k_images=4) | |
return scores, ret_images | |
if dataset and stl.button("embed image dataset"): | |
with stl.spinner("Initializing and creating image embeddings from dataset"): | |
embedded_data, embedder = embed_data(dataset) | |
stl.success("Successfully embedded and created image index") | |
if embedded_data and embedder: | |
finder = init_search(embedded_data, embedder) | |
search_term = stl.text_input("Text description/search for image") | |
if search_term: | |
with stl.spinner("retrieving images with description.."): | |
scores, ret_images = get_images_from_description(search_term) | |
stl.success(f"sucessfully retrieved {len(ret_images)} images") | |
try: | |
for count, score, image in tqdm(zip(range(len(ret_images)), scores, ret_images)): | |
stl.image(image["image"][count]) | |
stl.write(score) | |
except Exception as e: | |
stl.error(e) | |