broadwell commited on
Commit
3adc1a0
1 Parent(s): 00bcde0

Can select and visualizing results from cropping, stretching or tiling images

Browse files
Files changed (1) hide show
  1. app.py +209 -92
app.py CHANGED
@@ -18,12 +18,43 @@ from CLIP_Explainability.vit_cam import (
18
  vit_perword_relevance,
19
  ) # , interpret_vit_overlapped
20
 
21
- MAX_IMG_WIDTH = 450 # For small dialog
22
  MAX_IMG_HEIGHT = 800
23
 
24
  st.set_page_config(layout="wide")
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def init():
28
  st.session_state.current_page = 1
29
 
@@ -34,74 +65,51 @@ def init():
34
  ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
35
  ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
36
 
37
- st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load(
38
- ml_model_path, device=device, jit=False
39
- )
 
40
 
41
- st.session_state.ml_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(
42
- ml_model_name
43
- )
44
- st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
45
 
46
- ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
47
- ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
48
 
49
- st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
50
- ja_model_path, device=device, jit=False
51
- )
52
 
53
- st.session_state.ja_model = AutoModel.from_pretrained(
54
- ja_model_name, trust_remote_code=True
55
- ).to(device)
56
- st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained(
57
- ja_model_name, trust_remote_code=True
58
- )
 
 
 
 
 
 
 
59
 
60
- st.session_state.active_model = "M-CLIP (multiple languages)"
61
 
 
62
  st.session_state.search_image_ids = []
63
  st.session_state.search_image_scores = {}
64
  st.session_state.activations_image = None
65
  st.session_state.text_table_df = None
66
 
67
- # Load the image IDs
68
- st.session_state.images_info = pd.read_csv("./metadata.csv")
69
- st.session_state.images_info.set_index("filename", inplace=True)
70
-
71
- st.session_state.image_ids = list(
72
- open("./images_list.txt", "r", encoding="utf-8").read().strip().split("\n")
73
- )
74
-
75
- # Load the image feature vectors
76
- # ml_image_features = np.load("./multilingual_features.npy")
77
- # ja_image_features = np.load("./hakuhodo_features.npy")
78
- ml_image_features = np.load("./resized_ml_features.npy")
79
- ja_image_features = np.load("./resized_ja_features.npy")
80
- # ml_image_features = np.load("./tiled_ml_features.npy")
81
- # ja_image_features = np.load("./tiled_ja_features.npy")
82
-
83
- # Convert features to Tensors: Float32 on CPU and Float16 on GPU
84
- if device == "cpu":
85
- ml_image_features = torch.from_numpy(ml_image_features).float().to(device)
86
- ja_image_features = torch.from_numpy(ja_image_features).float().to(device)
87
- else:
88
- ml_image_features = torch.from_numpy(ml_image_features).to(device)
89
- ja_image_features = torch.from_numpy(ja_image_features).to(device)
90
-
91
- st.session_state.ml_image_features = ml_image_features / ml_image_features.norm(
92
- dim=-1, keepdim=True
93
- )
94
- st.session_state.ja_image_features = ja_image_features / ja_image_features.norm(
95
- dim=-1, keepdim=True
96
- )
97
 
98
 
99
- if (
100
- "ml_image_features" not in st.session_state
101
- or "ja_image_features" not in st.session_state
102
- ):
103
- with st.spinner("Loading models and data, please wait..."):
104
- init()
105
 
106
 
107
  # The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
@@ -191,6 +199,7 @@ def visualize_gradcam(viz_image_id):
191
  image_url = st.session_state.images_info.loc[viz_image_id]["image_url"]
192
  image_response = requests.get(image_url)
193
  image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF"])
 
194
 
195
  img_dim = 224
196
  if st.session_state.active_model == "M-CLIP (multiple languages)":
@@ -198,62 +207,141 @@ def visualize_gradcam(viz_image_id):
198
 
199
  orig_img_dims = image.size
200
 
