broadwell commited on
Commit
62d2147
·
verified ·
1 Parent(s): 5954f19

Reorder functions

Browse files
Files changed (1) hide show
  1. app.py +66 -66
app.py CHANGED
@@ -24,6 +24,72 @@ MAX_IMG_HEIGHT = 800
24
  st.set_page_config(layout="wide")
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def load_image_features():
28
  # Load the image feature vectors
29
  if st.session_state.vision_mode == "tiled":
@@ -112,72 +178,6 @@ if "images_info" not in st.session_state:
112
  init()
113
 
114
 
115
- # The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
116
- def encode_search_query(search_query, model_type):
117
- with torch.no_grad():
118
- # Encode and normalize the search query using the multilingual model
119
- if model_type == "M-CLIP (multiple languages)":
120
- text_encoded = st.session_state.ml_model.forward(
121
- search_query, st.session_state.ml_tokenizer
122
- )
123
- text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
124
- else: # model_type == "J-CLIP (日本語 only)"
125
- t_text = st.session_state.ja_tokenizer(
126
- search_query, padding=True, return_tensors="pt"
127
- )
128
- text_encoded = st.session_state.ja_model.get_text_features(**t_text)
129
- text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
130
-
131
- # Retrieve the feature vector
132
- return text_encoded
133
-
134
-
135
- # The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.
136
- def find_best_matches(text_features, image_features, image_ids):
137
- # Compute the similarity between the search query and each image using the Cosine similarity
138
- similarities = (image_features @ text_features.T).squeeze(1)
139
-
140
- # Sort the images by their similarity score
141
- best_image_idx = (-similarities).argsort()
142
-
143
- # Return the image IDs of the best matches
144
- return [[image_ids[i], similarities[i].item()] for i in best_image_idx]
145
-
146
-
147
- def clip_search(search_query):
148
- if st.session_state.search_field_value != search_query:
149
- st.session_state.search_field_value = search_query
150
-
151
- model_type = st.session_state.active_model
152
-
153
- if len(search_query) >= 1:
154
- text_features = encode_search_query(search_query, model_type)
155
-
156
- # Compute the similarity between the descrption and each photo using the Cosine similarity
157
- # similarities = list((text_features @ photo_features.T).squeeze(0))
158
-
159
- # Sort the photos by their similarity score
160
- if model_type == "M-CLIP (multiple languages)":
161
- matches = find_best_matches(
162
- text_features,
163
- st.session_state.ml_image_features,
164
- st.session_state.image_ids,
165
- )
166
- else: # model_type == "J-CLIP (日本語 only)"
167
- matches = find_best_matches(
168
- text_features,
169
- st.session_state.ja_image_features,
170
- st.session_state.image_ids,
171
- )
172
-
173
- st.session_state.search_image_ids = [match[0] for match in matches]
174
- st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
175
-
176
-
177
- def string_search():
178
- clip_search(st.session_state.search_field_value)
179
-
180
-
181
  def visualize_gradcam(viz_image_id):
182
  if not st.session_state.search_field_value:
183
  return
 
24
  st.set_page_config(layout="wide")
25
 
26
 
27
+ # The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.
28
+ def find_best_matches(text_features, image_features, image_ids):
29
+ # Compute the similarity between the search query and each image using the Cosine similarity
30
+ similarities = (image_features @ text_features.T).squeeze(1)
31
+
32
+ # Sort the images by their similarity score
33
+ best_image_idx = (-similarities).argsort()
34
+
35
+ # Return the image IDs of the best matches
36
+ return [[image_ids[i], similarities[i].item()] for i in best_image_idx]
37
+
38
+
39
+ # The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
40
+ def encode_search_query(search_query, model_type):
41
+ with torch.no_grad():
42
+ # Encode and normalize the search query using the multilingual model
43
+ if model_type == "M-CLIP (multiple languages)":
44
+ text_encoded = st.session_state.ml_model.forward(
45
+ search_query, st.session_state.ml_tokenizer
46
+ )
47
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
48
+ else: # model_type == "J-CLIP (日本語 only)"
49
+ t_text = st.session_state.ja_tokenizer(
50
+ search_query, padding=True, return_tensors="pt"
51
+ )
52
+ text_encoded = st.session_state.ja_model.get_text_features(**t_text)
53
+ text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
54
+
55
+ # Retrieve the feature vector
56
+ return text_encoded
57
+
58
+
59
+ def clip_search(search_query):
60
+ if st.session_state.search_field_value != search_query:
61
+ st.session_state.search_field_value = search_query
62
+
63
+ model_type = st.session_state.active_model
64
+
65
+ if len(search_query) >= 1:
66
+ text_features = encode_search_query(search_query, model_type)
67
+
68
+ # Compute the similarity between the descrption and each photo using the Cosine similarity
69
+ # similarities = list((text_features @ photo_features.T).squeeze(0))
70
+
71
+ # Sort the photos by their similarity score
72
+ if model_type == "M-CLIP (multiple languages)":
73
+ matches = find_best_matches(
74
+ text_features,
75
+ st.session_state.ml_image_features,
76
+ st.session_state.image_ids,
77
+ )
78
+ else: # model_type == "J-CLIP (日本語 only)"
79
+ matches = find_best_matches(
80
+ text_features,
81
+ st.session_state.ja_image_features,
82
+ st.session_state.image_ids,
83
+ )
84
+
85
+ st.session_state.search_image_ids = [match[0] for match in matches]
86
+ st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
87
+
88
+
89
+ def string_search():
90
+ clip_search(st.session_state.search_field_value)
91
+
92
+
93
  def load_image_features():
94
  # Load the image feature vectors
95
  if st.session_state.vision_mode == "tiled":
 
178
  init()
179
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  def visualize_gradcam(viz_image_id):
182
  if not st.session_state.search_field_value:
183
  return