broadwell commited on
Commit
254b186
1 Parent(s): c3473c5

Upload 3 files

Browse files
CLIP_Explainability/image_utils.py CHANGED
@@ -1,22 +1,38 @@
1
  import numpy as np
2
  import cv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def show_cam_on_image(img, mask, neg_saliency=False):
5
-
6
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
7
-
8
  heatmap = np.float32(heatmap) / 255
9
  cam = heatmap + np.float32(img)
10
  cam = cam / np.max(cam)
11
  return cam
12
 
 
13
  def show_overlapped_cam(img, neg_mask, pos_mask):
14
- neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), cv2.COLORMAP_RAINBOW)
15
- pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), cv2.COLORMAP_JET)
 
 
16
  neg_heatmap = np.float32(neg_heatmap) / 255
17
  pos_heatmap = np.float32(pos_heatmap) / 255
18
  # try different options: sum, average, ...
19
  heatmap = neg_heatmap + pos_heatmap
20
  cam = heatmap + np.float32(img)
21
  cam = cam / np.max(cam)
22
- return cam
 
1
  import numpy as np
2
  import cv2
3
+ from matplotlib import pyplot as plt
4
+
5
+
6
+ def get_mpl_colormap(cmap_name):
7
+ cmap = plt.get_cmap(cmap_name)
8
+
9
+ # Initialize the matplotlib color map
10
+ sm = plt.cm.ScalarMappable(cmap=cmap)
11
+
12
+ # Obtain linear color range
13
+ color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1]
14
+
15
+ return color_range.reshape(256, 1, 3)
16
+
17
 
18
  def show_cam_on_image(img, mask, neg_saliency=False):
 
19
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
20
+
21
  heatmap = np.float32(heatmap) / 255
22
  cam = heatmap + np.float32(img)
23
  cam = cam / np.max(cam)
24
  return cam
25
 
26
+
27
  def show_overlapped_cam(img, neg_mask, pos_mask):
28
+ # neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), cv2.COLORMAP_RAINBOW)
29
+ # pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), cv2.COLORMAP_JET)
30
+ neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), get_mpl_colormap("Blues"))
31
+ pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), get_mpl_colormap("Reds"))
32
  neg_heatmap = np.float32(neg_heatmap) / 255
33
  pos_heatmap = np.float32(pos_heatmap) / 255
34
  # try different options: sum, average, ...
35
  heatmap = neg_heatmap + pos_heatmap
36
  cam = heatmap + np.float32(img)
37
  cam = cam / np.max(cam)
38
+ return cam
CLIP_Explainability/rn_cam.py CHANGED
@@ -1,7 +1,8 @@
1
  import torch
2
  import numpy as np
3
- from PIL import Image
4
- import matplotlib.pyplot as plt
 
5
  import cv2
6
  import re
7
 
