Spaces:
Runtime error
Runtime error
Upload utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import tensorflow as tf
|
5 |
+
from PIL import Image
|
6 |
+
from tensorflow import keras
|
7 |
+
|
8 |
+
RESOLUTION = 224
|
9 |
+
PATCH_SIZE = 16
|
10 |
+
|
11 |
+
|
12 |
+
crop_layer = keras.layers.CenterCrop(RESOLUTION, RESOLUTION)
|
13 |
+
norm_layer = keras.layers.Normalization(
|
14 |
+
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
|
15 |
+
variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def preprocess_image(orig_image: Image, size: int):
|
20 |
+
"""Image preprocessing utility."""
|
21 |
+
image = np.array(orig_image)
|
22 |
+
image_resized = tf.expand_dims(image, 0)
|
23 |
+
resize_size = int((256 / 224) * size)
|
24 |
+
image_resized = tf.image.resize(
|
25 |
+
image_resized, (resize_size, resize_size), method="bicubic"
|
26 |
+
)
|
27 |
+
image_resized = crop_layer(image_resized)
|
28 |
+
return orig_image, norm_layer(image_resized).numpy()
|
29 |
+
|
30 |
+
|
31 |
+
# Reference:
|
32 |
+
# https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
|
33 |
+
|
34 |
+
|
35 |
+
def get_cls_attention_map(
|
36 |
+
preprocessed_image: np.ndarray,
|
37 |
+
attn_score_dict: Dict[str, np.ndarray],
|
38 |
+
block_key="ca_ffn_block_0_att",
|
39 |
+
):
|
40 |
+
"""Utility to generate class-attention map modeling spatial-class relationships."""
|
41 |
+
w_featmap = preprocessed_image.shape[2] // PATCH_SIZE
|
42 |
+
h_featmap = preprocessed_image.shape[1] // PATCH_SIZE
|
43 |
+
|
44 |
+
attention_scores = attn_score_dict[block_key]
|
45 |
+
nh = attention_scores.shape[1] # Number of attention heads.
|
46 |
+
|
47 |
+
# Taking the representations from CLS token.
|
48 |
+
attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
|
49 |
+
|
50 |
+
# Reshape the attention scores to resemble mini patches.
|
51 |
+
attentions = attentions.reshape(nh, w_featmap, h_featmap)
|
52 |
+
|
53 |
+
attentions = attentions.transpose((1, 2, 0))
|
54 |
+
|
55 |
+
# Resize the attention patches to 224x224 (224: 14x16)
|
56 |
+
attentions = tf.image.resize(
|
57 |
+
attentions,
|
58 |
+
size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
|
59 |
+
method="bicubic",
|
60 |
+
)
|
61 |
+
|
62 |
+
return attentions.numpy()
|