sayakpaul HF staff commited on
Commit
104a2dd
·
1 Parent(s): b450510

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import timm
5
+ from timm import create_model
6
+ from timm.models.layers import PatchEmbed
7
+ from torchvision.models.feature_extraction import create_feature_extractor
8
+ from torchvision.transforms import functional as F
9
+
10
+ cait_model = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
11
+ transform = timm.data.create_transform(
12
+ **timm.data.resolve_data_config(cait_model.pretrained_cfg)
13
+ )
14
+
15
+ patch_size = 16
16
+
17
+
18
+ def create_attn_extractor(model, block_id=0):
19
+ """Creates a model that produces the softmax attention scores.
20
+ References:
21
+ https://github.com/huggingface/pytorch-image-models/discussions/926
22
+ """
23
+ feature_extractor = create_feature_extractor(
24
+ cait_model,
25
+ return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"],
26
+ tracer_kwargs={"leaf_modules": [PatchEmbed]},
27
+ )
28
+ return feature_extractor
29
+
30
+
31
+ def get_cls_attention_map(
32
+ image, attn_score_dict=out, block_key="blocks_token_only.0.attn.softmax"
33
+ ):
34
+ """Prepares attention maps so that they can be visualized."""
35
+ w_featmap = image.shape[3] // patch_size
36
+ h_featmap = image.shape[2] // patch_size
37
+
38
+ attention_scores = attn_score_dict[block_key]
39
+ nh = attention_scores.shape[1] # Number of attention heads.
40
+
41
+ # Taking the representations from CLS token.
42
+ attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
43
+ print(attentions.shape)
44
+
45
+ # Reshape the attention scores to resemble mini patches.
46
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
47
+ print(attentions.shape)
48
+
49
+ # Resize the attention patches to 224x224 (224: 14x16)
50
+ attentions = F.resize(
51
+ attentions,
52
+ size=(h_featmap * patch_size, w_featmap * patch_size),
53
+ interpolation=3,
54
+ )
55
+ print(attentions.shape)
56
+
57
+ return attentions
58
+
59
+
60
+ def generate_plot(processed_map):
61
+ """Generates a class attention map plot."""
62
+ fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
63
+ img_count = 0
64
+
65
+ for i in range(processed_map.shape[0]):
66
+ if img_count < processed_map.shape[0]:
67
+ axes[i].imshow(processed_map[img_count].numpy())
68
+ axes[i].title.set_text(f"Attention head: {img_count}")
69
+ axes[i].axis("off")
70
+ img_count += 1
71
+
72
+ fig.tight_layout()
73
+ return fig
74
+
75
+
76
+ def generate_class_attn_map(image, block_id=0):
77
+ """Collates the above utilities together for generating
78
+ a class attention map."""
79
+ image_tensor = transform(image).unsqueeze(0)
80
+ feature_extractor = create_attn_extractor(cait_model, block_id)
81
+
82
+ with torch.no_grad():
83
+ out = feature_extractor(image_tensor)
84
+
85
+ block_key = f"blocks_token_only.{block_id}.attn.softmax"
86
+ processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
87
+ return generate_plot(processed_cls_attn_map)
88
+
89
+
90
+ title = "Class Attention Maps"
91
+ article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.). We use the [cait_xxs24_224](https://huggingface.co/timm/cait_xxs24_224.fb_dist_in1k) variant of CaiT. One can find all the other variants [here](https://huggingface.co/models?search=cait)."
92
+
93
+ iface = gr.Interface(
94
+ generate_class_attn_map,
95
+ inputs=[
96
+ gr.inputs.Image(type="pil", label="Input Image"),
97
+ gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
98
+ ],
99
+ outputs=[gr.Plot(type="auto").style()],
100
+ title=title,
101
+ article=article,
102
+ allow_flagging="never",
103
+ cache_examples=True,
104
+ examples=[["./bird.png", 0]],
105
+ )
106
+ iface.launch()