File size: 2,575 Bytes
44883c6
 
52719f3
44883c6
 
 
 
 
 
 
 
7963048
44883c6
 
5064dd2
44883c6
 
 
 
52719f3
44883c6
52719f3
 
5064dd2
e51ef53
52719f3
 
 
 
44883c6
 
 
 
 
 
d945c9f
5064dd2
e572f71
 
44883c6
e51ef53
 
 
44883c6
e51ef53
 
e4c4ba8
44883c6
e51ef53
9812662
b40ae18
44883c6
b40ae18
e51ef53
0b253de
dc3bcae
e51ef53
d945c9f
e51ef53
 
 
a646756
e51ef53
44883c6
a646756
72c69ab
0b253de
 
b40ae18
 
 
 
 
 
 
 
 
 
7963048
5064dd2
 
 
 
e51ef53
 
7f88749
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import ripple
import streamlit as stl
from tqdm.auto import tqdm

# streamlit app
stl.set_page_config(
    page_title="Ripple",
)

stl.title("ripple search")
stl.write(
    "An app that uses text input to search for described images, using embeddings of selected image datasets. Uses contrastive learning models(CLIP) and the sentence-transformers"
)
stl.link_button(
    label="Full library code",
    url="https://github.com/kelechi-c/ripple_net",
)

dataset = stl.selectbox(
    "choose huggingface dataset(bigger datasets take more time to embed..)",
    options=[
        "huggan/few-shot-art-painting",
        "huggan/wikiart",
        "zh-plus/tiny-imagenet",
        "huggan/flowers-102-categories",
        "lambdalabs/naruto-blip-captions",
        "detection-datasets/fashionpedia",
        "fantasyfish/laion-art",
        "Chris1/cityscapes"
    ],
)
# initalized global variables

embedded_data = None
embedder = None
finder = None
search_term = None
ret_images = None
scores = None

#@stl.cache_data
def embed_data(dataset):
    embedder = ripple.ImageEmbedder(
            dataset, retrieval_type="text-image", dataset_type="huggingface"
    )
    embedded_data = embedder.create_embeddings(device="cpu")
    return embedded_data, embedder

@stl.cache_resource
def init_search(_embedded_data, _embedder):
    finder = ripple.TextSearch(_embedded_data, _embedder.embed_model)
    stl.success("Initialized text search class")
    return finder

def get_images_from_description(finder, description):
    scores, ret_images = finder.get_similar_images(description, k_images=4)
    return scores, ret_images
    

if dataset and stl.button("embed image dataset"):
    with stl.spinner("Initializing and creating image embeddings from dataset"):
        embedded_data, embedder = embed_data(dataset)
        stl.success("Successfully embedded and created image index")

if embedded_data and embedder:
    finder = init_search(embedded_data, embedder)


try:
    search_term = stl.text_input("Text description/search for image")
    
    if search_term is not None:
        with stl.spinner(f"retrieving images with description..'{search_term}'"):
            scores, ret_images = get_images_from_description(finder, search_term)
            stl.success(f"sucessfully retrieved {len(ret_images)} images")
            
except Exception as e:
    stl.error(e)

try:
    for count, score, image in tqdm(zip(range(len(ret_images)), scores, ret_images)):
        stl.image(image["image"][count])
        stl.write(score)
        
except Exception as e:
    stl.error(e)