Spaces:
Runtime error
Runtime error
File size: 1,896 Bytes
d4ab5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
import torch.nn.functional as F
from math import sqrt
from torch import Tensor
from transformers import ViTForImageClassification
@torch.no_grad()
def attention_rollout(
images: Tensor,
vit: ViTForImageClassification,
discard_ratio: float = 0.9,
head_fusion: str = "mean",
device: str = "cpu",
) -> Tensor:
"""Performs the Attention Rollout method on a batch of images (https://arxiv.org/pdf/2005.00928.pdf)."""
# Forward pass and save attention maps
attentions = vit(images, output_attentions=True).attentions
B, _, H, W = images.shape # Batch size, channels, height, width
P = attentions[0].size(-1) # Number of patches
mask = torch.eye(P).to(device)
# Iterate over layers
for j, attention in enumerate(attentions):
if head_fusion == "mean":
attention_heads_fused = attention.mean(axis=1)
elif head_fusion == "max":
attention_heads_fused = attention.max(axis=1)[0]
elif head_fusion == "min":
attention_heads_fused = attention.min(axis=1)[0]
else:
raise "Attention head fusion type Not supported"
# Drop the lowest attentions, but don't drop the class token
flat = attention_heads_fused.view(B, -1)
_, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
indices = indices[indices != 0]
flat[0, indices] = 0
# I = torch.eye(P)
a = (attention_heads_fused + torch.eye(P).to(device)) / 2
a = a / a.sum(dim=-1).view(-1, P, 1)
mask = a @ mask
# Look at the total attention between the class token and the image patches
mask = mask[:, 0, 1:]
mask = mask / torch.max(mask)
N = int(sqrt(P))
S = int(H / N)
mask = mask.reshape(B, 1, N, N)
mask = F.interpolate(mask, scale_factor=S)
mask = mask.reshape(B, H, W)
return mask
|