Upload 3 files
Browse files- CLIP_Explainability/image_utils.py +21 -5
- CLIP_Explainability/rn_cam.py +12 -10
- CLIP_Explainability/vit_cam.py +19 -45
CLIP_Explainability/image_utils.py
CHANGED
@@ -1,22 +1,38 @@
|
|
1 |
import numpy as np
|
2 |
import cv2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
def show_cam_on_image(img, mask, neg_saliency=False):
|
5 |
-
|
6 |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
7 |
-
|
8 |
heatmap = np.float32(heatmap) / 255
|
9 |
cam = heatmap + np.float32(img)
|
10 |
cam = cam / np.max(cam)
|
11 |
return cam
|
12 |
|
|
|
13 |
def show_overlapped_cam(img, neg_mask, pos_mask):
|
14 |
-
neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), cv2.COLORMAP_RAINBOW)
|
15 |
-
pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), cv2.COLORMAP_JET)
|
|
|
|
|
16 |
neg_heatmap = np.float32(neg_heatmap) / 255
|
17 |
pos_heatmap = np.float32(pos_heatmap) / 255
|
18 |
# try different options: sum, average, ...
|
19 |
heatmap = neg_heatmap + pos_heatmap
|
20 |
cam = heatmap + np.float32(img)
|
21 |
cam = cam / np.max(cam)
|
22 |
-
return cam
|
|
|
1 |
import numpy as np
|
2 |
import cv2
|
3 |
+
from matplotlib import pyplot as plt
|
4 |
+
|
5 |
+
|
6 |
+
def get_mpl_colormap(cmap_name):
|
7 |
+
cmap = plt.get_cmap(cmap_name)
|
8 |
+
|
9 |
+
# Initialize the matplotlib color map
|
10 |
+
sm = plt.cm.ScalarMappable(cmap=cmap)
|
11 |
+
|
12 |
+
# Obtain linear color range
|
13 |
+
color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1]
|
14 |
+
|
15 |
+
return color_range.reshape(256, 1, 3)
|
16 |
+
|
17 |
|
18 |
def show_cam_on_image(img, mask, neg_saliency=False):
|
|
|
19 |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
20 |
+
|
21 |
heatmap = np.float32(heatmap) / 255
|
22 |
cam = heatmap + np.float32(img)
|
23 |
cam = cam / np.max(cam)
|
24 |
return cam
|
25 |
|
26 |
+
|
27 |
def show_overlapped_cam(img, neg_mask, pos_mask):
|
28 |
+
# neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), cv2.COLORMAP_RAINBOW)
|
29 |
+
# pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), cv2.COLORMAP_JET)
|
30 |
+
neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), get_mpl_colormap("Blues"))
|
31 |
+
pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), get_mpl_colormap("Reds"))
|
32 |
neg_heatmap = np.float32(neg_heatmap) / 255
|
33 |
pos_heatmap = np.float32(pos_heatmap) / 255
|
34 |
# try different options: sum, average, ...
|
35 |
heatmap = neg_heatmap + pos_heatmap
|
36 |
cam = heatmap + np.float32(img)
|
37 |
cam = cam / np.max(cam)
|
38 |
+
return cam
|
CLIP_Explainability/rn_cam.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
-
|
4 |
-
|
|
|
5 |
import cv2
|
6 |
import re
|
7 |
|
@@ -22,7 +23,7 @@ def rn_relevance(
|
|
22 |
cam = method(
|
23 |
model=img_encoder,
|
24 |
target_layers=target_layers,
|
25 |
-
use_cuda=torch.cuda.is_available(),
|
26 |
)
|
27 |
|
28 |
if neg_saliency:
|
@@ -127,12 +128,13 @@ def rn_perword_relevance(
|
|
127 |
masked_text = re.sub(masked_word, "", text)
|
128 |
masked_text = clip_tokenizer(masked_text).to(device)
|
129 |
|
130 |
-
image_features = clip_model.encode_image(image)
|
131 |
main_text_features = clip_model.encode_text(main_text)
|
132 |
masked_text_features = clip_model.encode_text(masked_text)
|
133 |
|
134 |
-
image_features_norm = image_features.norm(dim=-1, keepdim=True)
|
135 |
-
image_features_new = image_features / image_features_norm
|
|
|
136 |
main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
|
137 |
main_text_features_new = main_text_features / main_text_features_norm
|
138 |
|
@@ -146,10 +148,10 @@ def rn_perword_relevance(
|
|
146 |
cam = method(
|
147 |
model=clip_model.visual,
|
148 |
target_layers=target_layers,
|
149 |
-
use_cuda=torch.cuda.is_available(),
|
150 |
)
|
151 |
|
152 |
-
image_features = clip_model.visual(image)
|
153 |
|
154 |
image_relevance = cam(input_tensor=image, target_encoding=target_encoding)[
|
155 |
0
|
@@ -175,7 +177,7 @@ def rn_perword_relevance(
|
|
175 |
image = image[0].permute(1, 2, 0).data.cpu().numpy()
|
176 |
image = (image - image.min()) / (image.max() - image.min())
|
177 |
|
178 |
-
return image_relevance
|
179 |
|
180 |
|
181 |
def interpret_perword_rn(
|
@@ -189,7 +191,7 @@ def interpret_perword_rn(
|
|
189 |
data_only=False,
|
190 |
img_dim=224,
|
191 |
):
|
192 |
-
image_relevance
|
193 |
image,
|
194 |
text,
|
195 |
clip_model,
|
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
+
|
4 |
+
# from PIL import Image
|
5 |
+
# import matplotlib.pyplot as plt
|
6 |
import cv2
|
7 |
import re
|
8 |
|
|
|
23 |
cam = method(
|
24 |
model=img_encoder,
|
25 |
target_layers=target_layers,
|
26 |
+
use_cuda=torch.cuda.is_available() and device != "cpu",
|
27 |
)
|
28 |
|
29 |
if neg_saliency:
|
|
|
128 |
masked_text = re.sub(masked_word, "", text)
|
129 |
masked_text = clip_tokenizer(masked_text).to(device)
|
130 |
|
131 |
+
# image_features = clip_model.encode_image(image)
|
132 |
main_text_features = clip_model.encode_text(main_text)
|
133 |
masked_text_features = clip_model.encode_text(masked_text)
|
134 |
|
135 |
+
# image_features_norm = image_features.norm(dim=-1, keepdim=True)
|
136 |
+
# image_features_new = image_features / image_features_norm
|
137 |
+
|
138 |
main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
|
139 |
main_text_features_new = main_text_features / main_text_features_norm
|
140 |
|
|
|
148 |
cam = method(
|
149 |
model=clip_model.visual,
|
150 |
target_layers=target_layers,
|
151 |
+
use_cuda=torch.cuda.is_available() and device != "cpu",
|
152 |
)
|
153 |
|
154 |
+
# image_features = clip_model.visual(image)
|
155 |
|
156 |
image_relevance = cam(input_tensor=image, target_encoding=target_encoding)[
|
157 |
0
|
|
|
177 |
image = image[0].permute(1, 2, 0).data.cpu().numpy()
|
178 |
image = (image - image.min()) / (image.max() - image.min())
|
179 |
|
180 |
+
return image_relevance
|
181 |
|
182 |
|
183 |
def interpret_perword_rn(
|
|
|
191 |
data_only=False,
|
192 |
img_dim=224,
|
193 |
):
|
194 |
+
image_relevance = rn_perword_relevance(
|
195 |
image,
|
196 |
text,
|
197 |
clip_model,
|
CLIP_Explainability/vit_cam.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
-
|
4 |
-
|
|
|
5 |
import cv2
|
6 |
import regex as re
|
7 |
|
@@ -71,7 +72,7 @@ def vit_block_vis(
|
|
71 |
cam = cam[0].permute(1, 2, 0).data.cpu().numpy()
|
72 |
cam = np.float32(cam)
|
73 |
|
74 |
-
plt.imshow(cam)
|
75 |
|
76 |
return new_score
|
77 |
|
@@ -90,8 +91,10 @@ def vit_relevance(
|
|
90 |
|
91 |
image_features_norm = image_features.norm(dim=-1, keepdim=True)
|
92 |
image_features_new = image_features / image_features_norm
|
|
|
93 |
target_features_norm = target_features.norm(dim=-1, keepdim=True)
|
94 |
target_features_new = target_features / target_features_norm
|
|
|
95 |
similarity = image_features_new[0].dot(target_features_new[0])
|
96 |
if neg_saliency:
|
97 |
objective = 1 - similarity
|
@@ -154,7 +157,7 @@ def vit_relevance(
|
|
154 |
image = image[0].permute(1, 2, 0).data.cpu().numpy()
|
155 |
image = (image - image.min()) / (image.max() - image.min())
|
156 |
|
157 |
-
return image_relevance, image
|
158 |
|
159 |
|
160 |
def interpret_vit(
|
@@ -166,7 +169,7 @@ def interpret_vit(
|
|
166 |
neg_saliency=False,
|
167 |
img_dim=224,
|
168 |
):
|
169 |
-
image_relevance, image = vit_relevance(
|
170 |
image,
|
171 |
target_features,
|
172 |
img_encoder,
|
@@ -180,14 +183,14 @@ def interpret_vit(
|
|
180 |
vis = np.uint8(255 * vis)
|
181 |
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
182 |
|
183 |
-
return vis
|
184 |
# plt.imshow(vis)
|
185 |
|
186 |
|
187 |
def interpret_vit_overlapped(
|
188 |
image, target_features, img_encoder, device, method="last grad", img_dim=224
|
189 |
):
|
190 |
-
pos_image_relevance, _ = vit_relevance(
|
191 |
image,
|
192 |
target_features,
|
193 |
img_encoder,
|
@@ -196,7 +199,7 @@ def interpret_vit_overlapped(
|
|
196 |
neg_saliency=False,
|
197 |
img_dim=img_dim,
|
198 |
)
|
199 |
-
neg_image_relevance, image = vit_relevance(
|
200 |
image,
|
201 |
target_features,
|
202 |
img_encoder,
|
@@ -210,19 +213,18 @@ def interpret_vit_overlapped(
|
|
210 |
vis = np.uint8(255 * vis)
|
211 |
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
212 |
|
213 |
-
return vis
|
214 |
# plt.imshow(vis)
|
215 |
|
216 |
|
217 |
def vit_perword_relevance(
|
218 |
-
|
219 |
text,
|
220 |
clip_model,
|
221 |
clip_tokenizer,
|
222 |
device,
|
223 |
masked_word="",
|
224 |
use_last_grad=True,
|
225 |
-
data_only=False,
|
226 |
img_dim=224,
|
227 |
):
|
228 |
clip_model.eval()
|
@@ -232,12 +234,13 @@ def vit_perword_relevance(
|
|
232 |
masked_text = re.sub(masked_word, "", text)
|
233 |
masked_text = clip_tokenizer(masked_text).to(device)
|
234 |
|
235 |
-
image_features = clip_model.encode_image(image)
|
236 |
main_text_features = clip_model.encode_text(main_text)
|
237 |
masked_text_features = clip_model.encode_text(masked_text)
|
238 |
|
|
|
239 |
image_features_norm = image_features.norm(dim=-1, keepdim=True)
|
240 |
image_features_new = image_features / image_features_norm
|
|
|
241 |
main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
|
242 |
main_text_features_new = main_text_features / main_text_features_norm
|
243 |
|
@@ -290,38 +293,9 @@ def vit_perword_relevance(
|
|
290 |
image_relevance.max() - image_relevance.min()
|
291 |
)
|
292 |
|
293 |
-
|
294 |
-
|
295 |
|
296 |
-
|
297 |
-
image = (image - image.min()) / (image.max() - image.min())
|
298 |
-
|
299 |
-
return image_relevance, image
|
300 |
-
|
301 |
-
|
302 |
-
def interpret_perword_vit(
|
303 |
-
image,
|
304 |
-
text,
|
305 |
-
clip_model,
|
306 |
-
clip_tokenizer,
|
307 |
-
device,
|
308 |
-
masked_word="",
|
309 |
-
use_last_grad=True,
|
310 |
-
img_dim=224,
|
311 |
-
):
|
312 |
-
image_relevance, image = vit_perword_relevance(
|
313 |
-
image,
|
314 |
-
text,
|
315 |
-
clip_model,
|
316 |
-
clip_tokenizer,
|
317 |
-
device,
|
318 |
-
masked_word,
|
319 |
-
use_last_grad,
|
320 |
-
img_dim=img_dim,
|
321 |
-
)
|
322 |
-
vis = show_cam_on_image(image, image_relevance)
|
323 |
-
vis = np.uint8(255 * vis)
|
324 |
-
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
325 |
|
326 |
-
return
|
327 |
-
# plt.imshow(vis)
|
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
+
|
4 |
+
# from PIL import Image
|
5 |
+
# import matplotlib.pyplot as plt
|
6 |
import cv2
|
7 |
import regex as re
|
8 |
|
|
|
72 |
cam = cam[0].permute(1, 2, 0).data.cpu().numpy()
|
73 |
cam = np.float32(cam)
|
74 |
|
75 |
+
# plt.imshow(cam)
|
76 |
|
77 |
return new_score
|
78 |
|
|
|
91 |
|
92 |
image_features_norm = image_features.norm(dim=-1, keepdim=True)
|
93 |
image_features_new = image_features / image_features_norm
|
94 |
+
|
95 |
target_features_norm = target_features.norm(dim=-1, keepdim=True)
|
96 |
target_features_new = target_features / target_features_norm
|
97 |
+
|
98 |
similarity = image_features_new[0].dot(target_features_new[0])
|
99 |
if neg_saliency:
|
100 |
objective = 1 - similarity
|
|
|
157 |
image = image[0].permute(1, 2, 0).data.cpu().numpy()
|
158 |
image = (image - image.min()) / (image.max() - image.min())
|
159 |
|
160 |
+
return image_relevance, image, image_features, similarity
|
161 |
|
162 |
|
163 |
def interpret_vit(
|
|
|
169 |
neg_saliency=False,
|
170 |
img_dim=224,
|
171 |
):
|
172 |
+
image_relevance, image, image_features, similarity = vit_relevance(
|
173 |
image,
|
174 |
target_features,
|
175 |
img_encoder,
|
|
|
183 |
vis = np.uint8(255 * vis)
|
184 |
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
185 |
|
186 |
+
return vis, image_features, similarity
|
187 |
# plt.imshow(vis)
|
188 |
|
189 |
|
190 |
def interpret_vit_overlapped(
|
191 |
image, target_features, img_encoder, device, method="last grad", img_dim=224
|
192 |
):
|
193 |
+
pos_image_relevance, _, image_features, similarity = vit_relevance(
|
194 |
image,
|
195 |
target_features,
|
196 |
img_encoder,
|
|
|
199 |
neg_saliency=False,
|
200 |
img_dim=img_dim,
|
201 |
)
|
202 |
+
neg_image_relevance, image, _, _ = vit_relevance(
|
203 |
image,
|
204 |
target_features,
|
205 |
img_encoder,
|
|
|
213 |
vis = np.uint8(255 * vis)
|
214 |
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
|
215 |
|
216 |
+
return vis, image_features, similarity
|
217 |
# plt.imshow(vis)
|
218 |
|
219 |
|
220 |
def vit_perword_relevance(
|
221 |
+
image_features,
|
222 |
text,
|
223 |
clip_model,
|
224 |
clip_tokenizer,
|
225 |
device,
|
226 |
masked_word="",
|
227 |
use_last_grad=True,
|
|
|
228 |
img_dim=224,
|
229 |
):
|
230 |
clip_model.eval()
|
|
|
234 |
masked_text = re.sub(masked_word, "", text)
|
235 |
masked_text = clip_tokenizer(masked_text).to(device)
|
236 |
|
|
|
237 |
main_text_features = clip_model.encode_text(main_text)
|
238 |
masked_text_features = clip_model.encode_text(masked_text)
|
239 |
|
240 |
+
# image_features = clip_model.encode_image(image)
|
241 |
image_features_norm = image_features.norm(dim=-1, keepdim=True)
|
242 |
image_features_new = image_features / image_features_norm
|
243 |
+
|
244 |
main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True)
|
245 |
main_text_features_new = main_text_features / main_text_features_norm
|
246 |
|
|
|
293 |
image_relevance.max() - image_relevance.min()
|
294 |
)
|
295 |
|
296 |
+
# image = image[0].permute(1, 2, 0).data.cpu().numpy()
|
297 |
+
# image = (image - image.min()) / (image.max() - image.min())
|
298 |
|
299 |
+
# return image_relevance, image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
+
return image_relevance
|
|