sayakpaul HF staff commited on
Commit
cdee39b
1 Parent(s): ec36a5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -18
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import cv2
2
  import gradio as gr
3
  import numpy as np
4
  import tensorflow as tf
@@ -14,13 +13,9 @@ _RESOLUTION = 224
14
  def get_model() -> tf.keras.Model:
15
  """Initiates a tf.keras.Model from HF Hub."""
16
  inputs = tf.keras.Input((_RESOLUTION, _RESOLUTION, 3))
17
- hub_module = from_pretrained_keras(
18
- "probing-vits/cait_xxs24_224_classification"
19
- )
20
 
21
- logits, sa_atn_score_dict, ca_atn_score_dict = hub_module(
22
- inputs, training=False
23
- )
24
 
25
  return tf.keras.Model(
26
  inputs, [logits, sa_atn_score_dict, ca_atn_score_dict]
@@ -38,17 +33,15 @@ def show_plot(image):
38
  _, _, ca_atn_score_dict = _MODEL.predict(preprocessed_image)
39
 
40
  # Compute the saliency map and superimpose.
41
- result_first_block = utils.get_cls_attention_map(
42
  preprocessed_image, ca_atn_score_dict, block_key="ca_ffn_block_0_att"
43
  )
44
- heatmap = cv2.applyColorMap(
45
- np.uint8(255 * result_first_block), cv2.COLORMAP_CIVIDIS
46
- )
47
- heatmap = np.float32(heatmap)
 
48
 
49
- saliency_map = heatmap + original_image
50
- saliency_map = np.clip(saliency_map, 0.0, 255.0).astype(np.uint8)
51
- return Image.fromarray(saliency_map)
52
 
53
 
54
  title = "Generate Class Saliency Plots"
@@ -57,10 +50,10 @@ article = "Class saliency maps as investigated in [Going deeper with Image Trans
57
  iface = gr.Interface(
58
  show_plot,
59
  inputs=gr.inputs.Image(type="pil", label="Input Image"),
60
- outputs="image",
61
  title=title,
62
  article=article,
63
  allow_flagging="never",
64
- examples=[["./butterfly.jpg"]],
65
  )
66
- iface.launch()
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import tensorflow as tf
 
13
  def get_model() -> tf.keras.Model:
14
  """Initiates a tf.keras.Model from HF Hub."""
15
  inputs = tf.keras.Input((_RESOLUTION, _RESOLUTION, 3))
16
+ hub_module = from_pretrained_keras("probing-vits/cait_xxs24_224_classification")
 
 
17
 
18
+ logits, sa_atn_score_dict, ca_atn_score_dict = hub_module(inputs, training=False)
 
 
19
 
20
  return tf.keras.Model(
21
  inputs, [logits, sa_atn_score_dict, ca_atn_score_dict]
 
33
  _, _, ca_atn_score_dict = _MODEL.predict(preprocessed_image)
34
 
35
  # Compute the saliency map and superimpose.
36
+ saliency_attention = utils.get_cls_attention_map(
37
  preprocessed_image, ca_atn_score_dict, block_key="ca_ffn_block_0_att"
38
  )
39
+ fig = plt.figure()
40
+ plt.imshow(original_image.astype("int32"))
41
+ plt.imshow(saliency_attention.squeeze(), cmap="cividis", alpha=0.9)
42
+ plt.axis("off")
43
+ return fig
44
 
 
 
 
45
 
46
 
47
  title = "Generate Class Saliency Plots"
 
50
  iface = gr.Interface(
51
  show_plot,
52
  inputs=gr.inputs.Image(type="pil", label="Input Image"),
53
+ outputs=gr.outputs.Plot(type="auto"),
54
  title=title,
55
  article=article,
56
  allow_flagging="never",
57
+ examples=[["./butterfly_cropped.png"]],
58
  )
59
+ iface.launch(debug=True)