File size: 3,998 Bytes
104a2dd
 
 
 
2a6b8e9
104a2dd
 
 
 
aac57a1
104a2dd
dad7fe2
 
 
104a2dd
 
dad7fe2
104a2dd
 
aac57a1
104a2dd
 
 
 
 
dad7fe2
104a2dd
 
 
 
 
 
 
045d37d
104a2dd
 
dad7fe2
 
104a2dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dad7fe2
104a2dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aac57a1
 
 
 
 
 
 
104a2dd
 
 
 
dad7fe2
 
104a2dd
 
 
 
 
 
aac57a1
 
 
 
104a2dd
 
 
 
 
 
 
 
 
 
 
aac57a1
104a2dd
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import timm
import torch
from timm import create_model
from timm.models.layers import PatchEmbed
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.transforms import functional as F
import glob

CAIT_MODEL = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
TRANSFORM = timm.data.create_transform(
    **timm.data.resolve_data_config(CAIT_MODEL.pretrained_cfg)
)

PATCH_SIZE = 16


def create_attn_extractor(block_id=0):
    """Creates a model that produces the softmax attention scores.
    References:
        https://github.com/huggingface/pytorch-image-models/discussions/926
    """
    feature_extractor = create_feature_extractor(
        CAIT_MODEL,
        return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"],
        tracer_kwargs={"leaf_modules": [PatchEmbed]},
    )
    return feature_extractor


def get_cls_attention_map(
    image, attn_score_dict, block_key="blocks_token_only.0.attn.softmax"
):
    """Prepares attention maps so that they can be visualized."""
    w_featmap = image.shape[3] // PATCH_SIZE
    h_featmap = image.shape[2] // PATCH_SIZE

    attention_scores = attn_score_dict[block_key]
    nh = attention_scores.shape[1]  # Number of attention heads.

    # Taking the representations from CLS token.
    attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
    print(attentions.shape)

    # Reshape the attention scores to resemble mini patches.
    attentions = attentions.reshape(nh, w_featmap, h_featmap)
    print(attentions.shape)

    # Resize the attention patches to 224x224 (224: 14x16)
    attentions = F.resize(
        attentions,
        size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
        interpolation=3,
    )
    print(attentions.shape)

    return attentions


def generate_plot(processed_map):
    """Generates a class attention map plot."""
    fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
    img_count = 0

    for i in range(processed_map.shape[0]):
        if img_count < processed_map.shape[0]:
            axes[i].imshow(processed_map[img_count].numpy())
            axes[i].title.set_text(f"Attention head: {img_count}")
            axes[i].axis("off")
            img_count += 1

    fig.tight_layout()
    return fig

def serialize_images(processed_map):
    """Serializes attention maps."""
    for i in range(processed_map.shape[0]):
        plt.imshow(processed_map[i].numpy())
        plt.tile(f"Attention head: {i}")
        plt.imsave(fname="attention_map_{i}.png")


def generate_class_attn_map(image, block_id=0):
    """Collates the above utilities together for generating
    a class attention map."""
    image_tensor = TRANSFORM(image).unsqueeze(0)
    feature_extractor = create_attn_extractor(block_id)

    with torch.no_grad():
        out = feature_extractor(image_tensor)

    block_key = f"blocks_token_only.{block_id}.attn.softmax"
    processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
    
    serialize_images(processed_cls_attn_map)
    all_attn_img_paths = sorted(glob.glob("attention_map_*.png"))
    return all_attn_img_paths


title = "Class Attention Maps"
article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.). We use the [cait_xxs24_224](https://huggingface.co/timm/cait_xxs24_224.fb_dist_in1k) variant of CaiT. One can find all the other variants [here](https://huggingface.co/models?search=cait)."

iface = gr.Interface(
    generate_class_attn_map,
    inputs=[
        gr.inputs.Image(type="pil", label="Input Image"),
        gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
    ],
    outputs=gr.Gallery().style(grid=[2], height="auto"),
    title=title,
    article=article,
    allow_flagging="never",
    cache_examples=True,
    examples=[["./bird.png", 0]],
)
iface.launch()