sayakpaul HF staff commited on
Commit
83c742a
1 Parent(s): 893fce9
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -79,6 +79,7 @@ def serialize_images(processed_map):
79
  for i in range(processed_map.shape[0]):
80
  plt.imshow(processed_map[i].numpy())
81
  plt.title(f"Attention head: {i}", fontsize=14)
 
82
  plt.savefig(fname="attention_map_{i}.png")
83
 
84
 
@@ -95,7 +96,8 @@ def generate_class_attn_map(image, block_id=0):
95
  processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
96
 
97
  serialize_images(processed_cls_attn_map)
98
- all_attn_img_paths = sorted(glob.glob("attention_map_*.png"))
 
99
  return all_attn_img_paths
100
 
101
 
@@ -115,4 +117,4 @@ iface = gr.Interface(
115
  cache_examples=True,
116
  examples=[["./bird.png", 0]],
117
  )
118
- iface.launch()
 
79
  for i in range(processed_map.shape[0]):
80
  plt.imshow(processed_map[i].numpy())
81
  plt.title(f"Attention head: {i}", fontsize=14)
82
+ plt.axis("off")
83
  plt.savefig(fname="attention_map_{i}.png")
84
 
85
 
 
96
  processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
97
 
98
  serialize_images(processed_cls_attn_map)
99
+ all_attn_img_paths = sorted(glob.glob("*.png"))
100
+ print(f"Number of images: {len(all_attn_img_paths)}")
101
  return all_attn_img_paths
102
 
103
 
 
117
  cache_examples=True,
118
  examples=[["./bird.png", 0]],
119
  )
120
+ iface.launch(debug=True)