Spaces:
Runtime error
Runtime error
File size: 3,998 Bytes
104a2dd 2a6b8e9 104a2dd aac57a1 104a2dd dad7fe2 104a2dd dad7fe2 104a2dd aac57a1 104a2dd dad7fe2 104a2dd 045d37d 104a2dd dad7fe2 104a2dd dad7fe2 104a2dd aac57a1 104a2dd dad7fe2 104a2dd aac57a1 104a2dd aac57a1 104a2dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import timm
import torch
from timm import create_model
from timm.models.layers import PatchEmbed
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.transforms import functional as F
import glob
CAIT_MODEL = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
TRANSFORM = timm.data.create_transform(
**timm.data.resolve_data_config(CAIT_MODEL.pretrained_cfg)
)
PATCH_SIZE = 16
def create_attn_extractor(block_id=0):
"""Creates a model that produces the softmax attention scores.
References:
https://github.com/huggingface/pytorch-image-models/discussions/926
"""
feature_extractor = create_feature_extractor(
CAIT_MODEL,
return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"],
tracer_kwargs={"leaf_modules": [PatchEmbed]},
)
return feature_extractor
def get_cls_attention_map(
image, attn_score_dict, block_key="blocks_token_only.0.attn.softmax"
):
"""Prepares attention maps so that they can be visualized."""
w_featmap = image.shape[3] // PATCH_SIZE
h_featmap = image.shape[2] // PATCH_SIZE
attention_scores = attn_score_dict[block_key]
nh = attention_scores.shape[1] # Number of attention heads.
# Taking the representations from CLS token.
attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
print(attentions.shape)
# Reshape the attention scores to resemble mini patches.
attentions = attentions.reshape(nh, w_featmap, h_featmap)
print(attentions.shape)
# Resize the attention patches to 224x224 (224: 14x16)
attentions = F.resize(
attentions,
size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
interpolation=3,
)
print(attentions.shape)
return attentions
def generate_plot(processed_map):
"""Generates a class attention map plot."""
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
img_count = 0
for i in range(processed_map.shape[0]):
if img_count < processed_map.shape[0]:
axes[i].imshow(processed_map[img_count].numpy())
axes[i].title.set_text(f"Attention head: {img_count}")
axes[i].axis("off")
img_count += 1
fig.tight_layout()
return fig
def serialize_images(processed_map):
"""Serializes attention maps."""
for i in range(processed_map.shape[0]):
plt.imshow(processed_map[i].numpy())
plt.tile(f"Attention head: {i}")
plt.imsave(fname="attention_map_{i}.png")
def generate_class_attn_map(image, block_id=0):
"""Collates the above utilities together for generating
a class attention map."""
image_tensor = TRANSFORM(image).unsqueeze(0)
feature_extractor = create_attn_extractor(block_id)
with torch.no_grad():
out = feature_extractor(image_tensor)
block_key = f"blocks_token_only.{block_id}.attn.softmax"
processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
serialize_images(processed_cls_attn_map)
all_attn_img_paths = sorted(glob.glob("attention_map_*.png"))
return all_attn_img_paths
title = "Class Attention Maps"
article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.). We use the [cait_xxs24_224](https://huggingface.co/timm/cait_xxs24_224.fb_dist_in1k) variant of CaiT. One can find all the other variants [here](https://huggingface.co/models?search=cait)."
iface = gr.Interface(
generate_class_attn_map,
inputs=[
gr.inputs.Image(type="pil", label="Input Image"),
gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
],
outputs=gr.Gallery().style(grid=[2], height="auto"),
title=title,
article=article,
allow_flagging="never",
cache_examples=True,
examples=[["./bird.png", 0]],
)
iface.launch()
|