201
- altered_image = image.resize((img_dim, img_dim), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  if st.session_state.active_model == "M-CLIP (multiple languages)":
204
- p_image = (
205
- st.session_state.ml_image_preprocess(altered_image)
206
- .unsqueeze(0)
207
- .to(st.session_state.device)
208
- )
209
-
210
  # Sometimes used for token importance viz
211
  tokenized_text = st.session_state.ml_tokenizer.tokenize(
212
  st.session_state.search_field_value
213
  )
214
- image_model = st.session_state.ml_image_model
215
- # tokenize = st.session_state.ml_tokenizer.tokenize
216
 
217
  text_features = st.session_state.ml_model.forward(
218
  st.session_state.search_field_value, st.session_state.ml_tokenizer
219
  )
220
 
221
- vis_t = interpret_vit(
222
- p_image.type(st.session_state.ml_image_model.dtype),
223
- text_features,
224
- st.session_state.ml_image_model.visual,
225
- st.session_state.device,
226
- img_dim=img_dim,
227
- )
228
 
229
- else:
230
- p_image = (
231
- st.session_state.ja_image_preprocess(altered_image)
232
- .unsqueeze(0)
233
- .to(st.session_state.device)
234
- )
 
 
235
 
 
 
 
 
 
 
 
 
 
 
 
236
  # Sometimes used for token importance viz
237
  tokenized_text = st.session_state.ja_tokenizer.tokenize(
238
  st.session_state.search_field_value
239
  )
240
- image_model = st.session_state.ja_image_model
241
 
242
  t_text = st.session_state.ja_tokenizer(
243
  st.session_state.search_field_value, return_tensors="pt"
244
  )
245
  text_features = st.session_state.ja_model.get_text_features(**t_text)
246
 
247
- vis_t = interpret_vit(
248
- p_image.type(st.session_state.ja_image_model.dtype),
249
- text_features,
250
- st.session_state.ja_image_model.visual,
251
- st.session_state.device,
252
- img_dim=img_dim,
253
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  transform = ToPILImage()
256
- vis_img = transform(vis_t)
 
 
 
 
 
257
 
258
  if orig_img_dims[0] > orig_img_dims[1]:
259
  scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
@@ -262,14 +350,27 @@ def visualize_gradcam(viz_image_id):
262
  scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
263
  scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
264
 
265
- st.session_state.activations_image = vis_img.resize(scaled_dims)
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  image_io = BytesIO()
268
  st.session_state.activations_image.save(image_io, "PNG")
269
  dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode("ascii")
270
 
271
  st.html(
272
- f"""<div style="display: flex; flex-direction: column; align-items: center">
273
  <img src="{dataurl}" />
274
  </div>"""
275
  )
@@ -326,7 +427,11 @@ def visualize_gradcam(viz_image_id):
326
  st.table(st.session_state.text_table_df)
327
 
328
 
329
- @st.dialog(" ", width="small")
 
 
 
 
330
  def image_modal(vis_image_id):
331
  visualize_gradcam(vis_image_id)
332
 
@@ -363,7 +468,7 @@ st.markdown(
363
  unsafe_allow_html=True,
364
  )
365
 
366
- search_row = st.columns([45, 10, 13, 7, 25], vertical_alignment="center")
367
  with search_row[0]:
368
  search_field = st.text_input(
369
  label="search",
@@ -379,8 +484,20 @@ with search_row[1]:
379
  with search_row[2]:
380
  st.empty()
381
  with search_row[3]:
382
- st.markdown("**CLIP Model:**")
 
 
 
 
 
 
 
 
383
  with search_row[4]:
 
 
 
 
384
  st.radio(
385
  "CLIP Model",
386
  options=["M-CLIP (multiple languages)", "J-CLIP (日本語)"],
 
18
  vit_perword_relevance,
19
  ) # , interpret_vit_overlapped
20
 
21
+ MAX_IMG_WIDTH = 500
22
  MAX_IMG_HEIGHT = 800
23
 
24
  st.set_page_config(layout="wide")
25
 
26
 
27
+ def load_image_features():
28
+ # Load the image feature vectors
29
+ if st.session_state.vision_mode == "tiled":
30
+ ml_image_features = np.load("./image_features/tiled_ml_features.npy")
31
+ ja_image_features = np.load("./image_features/tiled_ja_features.npy")
32
+ elif st.session_state.vision_mode == "stretched":
33
+ ml_image_features = np.load("./image_features/resized_ml_features.npy")
34
+ ja_image_features = np.load("./image_features/resized_ja_features.npy")
35
+ else: # st.session_state.vision_mode == "cropped":
36
+ ml_image_features = np.load("./image_features/cropped_ml_features.npy")
37
+ ja_image_features = np.load("./image_features/cropped_ja_features.npy")
38
+
39
+ # Convert features to Tensors: Float32 on CPU and Float16 on GPU
40
+ device = st.session_state.device
41
+ if device == "cpu":
42
+ ml_image_features = torch.from_numpy(ml_image_features).float().to(device)
43
+ ja_image_features = torch.from_numpy(ja_image_features).float().to(device)
44
+ else:
45
+ ml_image_features = torch.from_numpy(ml_image_features).to(device)
46
+ ja_image_features = torch.from_numpy(ja_image_features).to(device)
47
+
48
+ st.session_state.ml_image_features = ml_image_features / ml_image_features.norm(
49
+ dim=-1, keepdim=True
50
+ )
51
+ st.session_state.ja_image_features = ja_image_features / ja_image_features.norm(
52
+ dim=-1, keepdim=True
53
+ )
54
+
55
+ string_search()
56
+
57
+
58
  def init():
59
  st.session_state.current_page = 1
60
 
 
65
  ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus"
66
  ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
67
 
68
+ with st.spinner("Loading models and data, please wait..."):
69
+ st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load(
70
+ ml_model_path, device=device, jit=False
71
+ )
72
 
73
+ st.session_state.ml_model = (
74
+ pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name)
75
+ )
76
+ st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name)
77
 
78
+ ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
79
+ ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
80
 
81
+ st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
82
+ ja_model_path, device=device, jit=False
83
+ )
84
 
