broadwell commited on
Commit
3d8e28d
1 Parent(s): 3e811d1

Delete CLIP_Explainability/app.py

Browse files
Files changed (1) hide show
  1. CLIP_Explainability/app.py +0 -801
CLIP_Explainability/app.py DELETED
@@ -1,801 +0,0 @@
1
- # from base64 import b64encode
2
- from io import BytesIO
3
- from math import ceil
4
-
5
- import clip
6
- from multilingual_clip import legacy_multilingual_clip, pt_multilingual_clip
7
- import numpy as np
8
- import pandas as pd
9
- from PIL import Image
10
- import requests
11
- import streamlit as st
12
- import torch
13
- from torchvision.transforms import ToPILImage
14
- from transformers import AutoTokenizer, AutoModel, BertTokenizer
15
-
16
- from CLIP_Explainability.clip_ import load, tokenize
17
- from CLIP_Explainability.rn_cam import (
18
- # interpret_rn,
19
- interpret_rn_overlapped,
20
- rn_perword_relevance,
21
- )
22
- from CLIP_Explainability.vit_cam import (
23
- # interpret_vit,
24
- vit_perword_relevance,
25
- interpret_vit_overlapped,
26
- )
27
-
28
- from pytorch_grad_cam.grad_cam import GradCAM
29
-
30
- RUN_LITE = True # Load models for CAM viz for M-CLIP and J-CLIP only
31
-
32
- MAX_IMG_WIDTH = 500
33
- MAX_IMG_HEIGHT = 800
34
-
35
- st.set_page_config(layout="wide")
36
-
37
-
38
- # The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images.
39
- def find_best_matches(text_features, image_features, image_ids):
40
- # Compute the similarity between the search query and each image using the Cosine similarity
41
- similarities = (image_features @ text_features.T).squeeze(1)
42
-
43
- # Sort the images by their similarity score
44
- best_image_idx = (-similarities).argsort()
45
-
46
- # Return the image IDs of the best matches
47
- return [[image_ids[i], similarities[i].item()] for i in best_image_idx]
48
-
49
-
50
- # The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
51
- def encode_search_query(search_query, model_type):
52
- with torch.no_grad():
53
- # Encode and normalize the search query using the multilingual model
54
- if model_type == "M-CLIP (multilingual ViT)":
55
- text_encoded = st.session_state.ml_model.forward(
56
- search_query, st.session_state.ml_tokenizer
57
- )
58
- text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
59
- elif model_type == "J-CLIP (日本語 ViT)":
60
- t_text = st.session_state.ja_tokenizer(
61
- search_query,
62
- padding=True,
63
- return_tensors="pt",
64
- device=st.session_state.device,
65
- )
66
- text_encoded = st.session_state.ja_model.get_text_features(**t_text)
67
- text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
68
- else: # model_type == legacy
69
- text_encoded = st.session_state.rn_model(search_query)
70
- text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
71
-
72
- # Retrieve the feature vector
73
- return text_encoded.to(st.session_state.device)
74
-
75
-
76
- def clip_search(search_query):
77
- if st.session_state.search_field_value != search_query:
78
- st.session_state.search_field_value = search_query
79
-
80
- model_type = st.session_state.active_model
81
-
82
- if len(search_query) >= 1:
83
- text_features = encode_search_query(search_query, model_type)
84
-
85
- # Compute the similarity between the descrption and each photo using the Cosine similarity
86
- # similarities = list((text_features @ photo_features.T).squeeze(0))
87
-
88
- # Sort the photos by their similarity score
89
- if model_type == "M-CLIP (multilingual ViT)":
90
- matches = find_best_matches(
91
- text_features,
92
- st.session_state.ml_image_features,
93
- st.session_state.image_ids,
94
- )
95
- elif model_type == "J-CLIP (日本語 ViT)":
96
- matches = find_best_matches(
97
- text_features,
98
- st.session_state.ja_image_features,
99
- st.session_state.image_ids,
100
- )
101
- else: # model_type == legacy
102
- matches = find_best_matches(
103
- text_features,
104
- st.session_state.rn_image_features,
105
- st.session_state.image_ids,
106
- )
107
-
108
- st.session_state.search_image_ids = [match[0] for match in matches]
109
- st.session_state.search_image_scores = {match[0]: match[1] for match in matches}
110
-
111
-
112
- def string_search():
113
- st.session_state.disable_uploader = (
114
- RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
115
- )
116
-
117
- if "search_field_value" in st.session_state:
118
- clip_search(st.session_state.search_field_value)
119
-
120
-
121
- def load_image_features():
122
- # Load the image feature vectors
123
- if st.session_state.vision_mode == "tiled":
124
- ml_image_features = np.load("./image_features/tiled_ml_features.npy")
125
- ja_image_features = np.load("./image_features/tiled_ja_features.npy")
126
- rn_image_features = np.load("./image_features/tiled_rn_features.npy")
127
- elif st.session_state.vision_mode == "stretched":
128
- ml_image_features = np.load("./image_features/resized_ml_features.npy")
129
- ja_image_features = np.load("./image_features/resized_ja_features.npy")
130
- rn_image_features = np.load("./image_features/resized_rn_features.npy")
131
- else: # st.session_state.vision_mode == "cropped":
132
- ml_image_features = np.load("./image_features/cropped_ml_features.npy")
133
- ja_image_features = np.load("./image_features/cropped_ja_features.npy")
134
- rn_image_features = np.load("./image_features/cropped_rn_features.npy")
135
-
136
- # Convert features to Tensors: Float32 on CPU and Float16 on GPU
137
- device = st.session_state.device
138
- if device == "cpu":
139
- ml_image_features = torch.from_numpy(ml_image_features).float().to(device)
140
- ja_image_features = torch.from_numpy(ja_image_features).float().to(device)
141
- rn_image_features = torch.from_numpy(rn_image_features).float().to(device)
142
- else:
143
- ml_image_features = torch.from_numpy(ml_image_features).to(device)
144
- ja_image_features = torch.from_numpy(ja_image_features).to(device)
145
- rn_image_features = torch.from_numpy(rn_image_features).to(device)
146
-
147
- st.session_state.ml_image_features = ml_image_features / ml_image_features.norm(
148
- dim=-1, keepdim=True
149
- )
150
- st.session_state.ja_image_features = ja_image_features / ja_image_features.norm(
151
- dim=-1, keepdim=True
152
- )
153
- st.session_state.rn_image_features = rn_image_features / rn_image_features.norm(
154
- dim=-1, keepdim=True
155
- )
156
-
157
- string_search()
158
-
159
-
160
- def init():
161
- st.session_state.current_page = 1
162
-
163
- # device = "cuda" if torch.cuda.is_available() else "cpu"
164
- device = "cpu"
165
-
166
- st.session_state.device = device
167
-
168
- # Load the open CLIP models
169
-
170
- with st.spinner("Loading models and data, please wait..."):
171
- ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
172
- ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
173
-
174
- st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load(
175
- ml_model_path, device=device, jit=False
176
- )
177
-
178
- st.session_state.ml_model = (
179
- pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name)
180
- ).to(device)
181
- st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
182
-
183
- ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
184
- ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
185
-
186
- st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
187
- ja_model_path, device=device, jit=False
188
- )
189
-
190
- st.session_state.ja_model = AutoModel.from_pretrained(
191
- ja_model_name, trust_remote_code=True
192
- ).to(device)
193
- st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained(
194
- ja_model_name, trust_remote_code=True
195
- )
196
-
197
- if not RUN_LITE:
198
- st.session_state.rn_image_model, st.session_state.rn_image_preprocess = (
199
- clip.load("RN50x4", device=device)
200
- )
201
-
202
- st.session_state.rn_model = legacy_multilingual_clip.load_model(
203
- "M-BERT-Base-69"
204
- ).to(device)
205
- st.session_state.rn_tokenizer = BertTokenizer.from_pretrained(
206
- "bert-base-multilingual-cased"
207
- )
208
-
209
- # Load the image IDs
210
- st.session_state.images_info = pd.read_csv("./metadata.csv")
211
- st.session_state.images_info.set_index("filename", inplace=True)
212
-
213
- with open("./images_list.txt", "r", encoding="utf-8") as images_list:
214
- st.session_state.image_ids = list(images_list.read().strip().split("\n"))
215
-
216
- st.session_state.active_model = "M-CLIP (multilingual ViT)"
217
-
218
- st.session_state.vision_mode = "tiled"
219
- st.session_state.search_image_ids = []
220
- st.session_state.search_image_scores = {}
221
- st.session_state.text_table_df = None
222
- st.session_state.disable_uploader = (
223
- RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
224
- )
225
-
226
- with st.spinner("Loading models and data, please wait..."):
227
- load_image_features()
228
-
229
-
230
- if "images_info" not in st.session_state:
231
- init()
232
-
233
-
234
- def get_overlay_vis(image, img_dim, image_model):
235
- orig_img_dims = image.size
236
-
237
- ##### If the features are based on tiled image slices
238
- tile_behavior = None
239
-
240
- if st.session_state.vision_mode == "tiled":
241
- scaled_dims = [img_dim, img_dim]
242
-
243
- if orig_img_dims[0] > orig_img_dims[1]:
244
- scale_ratio = round(orig_img_dims[0] / orig_img_dims[1])
245
- if scale_ratio > 1:
246
- scaled_dims = [scale_ratio * img_dim, img_dim]
247
- tile_behavior = "width"
248
- elif orig_img_dims[0] < orig_img_dims[1]:
249
- scale_ratio = round(orig_img_dims[1] / orig_img_dims[0])
250
- if scale_ratio > 1:
251
- scaled_dims = [img_dim, scale_ratio * img_dim]
252
- tile_behavior = "height"
253
-
254
- resized_image = image.resize(scaled_dims, Image.LANCZOS)
255
-
256
- if tile_behavior == "width":
257
- image_tiles = []
258
- for x in range(0, scale_ratio):
259
- box = (x * img_dim, 0, (x + 1) * img_dim, img_dim)
260
- image_tiles.append(resized_image.crop(box))
261
-
262
- elif tile_behavior == "height":
263
- image_tiles = []
264
- for y in range(0, scale_ratio):
265
- box = (0, y * img_dim, img_dim, (y + 1) * img_dim)
266
- image_tiles.append(resized_image.crop(box))
267
-
268
- else:
269
- image_tiles = [resized_image]
270
-
271
- elif st.session_state.vision_mode == "stretched":
272
- image_tiles = [image.resize((img_dim, img_dim), Image.LANCZOS)]
273
-
274
- else: # vision_mode == "cropped"
275
- if orig_img_dims[0] > orig_img_dims[1]:
276
- scale_factor = orig_img_dims[0] / orig_img_dims[1]
277
- resized_img_dims = (round(scale_factor * img_dim), img_dim)
278
- resized_img = image.resize(resized_img_dims)
279
- elif orig_img_dims[0] < orig_img_dims[1]:
280
- scale_factor = orig_img_dims[1] / orig_img_dims[0]
281
- resized_img_dims = (img_dim, round(scale_factor * img_dim))
282
- else:
283
- resized_img_dims = (img_dim, img_dim)
284
-
285
- resized_img = image.resize(resized_img_dims)
286
-
287
- left = round((resized_img_dims[0] - img_dim) / 2)
288
- top = round((resized_img_dims[1] - img_dim) / 2)
289
- x_right = round(resized_img_dims[0] - img_dim) - left
290
- x_bottom = round(resized_img_dims[1] - img_dim) - top
291
- right = resized_img_dims[0] - x_right
292
- bottom = resized_img_dims[1] - x_bottom
293
-
294
- # Crop the center of the image
295
- image_tiles = [resized_img.crop((left, top, right, bottom))]
296
-
297
- image_visualizations = []
298
- image_features = []
299
- image_similarities = []
300
-
301
- if st.session_state.active_model == "M-CLIP (multilingual ViT)":
302
- text_features = st.session_state.ml_model.forward(
303
- st.session_state.search_field_value, st.session_state.ml_tokenizer
304
- )
305
-
306
- if st.session_state.device == "cpu":
307
- text_features = text_features.float().to(st.session_state.device)
308
- else:
309
- text_features = text_features.to(st.session_state.device)
310
-
311
- for altered_image in image_tiles:
312
- p_image = (
313
- st.session_state.ml_image_preprocess(altered_image)
314
- .unsqueeze(0)
315
- .to(st.session_state.device)
316
- )
317
-
318
- vis_t, img_feats, similarity = interpret_vit_overlapped(
319
- p_image.type(image_model.dtype),
320
- text_features.type(image_model.dtype),
321
- image_model.visual,
322
- st.session_state.device,
323
- img_dim=img_dim,
324
- )
325
-
326
- image_visualizations.append(vis_t)
327
- image_features.append(img_feats)
328
- image_similarities.append(similarity.item())
329
-
330
- elif st.session_state.active_model == "J-CLIP (日本語 ViT)":
331
- t_text = st.session_state.ja_tokenizer(
332
- st.session_state.search_field_value,
333
- return_tensors="pt",
334
- device=st.session_state.device,
335
- )
336
-
337
- text_features = st.session_state.ja_model.get_text_features(**t_text)
338
-
339
- if st.session_state.device == "cpu":
340
- text_features = text_features.float().to(st.session_state.device)
341
- else:
342
- text_features = text_features.to(st.session_state.device)
343
-
344
- for altered_image in image_tiles:
345
- p_image = (
346
- st.session_state.ja_image_preprocess(altered_image)
347
- .unsqueeze(0)
348
- .to(st.session_state.device)
349
- )
350
-
351
- vis_t, img_feats, similarity = interpret_vit_overlapped(
352
- p_image.type(image_model.dtype),
353
- text_features.type(image_model.dtype),
354
- image_model.visual,
355
- st.session_state.device,
356
- img_dim=img_dim,
357
- )
358
-
359
- image_visualizations.append(vis_t)
360
- image_features.append(img_feats)
361
- image_similarities.append(similarity.item())
362
-
363
- else: # st.session_state.active_model == Legacy
364
- text_features = st.session_state.rn_model(st.session_state.search_field_value)
365
-
366
- if st.session_state.device == "cpu":
367
- text_features = text_features.float().to(st.session_state.device)
368
- else:
369
- text_features = text_features.to(st.session_state.device)
370
-
371
- for altered_image in image_tiles:
372
- p_image = (
373
- st.session_state.rn_image_preprocess(altered_image)
374
- .unsqueeze(0)
375
- .to(st.session_state.device)
376
- )
377
-
378
- vis_t = interpret_rn_overlapped(
379
- p_image.type(image_model.dtype),
380
- text_features.type(image_model.dtype),
381
- image_model.visual,
382
- GradCAM,
383
- st.session_state.device,
384
- img_dim=img_dim,
385
- )
386
-
387
- text_features_norm = text_features.norm(dim=-1, keepdim=True)
388
- text_features_new = text_features / text_features_norm
389
-
390
- image_feats = image_model.encode_image(p_image.type(image_model.dtype))
391
- image_feats_norm = image_feats.norm(dim=-1, keepdim=True)
392
- image_feats_new = image_feats / image_feats_norm
393
-
394
- similarity = image_feats_new[0].dot(text_features_new[0])
395
-
396
- image_visualizations.append(vis_t)
397
- image_features.append(p_image)
398
- image_similarities.append(similarity.item())
399
-
400
- transform = ToPILImage()
401
-
402
- vis_images = [transform(vis_t) for vis_t in image_visualizations]
403
-
404
- if st.session_state.vision_mode == "cropped":
405
- resized_img.paste(vis_images[0], (left, top))
406
- vis_images = [resized_img]
407
-
408
- if orig_img_dims[0] > orig_img_dims[1]:
409
- scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
410
- scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)]
411
- else:
412
- scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
413
- scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
414
-
415
- if tile_behavior == "width":
416
- vis_image = Image.new("RGB", (len(vis_images) * img_dim, img_dim))
417
- for x, v_img in enumerate(vis_images):
418
- vis_image.paste(v_img, (x * img_dim, 0))
419
- activations_image = vis_image.resize(scaled_dims)
420
-
421
- elif tile_behavior == "height":
422
- vis_image = Image.new("RGB", (img_dim, len(vis_images) * img_dim))
423
- for y, v_img in enumerate(vis_images):
424
- vis_image.paste(v_img, (0, y * img_dim))
425
- activations_image = vis_image.resize(scaled_dims)
426
-
427
- else:
428
- activations_image = vis_images[0].resize(scaled_dims)
429
-
430
- return activations_image, image_features, np.mean(image_similarities)
431
-
432
-
433
- def visualize_gradcam(image):
434
- if "search_field_value" not in st.session_state:
435
- return
436
-
437
- header_cols = st.columns([80, 20], vertical_alignment="bottom")
438
- with header_cols[0]:
439
- st.title("Image + query activation gradients")
440
- with header_cols[1]:
441
- if st.button("Close"):
442
- st.rerun()
443
-
444
- if st.session_state.active_model == "M-CLIP (multilingual ViT)":
445
- img_dim = 240
446
- image_model = st.session_state.ml_image_model
447
- # Sometimes used for token importance viz
448
- tokenized_text = st.session_state.ml_tokenizer.tokenize(
449
- st.session_state.search_field_value
450
- )
451
- elif st.session_state.active_model == "Legacy (multilingual ResNet)":
452
- img_dim = 288
453
- image_model = st.session_state.rn_image_model
454
- # Sometimes used for token importance viz
455
- tokenized_text = st.session_state.rn_tokenizer.tokenize(
456
- st.session_state.search_field_value
457
- )
458
- else: # J-CLIP
459
- img_dim = 224
460
- image_model = st.session_state.ja_image_model
461
- # Sometimes used for token importance viz
462
- tokenized_text = st.session_state.ja_tokenizer.tokenize(
463
- st.session_state.search_field_value
464
- )
465
-
466
- st.image(image)
467
-
468
- with st.spinner("Calculating..."):
469
- # info_text = st.text("Calculating activation regions...")
470
-
471
- activations_image, image_features, similarity_score = get_overlay_vis(
472
- image, img_dim, image_model
473
- )
474
-
475
- st.markdown(
476
- f"**Query text:** {st.session_state.search_field_value} | **Approx. image relevance:** {round(similarity_score.item(), 3)}"
477
- )
478
-
479
- st.image(activations_image)
480
-
481
- # image_io = BytesIO()
482
- # activations_image.save(image_io, "PNG")
483
- # dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode(
484
- # "ascii"
485
- # )
486
-
487
- # st.html(
488
- # f"""<div style="display: flex; flex-direction: column; align-items: center;">
489
- # <img src="{dataurl}" />
490
- # </div>"""
491
- # )
492
-
493
- tokenized_text = [
494
- tok.replace("▁", "").replace("#", "") for tok in tokenized_text if tok != "▁"
495
- ]
496
- tokenized_text = [
497
- tok
498
- for tok in tokenized_text
499
- if tok
500
- not in ["s", "ed", "a", "the", "an", "ing", "て", "に", "の", "は", "と", "た"]
501
- ]
502
-
503
- if (
504
- len(tokenized_text) > 1
505
- and len(tokenized_text) < 25
506
- and st.button(
507
- "Calculate text importance (may take some time)",
508
- )
509
- ):
510
- scores_per_token = {}
511
-
512
- progress_text = f"Processing {len(tokenized_text)} text tokens"
513
- progress_bar = st.progress(0.0, text=progress_text)
514
-
515
- for t, tok in enumerate(tokenized_text):
516
- token = tok
517
-
518
- for img_feats in image_features:
519
- if st.session_state.active_model == "Legacy (multilingual ResNet)":
520
- word_rel = rn_perword_relevance(
521
- img_feats,
522
- st.session_state.search_field_value,
523
- image_model,
524
- tokenize,
525
- GradCAM,
526
- st.session_state.device,
527
- token,
528
- data_only=True,
529
- img_dim=img_dim,
530
- )
531
- else:
532
- word_rel = vit_perword_relevance(
533
- img_feats,
534
- st.session_state.search_field_value,
535
- image_model,
536
- tokenize,
537
- st.session_state.device,
538
- token,
539
- img_dim=img_dim,
540
- )
541
- avg_score = np.mean(word_rel)
542
- if avg_score == 0 or np.isnan(avg_score):
543
- continue
544
-
545
- if token not in scores_per_token:
546
- scores_per_token[token] = [1 / avg_score]
547
- else:
548
- scores_per_token[token].append(1 / avg_score)
549
-
550
- progress_bar.progress(
551
- (t + 1) / len(tokenized_text),
552
- text=f"Processing token {t+1} of {len(tokenized_text)}",
553
- )
554
- progress_bar.empty()
555
-
556
- avg_scores_per_token = [
557
- np.mean(scores_per_token[tok]) for tok in list(scores_per_token.keys())
558
- ]
559
-
560
- normed_scores = torch.softmax(torch.tensor(avg_scores_per_token), dim=0)
561
-
562
- token_scores = [f"{round(score.item() * 100, 3)}%" for score in normed_scores]
563
- st.session_state.text_table_df = pd.DataFrame(
564
- {"token": list(scores_per_token.keys()), "importance": token_scores}
565
- )
566
-
567
- st.markdown("**Importance of each text token to relevance score**")
568
- st.table(st.session_state.text_table_df)
569
-
570
-
571
- @st.dialog(" ", width="large")
572
- def image_modal(image):
573
- visualize_gradcam(image)
574
-
575
-
576
- def vis_known_image(vis_image_id):
577
- image_url = st.session_state.images_info.loc[vis_image_id]["image_url"]
578
- image_response = requests.get(image_url)
579
- image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF", "PNG"])
580
- image = image.convert("RGB")
581
-
582
- image_modal(image)
583
-
584
-
585
- def vis_uploaded_image():
586
- uploaded_file = st.session_state.uploaded_image
587
- if uploaded_file is not None:
588
- # To read file as bytes:
589
- bytes_data = uploaded_file.getvalue()
590
- image = Image.open(BytesIO(bytes_data), formats=["JPEG", "GIF", "PNG"])
591
- image = image.convert("RGB")
592
-
593
- image_modal(image)
594
-
595
-
596
- def format_vision_mode(mode_stub):
597
- return mode_stub.capitalize()
598
-
599
-
600
- st.title("Explore Japanese visual aesthetics with CLIP models")
601
-
602
- st.markdown(
603
- """
604
- <style>
605
- [data-testid=stImageCaption] {
606
- padding: 0 0 0 0;
607
- }
608
- [data-testid=stVerticalBlockBorderWrapper] {
609
- line-height: 1.2;
610
- }
611
- [data-testid=stVerticalBlock] {
612
- gap: .75rem;
613
- }
614
- [data-testid=baseButton-secondary] {
615
- min-height: 1rem;
616
- padding: 0 0.75rem;
617
- margin: 0 0 1rem 0;
618
- }
619
- div[aria-label="dialog"]>button[aria-label="Close"] {
620
- display: none;
621
- }
622
- [data-testid=stFullScreenFrame] {
623
- display: flex;
624
- flex-direction: column;
625
- align-items: center;
626
- }
627
- </style>
628
- """,
629
- unsafe_allow_html=True,
630
- )
631
-
632
- search_row = st.columns([45, 8, 8, 10, 1, 8, 20], vertical_alignment="center")
633
- with search_row[0]:
634
- search_field = st.text_input(
635
- label="search",
636
- label_visibility="collapsed",
637
- placeholder="Type something, or click a suggested search below.",
638
- on_change=string_search,
639
- key="search_field_value",
640
- )
641
- with search_row[1]:
642
- st.button(
643
- "Search", on_click=string_search, use_container_width=True, type="primary"
644
- )
645
- with search_row[2]:
646
- st.markdown("**Vision mode:**")
647
- with search_row[3]:
648
- st.selectbox(
649
- "Vision mode",
650
- options=["tiled", "stretched", "cropped"],
651
- key="vision_mode",
652
- help="How to consider images that aren't square",
653
- on_change=load_image_features,
654
- format_func=format_vision_mode,
655
- label_visibility="collapsed",
656
- )
657
- with search_row[4]:
658
- st.empty()
659
- with search_row[5]:
660
- st.markdown("**CLIP model:**")
661
- with search_row[6]:
662
- st.selectbox(
663
- "CLIP Model:",
664
- options=[
665
- "M-CLIP (multilingual ViT)",
666
- "J-CLIP (日本語 ViT)",
667
- "Legacy (multilingual ResNet)",
668
- ],
669
- key="active_model",
670
- on_change=string_search,
671
- label_visibility="collapsed",
672
- )
673
-
674
- canned_searches = st.columns([12, 22, 22, 22, 22], vertical_alignment="top")
675
- with canned_searches[0]:
676
- st.markdown("**Suggested searches:**")
677
- if st.session_state.active_model == "J-CLIP (日本語 ViT)":
678
- with canned_searches[1]:
679
- st.button(
680
- "間",
681
- on_click=clip_search,
682
- args=["間"],
683
- use_container_width=True,
684
- )
685
- with canned_searches[2]:
686
- st.button("奥", on_click=clip_search, args=["奥"], use_container_width=True)
687
- with canned_searches[3]:
688
- st.button("山", on_click=clip_search, args=["山"], use_container_width=True)
689
- with canned_searches[4]:
690
- st.button(
691
- "花に酔えり 羽織着て刀 さす女",
692
- on_click=clip_search,
693
- args=["花に酔えり 羽織着て刀 さす女"],
694
- use_container_width=True,
695
- )
696
- else:
697
- with canned_searches[1]:
698
- st.button(
699
- "negative space",
700
- on_click=clip_search,
701
- args=["negative space"],
702
- use_container_width=True,
703
- )
704
- with canned_searches[2]:
705
- st.button("間", on_click=clip_search, args=["間"], use_container_width=True)
706
- with canned_searches[3]:
707
- st.button("음각", on_click=clip_search, args=["음각"], use_container_width=True)
708
- with canned_searches[4]:
709
- st.button(
710
- "αρνητικός χώρος",
711
- on_click=clip_search,
712
- args=["αρνητικός χώρος"],
713
- use_container_width=True,
714
- )
715
-
716
- controls = st.columns([25, 25, 20, 35], gap="large", vertical_alignment="center")
717
- with controls[0]:
718
- im_per_pg = st.columns([30, 70], vertical_alignment="center")
719
- with im_per_pg[0]:
720
- st.markdown("**Images/page:**")
721
- with im_per_pg[1]:
722
- batch_size = st.select_slider(
723
- "Images/page:", range(10, 50, 10), label_visibility="collapsed"
724
- )
725
- with controls[1]:
726
- im_per_row = st.columns([30, 70], vertical_alignment="center")
727
- with im_per_row[0]:
728
- st.markdown("**Images/row:**")
729
- with im_per_row[1]:
730
- row_size = st.select_slider(
731
- "Images/row:", range(1, 6), value=5, label_visibility="collapsed"
732
- )
733
- num_batches = ceil(len(st.session_state.image_ids) / batch_size)
734
- with controls[2]:
735
- pager = st.columns([40, 60], vertical_alignment="center")
736
- with pager[0]:
737
- st.markdown(f"Page **{st.session_state.current_page}** of **{num_batches}** ")
738
- with pager[1]:
739
- st.number_input(
740
- "Page",
741
- min_value=1,
742
- max_value=num_batches,
743
- step=1,
744
- label_visibility="collapsed",
745
- key="current_page",
746
- )
747
- with controls[3]:
748
- st.file_uploader(
749
- "Upload an image",
750
- type=["jpg", "jpeg", "gif", "png"],
751
- key="uploaded_image",
752
- label_visibility="collapsed",
753
- on_change=vis_uploaded_image,
754
- disabled=st.session_state.disable_uploader,
755
- )
756
-
757
-
758
- if len(st.session_state.search_image_ids) == 0:
759
- batch = []
760
- else:
761
- batch = st.session_state.search_image_ids[
762
- (st.session_state.current_page - 1) * batch_size : st.session_state.current_page
763
- * batch_size
764
- ]
765
-
766
- grid = st.columns(row_size)
767
- col = 0
768
- for image_id in batch:
769
- with grid[col]:
770
- link_text = st.session_state.images_info.loc[image_id]["permalink"].split("/")[
771
- 2
772
- ]
773
- # st.image(
774
- # st.session_state.images_info.loc[image_id]["image_url"],
775
- # caption=st.session_state.images_info.loc[image_id]["caption"],
776
- # )
777
- st.html(
778
- f"""<div style="display: flex; flex-direction: column; align-items: center">
779
- <img src="{st.session_state.images_info.loc[image_id]['image_url']}" style="max-width: 100%; max-height: {MAX_IMG_HEIGHT}px" />
780
- <div>{st.session_state.images_info.loc[image_id]['caption']} <b>[{round(st.session_state.search_image_scores[image_id], 3)}]</b></div>
781
- </div>"""
782
- )
783
- st.caption(
784
- f"""<div style="display: flex; flex-direction: column; align-items: center; position: relative; top: -12px">
785
- <a href="{st.session_state.images_info.loc[image_id]['permalink']}">{link_text}</a>
786
- <div>""",
787
- unsafe_allow_html=True,
788
- )
789
- if not (
790
- RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
791
- ):
792
- st.button(
793
- "Explain this",
794
- on_click=vis_known_image,
795
- args=[image_id],
796
- use_container_width=True,
797
- key=image_id,
798
- )
799
- else:
800
- st.empty()
801
- col = (col + 1) % row_size