kaushalya commited on
Commit
ba9c45e
·
1 Parent(s): f15bc50

Add spinner to search

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -36,21 +36,19 @@ model, processor = load_model()
36
  query = st.text_input("Enter your query here:")
37
 
38
  if st.button("Search"):
39
- st.write(f"Searching our image database for {query}...")
40
-
41
- inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
42
-
43
- query_embedding = model.get_text_features(**inputs)
44
- query_embedding = np.asarray(query_embedding)
45
- query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
46
- dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
47
- topk_images = dot_prod.argsort()[-k:]
48
- matching_images = image_list[topk_images]
49
- top_scores = 1. - dot_prod[topk_images]
50
- #show images
51
-
52
- for img_path, score in zip(matching_images, top_scores):
53
- img = plt.imread(os.path.join(img_dir, img_path))
54
- st.image(img)
55
- st.write(f"{img_path} ({score:.2f})", help="score")
56
 
 
36
  query = st.text_input("Enter your query here:")
37
 
38
  if st.button("Search"):
39
+ with st.spinner(f"Searching ROCO test set for {query}..."):
40
+ inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
41
+
42
+ query_embedding = model.get_text_features(**inputs)
43
+ query_embedding = np.asarray(query_embedding)
44
+ query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
45
+ dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
46
+ topk_images = dot_prod.argsort()[-k:]
47
+ matching_images = image_list[topk_images]
48
+ top_scores = 1. - dot_prod[topk_images]
49
+ #show images
50
+ for img_path, score in zip(matching_images, top_scores):
51
+ img = plt.imread(os.path.join(img_dir, img_path))
52
+ st.image(img)
53
+ st.write(f"{img_path} ({score:.2f})", help="score")
 
 
54