sayakpaul HF staff commited on
Commit
f03296b
1 Parent(s): 12e48bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -23,6 +23,22 @@ def get_model() -> tf.keras.Model:
23
  _MODEL = get_model()
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def show_plot(image):
27
  """Function to be called when user hits submit on the UI."""
28
  _, preprocessed_image = utils.preprocess_image(
@@ -36,9 +52,7 @@ def show_plot(image):
36
  result_second_block = utils.get_cls_attention_map(
37
  preprocessed_image, ca_atn_score_dict, block_key="ca_ffn_block_1_att"
38
  )
39
- return Image.fromarray(result_first_block), Image.fromarray(
40
- result_second_block
41
- )
42
 
43
 
44
  title = "Generate Class Attention Plots"
@@ -47,7 +61,7 @@ article = "Class attention maps as investigated in [Going deeper with Image Tran
47
  iface = gr.Interface(
48
  show_plot,
49
  inputs=gr.inputs.Image(type="pil", label="Input Image"),
50
- outputs="image",
51
  title=title,
52
  article=article,
53
  allow_flagging="never",
 
23
  _MODEL = get_model()
24
 
25
 
26
+ def plot(attentions: np.ndarray):
27
+ """Plots the attention maps from individual attention heads."""
28
+ fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
29
+ img_count = 0
30
+
31
+ for i in range(attentions.shape[-1]):
32
+ if img_count < attentions.shape[-1]:
33
+ axes[i].imshow(attentions[:, :, img_count])
34
+ axes[i].title.set_text(f"Attention head: {img_count}")
35
+ axes[i].axis("off")
36
+ img_count += 1
37
+
38
+ fig.tight_layout()
39
+ return fig
40
+
41
+
42
  def show_plot(image):
43
  """Function to be called when user hits submit on the UI."""
44
  _, preprocessed_image = utils.preprocess_image(
 
52
  result_second_block = utils.get_cls_attention_map(
53
  preprocessed_image, ca_atn_score_dict, block_key="ca_ffn_block_1_att"
54
  )
55
+ return plot(result_first_block), plot(result_second_block)
 
 
56
 
57
 
58
  title = "Generate Class Attention Plots"
 
61
  iface = gr.Interface(
62
  show_plot,
63
  inputs=gr.inputs.Image(type="pil", label="Input Image"),
64
+ outputs="plot",
65
  title=title,
66
  article=article,
67
  allow_flagging="never",