File size: 1,896 Bytes
d4ab5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn.functional as F

from math import sqrt
from torch import Tensor
from transformers import ViTForImageClassification


@torch.no_grad()
def attention_rollout(
    images: Tensor,
    vit: ViTForImageClassification,
    discard_ratio: float = 0.9,
    head_fusion: str = "mean",
    device: str = "cpu",
) -> Tensor:
    """Performs the Attention Rollout method on a batch of images (https://arxiv.org/pdf/2005.00928.pdf)."""
    # Forward pass and save attention maps
    attentions = vit(images, output_attentions=True).attentions

    B, _, H, W = images.shape  # Batch size, channels, height, width
    P = attentions[0].size(-1)  # Number of patches

    mask = torch.eye(P).to(device)
    # Iterate over layers
    for j, attention in enumerate(attentions):
        if head_fusion == "mean":
            attention_heads_fused = attention.mean(axis=1)
        elif head_fusion == "max":
            attention_heads_fused = attention.max(axis=1)[0]
        elif head_fusion == "min":
            attention_heads_fused = attention.min(axis=1)[0]
        else:
            raise "Attention head fusion type Not supported"

        # Drop the lowest attentions, but don't drop the class token
        flat = attention_heads_fused.view(B, -1)
        _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
        indices = indices[indices != 0]
        flat[0, indices] = 0

        # I = torch.eye(P)
        a = (attention_heads_fused + torch.eye(P).to(device)) / 2
        a = a / a.sum(dim=-1).view(-1, P, 1)

        mask = a @ mask

    # Look at the total attention between the class token and the image patches
    mask = mask[:, 0, 1:]
    mask = mask / torch.max(mask)

    N = int(sqrt(P))
    S = int(H / N)

    mask = mask.reshape(B, 1, N, N)
    mask = F.interpolate(mask, scale_factor=S)
    mask = mask.reshape(B, H, W)

    return mask