File size: 2,540 Bytes
7754b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import numpy as np
import cv2

# rule 5 from paper
def avg_heads(cam, grad):
    cam = cam.reshape(-1, cam.shape[-3], cam.shape[-2], cam.shape[-1])
    grad = grad.reshape(-1, cam.shape[-3], grad.shape[-2], grad.shape[-1])
    cam = grad * cam
    cam = cam.clamp(min=0).mean(dim=1)
    return cam

# rule 6 from paper
def apply_self_attention_rules(R_ss, cam_ss):
    R_ss_addition = torch.matmul(cam_ss, R_ss)
    return R_ss_addition

def upscale_relevance(relevance):
    relevance = relevance.reshape(-1, 1, 14, 14)
    relevance = torch.nn.functional.interpolate(relevance, scale_factor=16, mode='bilinear')

    # normalize between 0 and 1
    relevance = relevance.reshape(relevance.shape[0], -1)
    min = relevance.min(1, keepdim=True)[0]
    max = relevance.max(1, keepdim=True)[0]
    relevance = (relevance - min) / (max - min)

    relevance = relevance.reshape(-1, 1, 224, 224)
    return relevance

def generate_relevance(model, input, index=None):
    # a batch of samples
    batch_size = input.shape[0]
    output = model(input, register_hook=True)
    if index == None:
        index = np.argmax(output.cpu().data.numpy(), axis=-1)
        index = torch.tensor(index)

    one_hot = np.zeros((batch_size, output.shape[-1]), dtype=np.float32)
    one_hot[torch.arange(batch_size), index.data.cpu().numpy()] = 1
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot.to(input.device) * output)
    model.zero_grad()

    num_tokens = model.blocks[0].attn.get_attention_map().shape[-1]
    R = torch.eye(num_tokens, num_tokens).cuda()
    R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
    for i, blk in enumerate(model.blocks):
        grad = torch.autograd.grad(one_hot, [blk.attn.attention_map], retain_graph=True)[0]
        cam = blk.attn.get_attention_map()
        cam = avg_heads(cam, grad)
        R = R + apply_self_attention_rules(R, cam)
    relevance = R[:, 0, 1:]
    return upscale_relevance(relevance)

# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam


def get_image_with_relevance(image, relevance):
    image = image.permute(1, 2, 0)
    relevance = relevance.permute(1, 2, 0)
    image = (image - image.min()) / (image.max() - image.min())
    image = 255 * image
    vis = image * relevance
    return vis.data.cpu().numpy()