File size: 3,385 Bytes
c4b2b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import sys

import cv2
import numpy as np
import torch

from imagenet_class_indices import CLS2IDX

sys.path.append("Transformer-Explainability")


from baselines.ViT.ViT_explanation_generator import LRP, Baselines
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit


# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam


# initialize ViT pretrained
model = vit_LRP(pretrained=True).cuda()
model.eval()
attribution_generator = LRP(model)
model_baseline = vit(pretrained=True).cuda()
model_baseline.eval()
baselines_generator = Baselines(model_baseline)


def generate_visualization(
    original_image, class_index=None, method="transformer_attribution", LRP=True
):
    if LRP:
        transformer_attribution = attribution_generator.generate_LRP(
            original_image.unsqueeze(0).cuda(), method=method, index=class_index
        ).detach()
    else:
        if method == "gradcam":
            transformer_attribution = baselines_generator.generate_cam_attn(
                original_image.unsqueeze(0).cuda(), index=class_index
            ).detach()
        else:
            transformer_attribution = baselines_generator.generate_rollout(
                original_image.unsqueeze(0).cuda()
            ).detach()
    if method != "full":
        transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
        transformer_attribution = torch.nn.functional.interpolate(
            transformer_attribution, scale_factor=16, mode="bilinear"
        )
    else:
        transformer_attribution = transformer_attribution.reshape(1, 1, 224, 224)
    transformer_attribution = (
        transformer_attribution.reshape(224, 224).data.cpu().numpy()
    )
    transformer_attribution = (
        transformer_attribution - transformer_attribution.min()
    ) / (transformer_attribution.max() - transformer_attribution.min())

    image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (
        image_transformer_attribution - image_transformer_attribution.min()
    ) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis


def print_top_classes(predictions, **kwargs):
    # Print Top-5 predictions
    prob = torch.softmax(predictions, dim=1)
    class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
    max_str_len = 0
    class_names = []
    for cls_idx in class_indices:
        class_names.append(CLS2IDX[cls_idx])
        if len(CLS2IDX[cls_idx]) > max_str_len:
            max_str_len = len(CLS2IDX[cls_idx])

    print("Top 5 classes:")
    for cls_idx in class_indices:
        output_string = "\t{} : {}".format(cls_idx, CLS2IDX[cls_idx])
        output_string += " " * (max_str_len - len(CLS2IDX[cls_idx])) + "\t\t"
        output_string += "value = {:.3f}\t prob = {:.1f}%".format(
            predictions[0, cls_idx], 100 * prob[0, cls_idx]
        )
        print(output_string)