Reorder functions
Browse files
app.py
CHANGED
@@ -24,6 +24,72 @@ MAX_IMG_HEIGHT = 800
|
|
24 |
st.set_page_config(layout="wide")
|
25 |
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def load_image_features():
|
28 |
# Load the image feature vectors
|
29 |
if st.session_state.vision_mode == "tiled":
|
@@ -112,72 +178,6 @@ if "images_info" not in st.session_state:
|
|
112 |
init()
|
113 |
|
114 |
|
115 |
-
# The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
|
116 |
-
def encode_search_query(search_query, model_type):
|
117 |
-
with torch.no_grad():
|
118 |
-
# Encode and normalize the search query using the multilingual model
|
119 |
-
if model_type == "M-CLIP (multiple languages)":
|
120 |
-
text_encoded = st.session_state.ml_model.forward(
|
121 |
-
search_query, st.session_state.ml_tokenizer
|
122 |
-
)
|
123 |
-
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
|
124 |
-
else: # model_type == "J-CLIP (日本語 only)"
|
125 |
-
t_text = st.session_state.ja_tokenizer(
|
126 |
-
search_query, padding=True, return_tensors="pt"
|
127 |
-
)
|
128 |
-
text_encoded = st.session_state.ja_model.get_text_features(**t_text)
|
129 |
-
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
|
130 |
-
|
131 |
-
# Retrieve the feature vector
|
132 |
-
return text_encoded
|
133 |
-
|
134 |
-
|
135 |
-
# The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.
|
136 |
-
def find_best_matches(text_features, image_features, image_ids):
|
137 |
-
# Compute the similarity between the search query and each image using the Cosine similarity
|
138 |
-
similarities = (image_features @ text_features.T).squeeze(1)
|
139 |
-
|
140 |
-
# Sort the images by their similarity score
|
141 |
-
best_image_idx = (-similarities).argsort()
|
142 |
-
|
143 |
-
# Return the image IDs of the best matches
|
144 |
-
return [[image_ids[i], similarities[i].item()] for i in best_image_idx]
|
145 |
-
|
146 |
-
|
147 |
-
def clip_search(search_query):
|
148 |
-
if st.session_state.search_field_value != search_query:
|
149 |
-
st.session_state.search_field_value = search_query
|
150 |
-
|
151 |
-
model_type = st.session_state.active_model
|
152 |
-
|
153 |
-
if len(search_query) >= 1:
|
154 |
-
text_features = encode_search_query(search_query, model_type)
|
155 |
-
|
156 |
-
# Compute the similarity between the descrption and each photo using the Cosine similarity
|
157 |
-
# similarities = list((text_features @ photo_features.T).squeeze(0))
|
158 |
-
|
159 |
-
# Sort the photos by their similarity score
|
160 |
-
if model_type == "M-CLIP (multiple languages)":
|
161 |
-
matches = find_best_matches(
|
162 |
-
text_features,
|
163 |
-
st.session_state.ml_image_features,
|
164 |
-
st.session_state.image_ids,
|
165 |
-
)
|
166 |
-
else: # model_type == "J-CLIP (日本語 only)"
|
167 |
-
matches = find_best_matches(
|
168 |
-
text_features,
|
169 |
-
st.session_state.ja_image_features,
|
170 |
-
st.session_state.image_ids,
|
171 |
-
)
|
172 |
-
|
173 |
-
st.session_state.search_image_ids = [match[0] for match in matches]
|
174 |
-
st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
|
175 |
-
|
176 |
-
|
177 |
-
def string_search():
|
178 |
-
clip_search(st.session_state.search_field_value)
|
179 |
-
|
180 |
-
|
181 |
def visualize_gradcam(viz_image_id):
|
182 |
if not st.session_state.search_field_value:
|
183 |
return
|
|
|
24 |
st.set_page_config(layout="wide")
|
25 |
|
26 |
|
27 |
+
# The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.
|
28 |
+
def find_best_matches(text_features, image_features, image_ids):
|
29 |
+
# Compute the similarity between the search query and each image using the Cosine similarity
|
30 |
+
similarities = (image_features @ text_features.T).squeeze(1)
|
31 |
+
|
32 |
+
# Sort the images by their similarity score
|
33 |
+
best_image_idx = (-similarities).argsort()
|
34 |
+
|
35 |
+
# Return the image IDs of the best matches
|
36 |
+
return [[image_ids[i], similarities[i].item()] for i in best_image_idx]
|
37 |
+
|
38 |
+
|
39 |
+
# The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
|
40 |
+
def encode_search_query(search_query, model_type):
|
41 |
+
with torch.no_grad():
|
42 |
+
# Encode and normalize the search query using the multilingual model
|
43 |
+
if model_type == "M-CLIP (multiple languages)":
|
44 |
+
text_encoded = st.session_state.ml_model.forward(
|
45 |
+
search_query, st.session_state.ml_tokenizer
|
46 |
+
)
|
47 |
+
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
|
48 |
+
else: # model_type == "J-CLIP (日本語 only)"
|
49 |
+
t_text = st.session_state.ja_tokenizer(
|
50 |
+
search_query, padding=True, return_tensors="pt"
|
51 |
+
)
|
52 |
+
text_encoded = st.session_state.ja_model.get_text_features(**t_text)
|
53 |
+
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
|
54 |
+
|
55 |
+
# Retrieve the feature vector
|
56 |
+
return text_encoded
|
57 |
+
|
58 |
+
|
59 |
+
def clip_search(search_query):
|
60 |
+
if st.session_state.search_field_value != search_query:
|
61 |
+
st.session_state.search_field_value = search_query
|
62 |
+
|
63 |
+
model_type = st.session_state.active_model
|
64 |
+
|
65 |
+
if len(search_query) >= 1:
|
66 |
+
text_features = encode_search_query(search_query, model_type)
|
67 |
+
|
68 |
+
# Compute the similarity between the descrption and each photo using the Cosine similarity
|
69 |
+
# similarities = list((text_features @ photo_features.T).squeeze(0))
|
70 |
+
|
71 |
+
# Sort the photos by their similarity score
|
72 |
+
if model_type == "M-CLIP (multiple languages)":
|
73 |
+
matches = find_best_matches(
|
74 |
+
text_features,
|
75 |
+
st.session_state.ml_image_features,
|
76 |
+
st.session_state.image_ids,
|
77 |
+
)
|
78 |
+
else: # model_type == "J-CLIP (日本語 only)"
|
79 |
+
matches = find_best_matches(
|
80 |
+
text_features,
|
81 |
+
st.session_state.ja_image_features,
|
82 |
+
st.session_state.image_ids,
|
83 |
+
)
|
84 |
+
|
85 |
+
st.session_state.search_image_ids = [match[0] for match in matches]
|
86 |
+
st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
|
87 |
+
|
88 |
+
|
89 |
+
def string_search():
|
90 |
+
clip_search(st.session_state.search_field_value)
|
91 |
+
|
92 |
+
|
93 |
def load_image_features():
|
94 |
# Load the image feature vectors
|
95 |
if st.session_state.vision_mode == "tiled":
|
|
|
178 |
init()
|
179 |
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
def visualize_gradcam(viz_image_id):
|
182 |
if not st.session_state.search_field_value:
|
183 |
return
|