sayakpaul HF staff commited on
Commit
3e32c41
1 Parent(s): 047edad

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +66 -0
utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from PIL import Image
6
+ from tensorflow import keras
7
+
8
+ RESOLUTION = 224
9
+ PATCH_SIZE = 16
10
+
11
+
12
+ crop_layer = keras.layers.CenterCrop(RESOLUTION, RESOLUTION)
13
+ norm_layer = keras.layers.Normalization(
14
+ mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
15
+ variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
16
+ )
17
+
18
+
19
+ def preprocess_image(orig_image: Image, size: int):
20
+ """Image preprocessing utility."""
21
+ image = np.array(orig_image)
22
+ image_resized = tf.expand_dims(image, 0)
23
+ resize_size = int((256 / 224) * size)
24
+ image_resized = tf.image.resize(
25
+ image_resized, (resize_size, resize_size), method="bicubic"
26
+ )
27
+ image_resized = crop_layer(image_resized)
28
+ return image_resized.numpy().squeeze(), norm_layer(image_resized).numpy()
29
+
30
+
31
+ # Reference:
32
+ # https://github.com/facebookresearch/dino/blob/main/visualize_attention.py
33
+
34
+
35
+ def get_cls_attention_map(
36
+ preprocessed_image: np.ndarray,
37
+ attn_score_dict: Dict[str, np.ndarray],
38
+ block_key="ca_ffn_block_0_att",
39
+ ):
40
+ """Utility to generate class saliency map modeling spatial-class relationships."""
41
+ w_featmap = preprocessed_image.shape[2] // PATCH_SIZE
42
+ h_featmap = preprocessed_image.shape[1] // PATCH_SIZE
43
+
44
+ attention_scores = attn_score_dict[block_key]
45
+ nh = attention_scores.shape[1] # Number of attention heads.
46
+
47
+ # Taking the representations from CLS token.
48
+ attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
49
+
50
+ # Reshape the attention scores to resemble mini patches.
51
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
52
+
53
+ attentions = np.mean(attentions, axis=0)
54
+ attentions = (attentions - attentions.min()) / (
55
+ attentions.max() - attentions.min()
56
+ )
57
+ attentions = np.expand_dims(attentions, -1)
58
+
59
+ # Resize the attention patches to 224x224 (224: 14x16)
60
+ attentions = tf.image.resize(
61
+ attentions,
62
+ size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
63
+ method="bicubic",
64
+ )
65
+
66
+ return attentions.numpy()