import ripple import streamlit as stl from tqdm.auto import tqdm # streamlit app stl.set_page_config( page_title="Ripple", ) stl.title("ripple search") stl.write( "An app that uses text input to search for described images, using embeddings of selected image datasets. Uses contrastive learning models(CLIP) and the sentence-transformers" ) stl.link_button( label="Full library code", url="https://github.com/kelechi-c/ripple_net", ) dataset = stl.selectbox( "choose huggingface dataset(bigger datasets take more time to embed..)", options=[ "huggan/few-shot-art-painting", "huggan/wikiart", "zh-plus/tiny-imagenet", "huggan/flowers-102-categories", "lambdalabs/naruto-blip-captions", "detection-datasets/fashionpedia", "fantasyfish/laion-art", "Chris1/cityscapes" ], ) # initalized global variables embedded_data = None embedder = None finder = None search_term = None ret_images = None scores = None #@stl.cache_data def embed_data(dataset): embedder = ripple.ImageEmbedder( dataset, retrieval_type="text-image", dataset_type="huggingface" ) embedded_data = embedder.create_embeddings(device="cpu") return embedded_data, embedder @stl.cache_resource def init_search(_embedded_data, _embedder): text_search = ripple.TextSearch(_embedded_data, _embedder.embed_model) stl.success("Initialized text search class") return text_search def get_images_from_description(finder, description): scores, ret_images = finder.get_similar_images(description, k_images=4) return scores, ret_images if dataset and stl.button("embed image dataset"): with stl.spinner("Initializing and creating image embeddings from dataset"): embedded_data, embedder = embed_data(dataset) stl.success("Successfully embedded and created image index") if embedded_data and embedder: finder = init_search(embedded_data, embedder) search_term = stl.text_input("Text description/search for image") if search_term is not None: with stl.spinner("retrieving images with description.."): scores, ret_images = get_images_from_description(finder, search_term) stl.success(f"sucessfully retrieved {len(ret_images)} images") try: for count, score, image in tqdm(zip(range(len(ret_images)), scores, ret_images)): stl.image(image["image"][count]) stl.write(score) except Exception as e: stl.error(e)