@@ -22,7 +23,7 @@ def rn_relevance(
22
  cam = method(
23
  model=img_encoder,
24
  target_layers=target_layers,
25
- use_cuda=torch.cuda.is_available(),
26
  )
27
 
28
  if neg_saliency:
@@ -127,12 +128,13 @@ def rn_perword_relevance(
127
  masked_text = re.sub(masked_word, "", text)
128
  masked_text = clip_tokenizer(masked_text).to(device)
129
 
130
- image_features = clip_model.encode_image(image)
131
  main_text_features = clip_model.encode_text(main_text)
132
  masked_text_features = clip_model.encode_text(masked_text)
133
 
134
- image_features_norm = image_features.norm(dim=-1, keepdim=True)
135
- image_features_new = image_features / image_features_norm
 
136
  main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
137
  main_text_features_new = main_text_features / main_text_features_norm
138
 
@@ -146,10 +148,10 @@ def rn_perword_relevance(
146
  cam = method(
147
  model=clip_model.visual,
148
  target_layers=target_layers,
149
- use_cuda=torch.cuda.is_available(),
150
  )
151
 
152
- image_features = clip_model.visual(image)
153
 
154
  image_relevance = cam(input_tensor=image, target_encoding=target_encoding)[
155
  0
@@ -175,7 +177,7 @@ def rn_perword_relevance(
175
  image = image[0].permute(1, 2, 0).data.cpu().numpy()
176
  image = (image - image.min()) / (image.max() - image.min())
177
 
178
- return image_relevance, image
179
 
180
 
181
  def interpret_perword_rn(
@@ -189,7 +191,7 @@ def interpret_perword_rn(
189
  data_only=False,
190
  img_dim=224,
191
  ):
192
- image_relevance, image = rn_perword_relevance(
193
  image,
194
  text,
195
  clip_model,
 
1
  import torch
2
  import numpy as np
3
+
4
+ # from PIL import Image
5
+ # import matplotlib.pyplot as plt
6
  import cv2
7
  import re
8
 
 
23
  cam = method(
24
  model=img_encoder,
25
  target_layers=target_layers,
26
+ use_cuda=torch.cuda.is_available() and device != "cpu",
27
  )
28
 
29
  if neg_saliency:
 
128
  masked_text = re.sub(masked_word, "", text)
129
  masked_text = clip_tokenizer(masked_text).to(device)
130
 
131
+ # image_features = clip_model.encode_image(image)
132
  main_text_features = clip_model.encode_text(main_text)
133
  masked_text_features = clip_model.encode_text(masked_text)
134
 
135
+ # image_features_norm = image_features.norm(dim=-1, keepdim=True)
136
+ # image_features_new = image_features / image_features_norm
137
+
138
  main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
139
  main_text_features_new = main_text_features / main_text_features_norm
140
 
 
148
  cam = method(
149
  model=clip_model.visual,
150
  target_layers=target_layers,
151
+ use_cuda=torch.cuda.is_available() and device != "cpu",
152
  )
153
 
154
+ # image_features = clip_model.visual(image)
155
 
156
  image_relevance = cam(input_tensor=image, target_encoding=target_encoding)[
157
  0
 
177
  image = image[0].permute(1, 2, 0).data.cpu().numpy()
178
  image = (image - image.min()) / (image.max() - image.min())
179
 
180
+ return image_relevance
181
 
182
 
183
  def interpret_perword_rn(
 
191
  data_only=False,
192
  img_dim=224,
193
  ):
194
+ image_relevance = rn_perword_relevance(
195
  image,
196
  text,
197
  clip_model,
CLIP_Explainability/vit_cam.py CHANGED
@@ -1,7 +1,8 @@
1
  import torch
2
  import numpy as np
3
- from PIL import Image
4
- import matplotlib.pyplot as plt
 
5
  import cv2
6
  import regex as re
7
 
@@ -71,7 +72,7 @@ def vit_block_vis(
71
  cam = cam[0].permute(1, 2, 0).data.cpu().numpy()
72
  cam = np.float32(cam)
73
 
74
- plt.imshow(cam)
75
 
76
  return new_score
77
 
@@ -90,8 +91,10 @@ def vit_relevance(
90
 
91
  image_features_norm = image_features.norm(dim=-1, keepdim=True)
92
  image_features_new = image_features / image_features_norm
 
93
  target_features_norm = target_features.norm(dim=-1, keepdim=True)
94
  target_features_new = target_features / target_features_norm
 
95
  similarity = image_features_new[0].dot(target_features_new[0])
96
  if neg_saliency:
97
  objective = 1 - similarity
@@ -154,7 +157,7 @@ def vit_relevance(
154
  image = image[0].permute(1, 2, 0).data.cpu().numpy()
155
  image = (image - image.min()) / (image.max() - image.min())
156
 
157
- return image_relevance, image
158
 
159
 
160
  def interpret_vit(
@@ -166,7 +169,7 @@ def interpret_vit(
166
  neg_saliency=False,
167
  img_dim=224,
168
  ):
169
- image_relevance, image = vit_relevance(
170
  image,
171
  target_features,
172
  img_encoder,
@@ -180,14 +183,14 @@ def interpret_vit(
180
  vis = np.uint8(255 * vis)
181
  vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
182
 
183
- return vis
184
  # plt.imshow(vis)
185
 
186
 
187
  def interpret_vit_overlapped(
188
  image, target_features, img_encoder, device, method="last grad", img_dim=224
189
  ):
190
- pos_image_relevance, _ = vit_relevance(
191
  image,
192
  target_features,
193
  img_encoder,
@@ -196,7 +199,7 @@ def interpret_vit_overlapped(
196
  neg_saliency=False,
197
  img_dim=img_dim,
198
  )
199
- neg_image_relevance, image = vit_relevance(
200
  image,
201
  target_features,
202
  img_encoder,
@@ -210,19 +213,18 @@ def interpret_vit_overlapped(
210
  vis = np.uint8(255 * vis)
211
  vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
212
 
213
- return vis
214
  # plt.imshow(vis)
215
 
216
 
217
  def vit_perword_relevance(
218
- image,
219
  text,
220
  clip_model,
221
  clip_tokenizer,
222
  device,
223
  masked_word="",
224
  use_last_grad=True,
225
- data_only=False,
226
  img_dim=224,
227
  ):
228
  clip_model.eval()
@@ -232,12 +234,13 @@ def vit_perword_relevance(
232
  masked_text = re.sub(masked_word, "", text)
233
  masked_text = clip_tokenizer(masked_text).to(device)
234
 
235
- image_features = clip_model.encode_image(image)
236
  main_text_features = clip_model.encode_text(main_text)
237
  masked_text_features = clip_model.encode_text(masked_text)
238
 
 
239
  image_features_norm = image_features.norm(dim=-1, keepdim=True)
240
  image_features_new = image_features / image_features_norm
 
241
  main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
242
  main_text_features_new = main_text_features / main_text_features_norm
243
 
@@ -290,38 +293,9 @@ def vit_perword_relevance(
290
  image_relevance.max() - image_relevance.min()
291
  )
292
 
293
- if data_only:
294
- return image_relevance
295
 
296
- image = image[0].permute(1, 2, 0).data.cpu().numpy()
297
- image = (image - image.min()) / (image.max() - image.min())
298
-
299
- return image_relevance, image
300
-
301
-
302
- def interpret_perword_vit(
303
- image,
304
- text,
305
- clip_model,
306
- clip_tokenizer,
307
- device,
308
- masked_word="",
309
- use_last_grad=True,
310
- img_dim=224,
311
- ):
312
- image_relevance, image = vit_perword_relevance(
313
- image,
314
- text,
315
- clip_model,
316
- clip_tokenizer,
317
- device,
318
- masked_word,
319
- use_last_grad,
320
- img_dim=img_dim,
321
- )
322
- vis = show_cam_on_image(image, image_relevance)
323
- vis = np.uint8(255 * vis)
324
- vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
325
 
326
- return vis
327
- # plt.imshow(vis)
 
1
  import torch
2
  import numpy as np
3
+
4
+ # from PIL import Image
5
+ # import matplotlib.pyplot as plt
6
  import cv2
7
  import regex as re
8
 
 
72
  cam = cam[0].permute(1, 2, 0).data.cpu().numpy()
73
  cam = np.float32(cam)
74
 
75
+ # plt.imshow(cam)
76
 
77
  return new_score
78
 
 
91
 
92
  image_features_norm = image_features.norm(dim=-1, keepdim=True)
93
  image_features_new = image_features / image_features_norm
94
+
95
  target_features_norm = target_features.norm(dim=-1, keepdim=True)
96
  target_features_new = target_features / target_features_norm
97
+
98
  similarity = image_features_new[0].dot(target_features_new[0])
99
  if neg_saliency:
100
  objective = 1 - similarity
 
157
  image = image[0].permute(1, 2, 0).data.cpu().numpy()
158
  image = (image - image.min()) / (image.max() - image.min())
159
 
160
+ return image_relevance, image, image_features, similarity
161
 
162
 
163
  def interpret_vit(
 
169
  neg_saliency=False,
170
  img_dim=224,
171
  ):
172
+ image_relevance, image, image_features, similarity = vit_relevance(
173
  image,
174
  target_features,
175
  img_encoder,
 
183
  vis = np.uint8(255 * vis)
184
  vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
185
 
186
+ return vis, image_features, similarity
187
  # plt.imshow(vis)
188
 
189
 
190
  def interpret_vit_overlapped(
191
  image, target_features, img_encoder, device, method="last grad", img_dim=224
192
  ):
193
+ pos_image_relevance, _, image_features, similarity = vit_relevance(
194
  image,
195
  target_features,
196
  img_encoder,
 
199
  neg_saliency=False,
200
  img_dim=img_dim,
201
  )
202
+ neg_image_relevance, image, _, _ = vit_relevance(
203
  image,
204
  target_features,
205
  img_encoder,
 
213
  vis = np.uint8(255 * vis)
214
  vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
215
 
216
+ return vis, image_features, similarity
217
  # plt.imshow(vis)
218
 
219
 
220
  def vit_perword_relevance(
221
+ image_features,
222
  text,
223
  clip_model,
224
  clip_tokenizer,
225
  device,
226
  masked_word="",
227
  use_last_grad=True,
 
228
  img_dim=224,
229
  ):
230
  clip_model.eval()
 
234
  masked_text = re.sub(masked_word, "", text)
235
  masked_text = clip_tokenizer(masked_text).to(device)
236
 
 
237
  main_text_features = clip_model.encode_text(main_text)
238
  masked_text_features = clip_model.encode_text(masked_text)
239
 
240
+ # image_features = clip_model.encode_image(image)
241
  image_features_norm = image_features.norm(dim=-1, keepdim=True)
242
  image_features_new = image_features / image_features_norm
243
+
244
  main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
245
  main_text_features_new = main_text_features / main_text_features_norm
246
 
 
293
  image_relevance.max() - image_relevance.min()
294
  )
295
 
296
+ # image = image[0].permute(1, 2, 0).data.cpu().numpy()
297
+ # image = (image - image.min()) / (image.max() - image.min())
298
 
299
+ # return image_relevance, image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
+ return image_relevance