85
+ st.session_state.ja_model = AutoModel.from_pretrained(
86
+ ja_model_name, trust_remote_code=True
87
+ ).to(device)
88
+ st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained(
89
+ ja_model_name, trust_remote_code=True
90
+ )
91
+
92
+ # Load the image IDs
93
+ st.session_state.images_info = pd.read_csv("./metadata.csv")
94
+ st.session_state.images_info.set_index("filename", inplace=True)
95
+
96
+ with open("./images_list.txt", "r", encoding="utf-8") as images_list:
97
+ st.session_state.image_ids = list(images_list.read().strip().split("\n"))
98
 
99
+ st.session_state.active_model = "M-CLIP (multiple languages)"
100
 
101
+ st.session_state.vision_mode = "tiled"
102
  st.session_state.search_image_ids = []
103
  st.session_state.search_image_scores = {}
104
  st.session_state.activations_image = None
105
  st.session_state.text_table_df = None
106
 
107
+ with st.spinner("Loading models and data, please wait..."):
108
+ load_image_features()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
 
111
+ if "images_info" not in st.session_state:
112
+ init()
 
 
 
 
113
 
114
 
115
  # The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model.
 
199
  image_url = st.session_state.images_info.loc[viz_image_id]["image_url"]
200
  image_response = requests.get(image_url)
201
  image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF"])
202
+ image = image.convert("RGB")
203
 
204
  img_dim = 224
205
  if st.session_state.active_model == "M-CLIP (multiple languages)":
 
207
 
208
  orig_img_dims = image.size
209
 
