kaushalya commited on
Commit
1366c30
·
1 Parent(s): cffabcf

Show matching images

Browse files
Files changed (2) hide show
  1. app.py +11 -2
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
 
 
4
  from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
5
  from medclip.modeling_hybrid_clip import FlaxHybridCLIP
6
 
@@ -17,10 +19,10 @@ def load_image_embeddings():
17
  image_files = np.asarray(embeddings_df['files'].tolist())
18
  return image_files, image_embeds
19
 
20
- # def app():
21
  k = 5
22
  image_list, image_embeddings = load_image_embeddings()
23
  model, processor = load_model()
 
24
 
25
  query = st.text_input("Search:")
26
 
@@ -34,5 +36,12 @@ if st.button("Search"):
34
  query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
35
  dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
36
  matching_images = image_list[dot_prod.argsort()[-k:]]
37
- st.write(f"matching images: {matching_images}")
 
 
 
 
 
 
 
38
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
+ import os
5
+ import matplotlib.pyplot as plt
6
  from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
7
  from medclip.modeling_hybrid_clip import FlaxHybridCLIP
8
 
 
19
  image_files = np.asarray(embeddings_df['files'].tolist())
20
  return image_files, image_embeds
21
 
 
22
  k = 5
23
  image_list, image_embeddings = load_image_embeddings()
24
  model, processor = load_model()
25
+ img_dir = './images'
26
 
27
  query = st.text_input("Search:")
28
 
 
36
  query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
37
  dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1)
38
  matching_images = image_list[dot_prod.argsort()[-k:]]
39
+
40
+ # st.write(f"matching images: {matching_images}")
41
+ #show images
42
+
43
+ for img_path in matching_images:
44
+ img = plt.imread(os.path.join(img_dir, img_path))
45
+ st.write(img_path)
46
+ st.image(img)
47
 
requirements.txt CHANGED
@@ -5,5 +5,6 @@ streamlit==0.84.1
5
  torch==1.9.0
6
  torchvision==0.10.0
7
  pandas==1.3.0
 
8
  transformers==4.8.2
9
  watchdog==2.1.3
 
5
  torch==1.9.0
6
  torchvision==0.10.0
7
  pandas==1.3.0
8
+ matplotlib>=3.4.0
9
  transformers==4.8.2
10
  watchdog==2.1.3