sayakpaul HF staff commited on
Commit
3c1eaee
1 Parent(s): f03296b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -24,19 +24,19 @@ _MODEL = get_model()
24
 
25
 
26
  def plot(attentions: np.ndarray):
27
- """Plots the attention maps from individual attention heads."""
28
  fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
29
- img_count = 0
30
 
31
- for i in range(attentions.shape[-1]):
32
- if img_count < attentions.shape[-1]:
33
- axes[i].imshow(attentions[:, :, img_count])
34
- axes[i].title.set_text(f"Attention head: {img_count}")
35
- axes[i].axis("off")
36
- img_count += 1
37
 
38
- fig.tight_layout()
39
- return fig
40
 
41
 
42
  def show_plot(image):
 
24
 
25
 
26
  def plot(attentions: np.ndarray):
27
+ """Plots the attention maps from individual attention heads."""
28
  fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
29
+ img_count = 0
30
 
31
+ for i in range(attentions.shape[-1]):
32
+ if img_count < attentions.shape[-1]:
33
+ axes[i].imshow(attentions[:, :, img_count])
34
+ axes[i].title.set_text(f"Attention head: {img_count}")
35
+ axes[i].axis("off")
36
+ img_count += 1
37
 
38
+ fig.tight_layout()
39
+ return fig
40
 
41
 
42
  def show_plot(image):