Spaces:
Runtime error
Runtime error
File size: 2,099 Bytes
0dd57bb ced60bc 0dd57bb ced60bc 0dd57bb ced60bc 0dd57bb ced60bc f03296b 3c1eaee f03296b 3c1eaee f03296b 3c1eaee f03296b 3c1eaee f03296b ced60bc 1c951f1 ced60bc 12e48bb ced60bc 12e48bb ced60bc f03296b ced60bc f03296b ced60bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from huggingface_hub.keras_mixin import from_pretrained_keras
import gradio as gr
import tensorflow as tf
from PIL import Image
import utils
_RESOLUTION = 224
def get_model() -> tf.keras.Model:
"""Initiates a tf.keras.Model from HF Hub."""
inputs = tf.keras.Input((_RESOLUTION, _RESOLUTION, 3))
hub_module = from_pretrained_keras("probing-vits/cait_xxs24_224_classification")
logits, sa_atn_score_dict, ca_atn_score_dict = hub_module(inputs, training=False)
return tf.keras.Model(
inputs, [logits, sa_atn_score_dict, ca_atn_score_dict]
)
_MODEL = get_model()
def plot(attentions: np.ndarray):
"""Plots the attention maps from individual attention heads."""
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
img_count = 0
for i in range(attentions.shape[-1]):
if img_count < attentions.shape[-1]:
axes[i].imshow(attentions[:, :, img_count])
axes[i].title.set_text(f"Attention head: {img_count}")
axes[i].axis("off")
img_count += 1
fig.tight_layout()
return fig
def show_plot(image):
"""Function to be called when user hits submit on the UI."""
_, preprocessed_image = utils.preprocess_image(
image, _RESOLUTION
)
_, _, ca_atn_score_dict = _MODEL.predict(preprocessed_image)
result_first_block = utils.get_cls_attention_map(
preprocessed_image, ca_atn_score_dict, block_key="ca_ffn_block_0_att"
)
result_second_block = utils.get_cls_attention_map(
preprocessed_image, ca_atn_score_dict, block_key="ca_ffn_block_1_att"
)
return plot(result_first_block), plot(result_second_block)
title = "Generate Class Attention Plots"
article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.)."
iface = gr.Interface(
show_plot,
inputs=gr.inputs.Image(type="pil", label="Input Image"),
outputs="plot",
title=title,
article=article,
allow_flagging="never",
examples=[["./butterfly.jpg"]],
)
iface.launch()
|