210
+ ##### If the features are based on tiled image slices
211
+ tile_behavior = None
212
+
213
+ if st.session_state.vision_mode == "tiled":
214
+ scaled_dims = [img_dim, img_dim]
215
+
216
+ if orig_img_dims[0] > orig_img_dims[1]:
217
+ scale_ratio = round(orig_img_dims[0] / orig_img_dims[1])
218
+ if scale_ratio > 1:
219
+ scaled_dims = [scale_ratio * img_dim, img_dim]
220
+ tile_behavior = "width"
221
+ elif orig_img_dims[0] < orig_img_dims[1]:
222
+ scale_ratio = round(orig_img_dims[1] / orig_img_dims[0])
223
+ if scale_ratio > 1:
224
+ scaled_dims = [img_dim, scale_ratio * img_dim]
225
+ tile_behavior = "height"
226
+
227
+ resized_image = image.resize(scaled_dims, Image.LANCZOS)
228
+
229
+ if tile_behavior == "width":
230
+ image_tiles = []
231
+ for x in range(0, scale_ratio):
232
+ box = (x * img_dim, 0, (x + 1) * img_dim, img_dim)
233
+ image_tiles.append(resized_image.crop(box))
234
+
235
+ elif tile_behavior == "height":
236
+ image_tiles = []
237
+ for y in range(0, scale_ratio):
238
+ box = (0, y * img_dim, img_dim, (y + 1) * img_dim)
239
+ image_tiles.append(resized_image.crop(box))
240
+
241
+ else:
242
+ image_tiles = [resized_image]
243
+
244
+ elif st.session_state.vision_mode == "stretched":
245
+ image_tiles = [image.resize((img_dim, img_dim), Image.LANCZOS)]
246
+
247
+ else: # vision_mode == "cropped"
248
+ if orig_img_dims[0] > orig_img_dims[1]:
249
+ scale_factor = orig_img_dims[0] / orig_img_dims[1]
250
+ resized_img_dims = (round(scale_factor * img_dim), img_dim)
251
+ resized_img = image.resize(resized_img_dims)
252
+ elif orig_img_dims[0] < orig_img_dims[1]:
253
+ scale_factor = orig_img_dims[1] / orig_img_dims[0]
254
+ resized_img_dims = (img_dim, round(scale_factor * img_dim))
255
+ else:
256
+ resized_img_dims = (img_dim, img_dim)
257
+
258
+ resized_img = image.resize(resized_img_dims)
259
+
260
+ left = round((resized_img_dims[0] - img_dim) / 2)
261
+ top = round((resized_img_dims[1] - img_dim) / 2)
262
+ x_right = round(resized_img_dims[0] - img_dim) - left
263
+ x_bottom = round(resized_img_dims[1] - img_dim) - top
264
+ right = resized_img_dims[0] - x_right
265
+ bottom = resized_img_dims[1] - x_bottom
266
+
267
+ # Crop the center of the image
268
+ image_tiles = [resized_img.crop((left, top, right, bottom))]
269
+
270
+ image_visualizations = []
271
 
272
  if st.session_state.active_model == "M-CLIP (multiple languages)":
 
 
 
 
 
 
273
  # Sometimes used for token importance viz
274
  tokenized_text = st.session_state.ml_tokenizer.tokenize(
275
  st.session_state.search_field_value
276
  )
 
 
277
 
278
  text_features = st.session_state.ml_model.forward(
279
  st.session_state.search_field_value, st.session_state.ml_tokenizer
280
  )
281
 
282
+ image_model = st.session_state.ml_image_model
283
+ # tokenize = st.session_state.ml_tokenizer.tokenize
284
+ image_model.eval()
 
 
 
 
285
 
286
+ for altered_image in image_tiles:
287
+ image_model.zero_grad()
288
+
289
+ p_image = (
290
+ st.session_state.ml_image_preprocess(altered_image)
291
+ .unsqueeze(0)
292
+ .to(st.session_state.device)
293
+ )
294
 
295
+ vis_t = interpret_vit(
296
+ p_image.type(st.session_state.ml_image_model.dtype),
297
+ text_features,
298
+ image_model.visual,
299
+ st.session_state.device,
300
+ img_dim=img_dim,
301
+ )
302
+
303
+ image_visualizations.append(vis_t)
304
+
305
+ else:
306
  # Sometimes used for token importance viz
