|
import torch |
|
import numpy as np |
|
|
|
|
|
|
|
import cv2 |
|
import re |
|
|
|
from .image_utils import show_cam_on_image, show_overlapped_cam |
|
|
|
|
|
def rn_relevance( |
|
image, |
|
target_features, |
|
img_encoder, |
|
method, |
|
device, |
|
neg_saliency=False, |
|
img_dim=224, |
|
): |
|
target_layers = [img_encoder.layer4[-1]] |
|
|
|
cam = method( |
|
model=img_encoder, |
|
target_layers=target_layers, |
|
use_cuda=torch.cuda.is_available() and device != "cpu", |
|
) |
|
|
|
if neg_saliency: |
|
target_encoding = -target_features |
|
else: |
|
target_encoding = target_features |
|
|
|
image_relevance = cam(input_tensor=image, target_encoding=target_encoding)[ |
|
0 |
|
].squeeze() |
|
image_relevance = torch.FloatTensor(image_relevance) |
|
|
|
resize_dim = int(list(image_relevance.shape)[0]) |
|
|
|
image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim) |
|
|
|
|
|
image_relevance = torch.nn.functional.interpolate( |
|
image_relevance, size=img_dim, mode="bilinear" |
|
) |
|
image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy() |
|
image_relevance = (image_relevance - image_relevance.min()) / ( |
|
1e-7 + image_relevance.max() - image_relevance.min() |
|
) |
|
image = image[0].permute(1, 2, 0).data.cpu().numpy() |
|
image = (image - image.min()) / (image.max() - image.min()) |
|
|
|
return image_relevance, image |
|
|
|
|
|
def interpret_rn( |
|
image, |
|
target_features, |
|
img_encoder, |
|
method, |
|
device, |
|
neg_saliency=False, |
|
img_dim=224, |
|
): |
|
image_relevance, image = rn_relevance( |
|
image, |
|
target_features, |
|
img_encoder, |
|
method, |
|
device, |
|
neg_saliency=neg_saliency, |
|
img_dim=img_dim, |
|
) |
|
vis = show_cam_on_image(image, image_relevance, neg_saliency=neg_saliency) |
|
vis = np.uint8(255 * vis) |
|
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) |
|
|
|
return vis |
|
|
|
|
|
|
|
def interpret_rn_overlapped( |
|
image, target_features, img_encoder, method, device, img_dim=224 |
|
): |
|
pos_image_relevance, _ = rn_relevance( |
|
image, |
|
target_features, |
|
img_encoder, |
|
method, |
|
device, |
|
neg_saliency=False, |
|
img_dim=img_dim, |
|
) |
|
neg_image_relevance, image = rn_relevance( |
|
image, |
|
target_features, |
|
img_encoder, |
|
method, |
|
device, |
|
neg_saliency=True, |
|
img_dim=img_dim, |
|
) |
|
|
|
vis = show_overlapped_cam(image, neg_image_relevance, pos_image_relevance) |
|
vis = np.uint8(255 * vis) |
|
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) |
|
|
|
return vis |
|
|
|
|
|
|
|
def rn_perword_relevance( |
|
image, |
|
text, |
|
clip_model, |
|
clip_tokenizer, |
|
method, |
|
device, |
|
masked_word="", |
|
data_only=False, |
|
img_dim=224, |
|
): |
|
clip_model.eval() |
|
|
|
main_text = clip_tokenizer(text).to(device) |
|
|
|
masked_text = re.sub(masked_word, "", text) |
|
masked_text = clip_tokenizer(masked_text).to(device) |
|
|
|
|
|
main_text_features = clip_model.encode_text(main_text) |
|
masked_text_features = clip_model.encode_text(masked_text) |
|
|
|
|
|
|
|
|
|
main_text_features_norm = main_text_features.norm(dim=-1, keepdim=True) |
|
main_text_features_new = main_text_features / main_text_features_norm |
|
|
|
masked_text_features_norm = masked_text_features.norm(dim=-1, keepdim=True) |
|
masked_text_features_new = masked_text_features / masked_text_features_norm |
|
|
|
target_encoding = main_text_features_new - masked_text_features_new |
|
|
|
target_layers = [clip_model.visual.layer4[-1]] |
|
|
|
cam = method( |
|
model=clip_model.visual, |
|
target_layers=target_layers, |
|
use_cuda=torch.cuda.is_available() and device != "cpu", |
|
) |
|
|
|
|
|
|
|
image_relevance = cam(input_tensor=image, target_encoding=target_encoding)[ |
|
0 |
|
].squeeze() |
|
image_relevance = torch.FloatTensor(image_relevance) |
|
|
|
resize_dim = int(list(image_relevance.shape)[0]) |
|
|
|
image_relevance = image_relevance.reshape(1, 1, resize_dim, resize_dim) |
|
|
|
|
|
image_relevance = torch.nn.functional.interpolate( |
|
image_relevance, size=img_dim, mode="bilinear" |
|
) |
|
image_relevance = image_relevance.reshape(img_dim, img_dim).data.cpu().numpy() |
|
image_relevance = (image_relevance - image_relevance.min()) / ( |
|
1e-7 + image_relevance.max() - image_relevance.min() |
|
) |
|
|
|
if data_only: |
|
return image_relevance |
|
|
|
image = image[0].permute(1, 2, 0).data.cpu().numpy() |
|
image = (image - image.min()) / (image.max() - image.min()) |
|
|
|
return image_relevance |
|
|
|
|
|
def interpret_perword_rn( |
|
image, |
|
text, |
|
clip_model, |
|
clip_tokenizer, |
|
method, |
|
device, |
|
masked_word="", |
|
data_only=False, |
|
img_dim=224, |
|
): |
|
image_relevance = rn_perword_relevance( |
|
image, |
|
text, |
|
clip_model, |
|
clip_tokenizer, |
|
method, |
|
device, |
|
masked_word, |
|
data_only=data_only, |
|
img_dim=img_dim, |
|
) |
|
vis = show_cam_on_image(image, image_relevance) |
|
vis = np.uint8(255 * vis) |
|
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) |
|
|
|
return vis |
|
|
|
|