sayakpaul HF staff commited on
Commit
ed1b9ba
·
1 Parent(s): 76fdd25

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from huggingface_hub.keras_mixin import from_pretrained_keras
6
+ from PIL import Image
7
+
8
+ import utils
9
+
10
+ _RESOLUTION = 224
11
+
12
+
13
+ def get_model() -> tf.keras.Model:
14
+ """Initiates a tf.keras.Model from HF Hub."""
15
+ inputs = tf.keras.Input((_RESOLUTION, _RESOLUTION, 3))
16
+ hub_module = from_pretrained_keras(
17
+ "probing-vits/cait_xxs24_224_classification"
18
+ )
19
+
20
+ logits, sa_atn_score_dict, ca_atn_score_dict = hub_module(
21
+ inputs, training=False
22
+ )
23
+
24
+ return tf.keras.Model(
25
+ inputs, [logits, sa_atn_score_dict, ca_atn_score_dict]
26
+ )
27
+
28
+
29
+ _MODEL = get_model()
30
+
31
+
32
+ def show_plot(image):
33
+ """Function to be called when user hits submit on the UI."""
34
+ original_image, preprocessed_image = utils.preprocess_image(
35
+ image, _RESOLUTION
36
+ )
37
+ _, _, ca_atn_score_dict = _MODEL.predict(preprocessed_image)
38
+
39
+ # Compute the saliency map and superimpose.
40
+ result_first_block = utils.get_cls_attention_map(
41
+ image, ca_atn_score_dict, block_key="ca_ffn_block_0_att"
42
+ )
43
+ heatmap = cv2.applyColorMap(
44
+ np.uint8(255 * result_first_block), cv2.COLORMAP_CIVIDIS
45
+ )
46
+ heatmap = np.float32(heatmap) / 255
47
+
48
+ original_image = original_image / 255.0
49
+ saliency_map = heatmap + original_image
50
+ saliency_map = saliency_map / np.max(saliency_map)
51
+ return Image.fromarray(saliency_map)
52
+
53
+
54
+ title = "Generate Class Saliency Plots"
55
+ article = "Class saliency maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.)."
56
+
57
+ iface = gr.Interface(
58
+ show_plot,
59
+ inputs=gr.inputs.Image(type="pil", label="Input Image"),
60
+ outputs="image",
61
+ title=title,
62
+ article=article,
63
+ allow_flagging="never",
64
+ examples=[["./butterfly.jpg"]],
65
+ )
66
+ iface.launch()