307
  tokenized_text = st.session_state.ja_tokenizer.tokenize(
308
  st.session_state.search_field_value
309
  )
 
310
 
311
  t_text = st.session_state.ja_tokenizer(
312
  st.session_state.search_field_value, return_tensors="pt"
313
  )
314
  text_features = st.session_state.ja_model.get_text_features(**t_text)
315
 
316
+ image_model = st.session_state.ja_image_model
317
+ image_model.eval()
318
+
319
+ for altered_image in image_tiles:
320
+ image_model.zero_grad()
321
+
322
+ p_image = (
323
+ st.session_state.ja_image_preprocess(altered_image)
324
+ .unsqueeze(0)
325
+ .to(st.session_state.device)
326
+ )
327
+
328
+ vis_t = interpret_vit(
329
+ p_image.type(st.session_state.ja_image_model.dtype),
330
+ text_features,
331
+ image_model.visual,
332
+ st.session_state.device,
333
+ img_dim=img_dim,
334
+ )
335
+
336
+ image_visualizations.append(vis_t)
337
 
338
  transform = ToPILImage()
339
+
340
+ vis_images = [transform(vis_t) for vis_t in image_visualizations]
341
+
342
+ if st.session_state.vision_mode == "cropped":
343
+ resized_img.paste(vis_images[0], (left, top))
344
+ vis_images = [resized_img]
345
 
346
  if orig_img_dims[0] > orig_img_dims[1]:
347
  scale_factor = MAX_IMG_WIDTH / orig_img_dims[0]
 
350
  scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1]
351
  scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT]
352
 
353
+ if tile_behavior == "width":
354
+ vis_image = Image.new("RGB", (len(vis_images) * img_dim, img_dim))
355
+ for x, v_img in enumerate(vis_images):
356
+ vis_image.paste(v_img, (x * img_dim, 0))
357
+ st.session_state.activations_image = vis_image.resize(scaled_dims)
358
+
359
+ elif tile_behavior == "height":
360
+ vis_image = Image.new("RGB", (img_dim, len(vis_images) * img_dim))
361
+ for y, v_img in enumerate(vis_images):
362
+ vis_image.paste(v_img, (0, y * img_dim))
363
+ st.session_state.activations_image = vis_image.resize(scaled_dims)
364
+
365
+ else:
366
+ st.session_state.activations_image = vis_images[0].resize(scaled_dims)
367
 
368
  image_io = BytesIO()
369
  st.session_state.activations_image.save(image_io, "PNG")
370
  dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode("ascii")
371
 
372
  st.html(
373
+ f"""<div style="display: flex; flex-direction: column; align-items: center;">
374
  <img src="{dataurl}" />
375
  </div>"""
376
  )
 
427
  st.table(st.session_state.text_table_df)
428
 
429
 
430
+ def format_vision_mode(mode_stub):
431
+ return f"Vision mode: {mode_stub.capitalize()}"
432
+
433
+
434
+ @st.dialog(" ", width="large")
435
  def image_modal(vis_image_id):
436
  visualize_gradcam(vis_image_id)
437
 
 
468
  unsafe_allow_html=True,
469
  )
470
 
471
+ search_row = st.columns([45, 5, 1, 15, 1, 8, 25], vertical_alignment="center")
472
  with search_row[0]:
473
  search_field = st.text_input(
474
  label="search",
 
484
  with search_row[2]:
485
  st.empty()
486
  with search_row[3]:
487
+ st.selectbox(
488
+ "Vision mode:",
489
+ options=["tiled", "stretched", "cropped"],
490
+ key="vision_mode",
491
+ help="How to consider images that aren't square",
492
+ on_change=load_image_features,
493
+ format_func=format_vision_mode,
494
+ label_visibility="collapsed",
495
+ )
496
  with search_row[4]:
497
+ st.empty()
498
+ with search_row[5]:
499
+ st.markdown("**CLIP Model:**")
500
+ with search_row[6]:
501
  st.radio(
502
  "CLIP Model",
503
  options=["M-CLIP (multiple languages)", "J-CLIP (日本語)"],