tensorkelechi commited on
Commit
e51ef53
1 Parent(s): 0aae610

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -14
app.py CHANGED
@@ -22,6 +22,7 @@ dataset = stl.selectbox(
22
  "huggan/few-shot-art-painting",
23
  "huggan/wikiart",
24
  "zh-plus/tiny-imagenet",
 
25
  "lambdalabs/naruto-blip-captions",
26
  "detection-datasets/fashionpedia",
27
  "fantasyfish/laion-art",
@@ -37,30 +38,43 @@ search_term = None
37
  ret_images = []
38
  scores = []
39
 
40
-
41
- if dataset and stl.button("embed image dataset"):
42
- with stl.spinner("Initializing and creating image embeddings from dataset"):
43
- embedder = ripple.ImageEmbedder(
44
  dataset, retrieval_type="text-image", dataset_type="huggingface"
45
- )
46
-
47
- embedded_data = embedder.create_embeddings(device="cpu")
48
- stl.success("Sucessfully embedded and dcreated image index")
49
 
50
- if embedded_data is not None:
 
51
  text_search = ripple.TextSearch(embedded_data, embedder.embed_model)
52
  stl.success("Initialized text search class")
 
 
 
 
 
 
 
 
 
 
53
 
 
 
 
54
  search_term = stl.text_input("Text description/search for image")
55
 
56
- if search_term:
57
- with stl.spinner("retrieving images with description.."):
58
- scores, ret_images = text_search.get_similar_images(search_term, k_images=4)
59
- stl.success(f"sucessfully retrieved {len(ret_images)} images")
60
 
61
  try:
62
  for count, score, image in tqdm(zip(range(len(ret_images)), scores, ret_images)):
63
  stl.image(image["image"][count])
64
  stl.write(score)
65
- except Exceptio as e:
 
66
  st.error(e)
 
22
  "huggan/few-shot-art-painting",
23
  "huggan/wikiart",
24
  "zh-plus/tiny-imagenet",
25
+ "huggan/flowers-102-categories",
26
  "lambdalabs/naruto-blip-captions",
27
  "detection-datasets/fashionpedia",
28
  "fantasyfish/laion-art",
 
38
  ret_images = []
39
  scores = []
40
 
41
+ #@stl.cache_data
42
+ def embed_data(dataset):
43
+ embedder = ripple.ImageEmbedder(
 
44
  dataset, retrieval_type="text-image", dataset_type="huggingface"
45
+ )
46
+ embedded_data = embedder.create_embeddings(device="cpu")
47
+ return embedded_data
 
48
 
49
+ @stl.cache_resource
50
+ def init_search(embedded_data):
51
  text_search = ripple.TextSearch(embedded_data, embedder.embed_model)
52
  stl.success("Initialized text search class")
53
+ return text_search
54
+
55
+ def get_images_from_description(description):
56
+ scores, ret_images = description, k_images=4)
57
+ return scores, ret_images
58
+
59
+ if dataset and stl.button("embed image dataset"):
60
+ with stl.spinner("Initializing and creating image embeddings from dataset"):
61
+ embedded_data = embed_data(dataset)
62
+ stl.success("Successfully embedded and created image index")
63
 
64
+ if embedded_data:
65
+ finder = init_search(embedded_data)
66
+
67
  search_term = stl.text_input("Text description/search for image")
68
 
69
+ if search_term:
70
+ with stl.spinner("retrieving images with description.."):
71
+ scores, ret_images = get_images_from_description(search_term)
72
+ stl.success(f"sucessfully retrieved {len(ret_images)} images")
73
 
74
  try:
75
  for count, score, image in tqdm(zip(range(len(ret_images)), scores, ret_images)):
76
  stl.image(image["image"][count])
77
  stl.write(score)
78
+
79
+ except Exception as e:
80
  st.error(e)