sayakpaul HF staff commited on
Commit
aac57a1
1 Parent(s): 2a6b8e9

better viz.

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -7,6 +7,7 @@ from timm import create_model
7
  from timm.models.layers import PatchEmbed
8
  from torchvision.models.feature_extraction import create_feature_extractor
9
  from torchvision.transforms import functional as F
 
10
 
11
  cait_model = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
12
  transform = timm.data.create_transform(
@@ -16,7 +17,7 @@ transform = timm.data.create_transform(
16
  patch_size = 16
17
 
18
 
19
- def create_attn_extractor(model, block_id=0):
20
  """Creates a model that produces the softmax attention scores.
21
  References:
22
  https://github.com/huggingface/pytorch-image-models/discussions/926
@@ -73,6 +74,13 @@ def generate_plot(processed_map):
73
  fig.tight_layout()
74
  return fig
75
 
 
 
 
 
 
 
 
76
 
77
  def generate_class_attn_map(image, block_id=0):
78
  """Collates the above utilities together for generating
@@ -85,7 +93,10 @@ def generate_class_attn_map(image, block_id=0):
85
 
86
  block_key = f"blocks_token_only.{block_id}.attn.softmax"
87
  processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
88
- return generate_plot(processed_cls_attn_map)
 
 
 
89
 
90
 
91
  title = "Class Attention Maps"
@@ -97,7 +108,7 @@ iface = gr.Interface(
97
  gr.inputs.Image(type="pil", label="Input Image"),
98
  gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
99
  ],
100
- outputs=[gr.Plot(type="auto").style()],
101
  title=title,
102
  article=article,
103
  allow_flagging="never",
 
7
  from timm.models.layers import PatchEmbed
8
  from torchvision.models.feature_extraction import create_feature_extractor
9
  from torchvision.transforms import functional as F
10
+ import glob
11
 
12
  cait_model = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
13
  transform = timm.data.create_transform(
 
17
  patch_size = 16
18
 
19
 
20
+ def create_attn_extractor(block_id=0):
21
  """Creates a model that produces the softmax attention scores.
22
  References:
23
  https://github.com/huggingface/pytorch-image-models/discussions/926
 
74
  fig.tight_layout()
75
  return fig
76
 
77
+ def serialize_images(processed_map):
78
+ """Serializes attention maps."""
79
+ for i in range(processed_map.shape[0]):
80
+ plt.imshow(processed_map[i].numpy())
81
+ plt.tile(f"Attention head: {i}")
82
+ plt.imsave(fname="attention_map_{i}.png")
83
+
84
 
85
  def generate_class_attn_map(image, block_id=0):
86
  """Collates the above utilities together for generating
 
93
 
94
  block_key = f"blocks_token_only.{block_id}.attn.softmax"
95
  processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
96
+
97
+ serialize_images(processed_cls_attn_map)
98
+ all_attn_img_paths = sorted(glob.glob("attention_map_*.png"))
99
+ return all_attn_img_paths
100
 
101
 
102
  title = "Class Attention Maps"
 
108
  gr.inputs.Image(type="pil", label="Input Image"),
109
  gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
110
  ],
111
+ outputs=gr.Gallery().style(grid=[2], height="auto"),
112
  title=title,
113
  article=article,
114
  allow_flagging="never",