sayakpaul HF staff commited on
Commit
ced60bc
1 Parent(s): 23a9607

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import tensorflow_hub as hub
4
+ from PIL import Image
5
+
6
+ import utils
7
+
8
+ _RESOLUTION = 224
9
+ _MODEL_PATH = "gs://cait-tf/cait_xxs24_224"
10
+
11
+
12
+ def get_model() -> tf.keras.Model:
13
+ """Initiates a tf.keras.Model from TF-Hub."""
14
+ inputs = tf.keras.Input((_RESOLUTION, _RESOLUTION, 3))
15
+ hub_module = hub.KerasLayer(_MODEL_PATH)
16
+
17
+ logits, sa_atn_score_dict, ca_atn_score_dict = hub_module(inputs)
18
+
19
+ return tf.keras.Model(
20
+ inputs, [logits, sa_atn_score_dict, ca_atn_score_dict]
21
+ )
22
+
23
+
24
+ _MODEL = get_model()
25
+
26
+
27
+ def show_plot(image):
28
+ """Function to be called when user hits submit on the UI."""
29
+ _, preprocessed_image = utils.preprocess_image(
30
+ image, "deit_tiny_patch16_224"
31
+ )
32
+ _, _, ca_atn_score_dict = _MODEL.predict(preprocessed_image)
33
+
34
+ result_first_block = utils.get_cls_attention_map(
35
+ image, ca_atn_score_dict, block_key="ca_ffn_block_0_att"
36
+ )
37
+ result_second_block = utils.get_cls_attention_map(
38
+ image, ca_atn_score_dict, block_key="ca_ffn_block_1_att"
39
+ )
40
+ return Image.fromarray(result_first_block), Image.fromarray(
41
+ result_second_block
42
+ )
43
+
44
+
45
+ title = "Generate Class Attention Plots"
46
+ article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.)."
47
+
48
+ iface = gr.Interface(
49
+ show_plot,
50
+ inputs=gr.inputs.Image(type="pil", label="Input Image"),
51
+ outputs="image",
52
+ title=title,
53
+ article=article,
54
+ allow_flagging="never",
55
+ examples=[["./butterfly.jpg"]],
56
+ )
57
+ iface.launch()