File size: 4,765 Bytes
6718d11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
'''
Grad-CAM visualization utilities

- Based on https://keras.io/examples/vision/grad_cam/

---
- 2021-12-18 jkang first created 
- 2022-01-16 
    - copied from https://huggingface.co/spaces/jkang/demo-gradcam-imagenet/blob/main/utils.py
    - updated for artis/trend classifier
'''
import matplotlib.cm as cm

import os
import re
from glob import glob
import numpy as np
import tensorflow as tf
tfk = tf.keras
K = tfk.backend

# Disable GPU for testing
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'


def get_imagenet_classes():
    '''Retrieve all 1000 imagenet classes/labels as dictionaries'''
    classes = tfk.applications.imagenet_utils.decode_predictions(
        np.expand_dims(np.arange(1000), 0), top=1000
    )
    idx2lab = {cla[2]: cla[1] for cla in classes[0]}
    lab2idx = {idx2lab[idx]: idx for idx in idx2lab}
    return idx2lab, lab2idx


def search_by_name(str_part):
    '''Search imagenet class by partial matching string'''
    results = [key for key in list(lab2idx.keys()) if re.search(str_part, key)]
    if len(results) != 0:
        return [(key, lab2idx[key]) for key in results]
    else:
        return []


def get_xception_model():
    '''Get model to use'''
    base_model = tfk.applications.xception.Xception
    preprocessor = tfk.applications.xception.preprocess_input
    decode_predictions = tfk.applications.xception.decode_predictions
    last_conv_layer_name = "block14_sepconv2_act"

    model = base_model(weights='imagenet')
    grad_model = tfk.models.Model(
        inputs=[model.inputs],
        outputs=[model.get_layer(last_conv_layer_name).output,
                 model.output]
    )
    return model, grad_model, preprocessor, decode_predictions


def get_img_4d_array(image_file, image_size=(299, 299)):
    '''Load image as 4d array'''
    img = tfk.preprocessing.image.load_img(
        image_file, target_size=image_size)  # PIL obj
    img_array = tfk.preprocessing.image.img_to_array(
        img)  # float32 numpy array
    img_array = np.expand_dims(img_array, axis=0)  # 3d -> 4d (1,299,299,3)
    return img_array


def make_gradcam_heatmap(grad_model, img_array, pred_idx=None):
    '''Generate heatmap to overlay with
    - img_array: 4d numpy array
    - pred_idx: eg. index out of 1000 imagenet classes
        if None, argmax is chosen from prediction
    '''
    # Get gradient of pred class w.r.t. last conv activation
    with tf.GradientTape() as tape:
        last_conv_act, predictions = grad_model(img_array)
        if pred_idx == None:
            pred_idx = tf.argmax(predictions[0])
        class_channel = predictions[:, pred_idx]  # (1,1000) => (1,)

    # d(class_channel/last_conv_act)
    grads = tape.gradient(class_channel, last_conv_act)
    pooled_grads = tf.reduce_mean(grads, axis=(
        0, 1, 2))  # (1,10,10,2048) => (2048,)

    # (10,10,2048) x (2048,1) => (10,10,1)
    heatmap = last_conv_act[0] @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)  # (10,10)

    # Normalize heatmap between 0 and 1
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap, pred_idx.numpy(), predictions.numpy().squeeze()


def align_image_with_heatmap(img_array, heatmap, alpha=0.3, cmap='jet'):
    '''Align the image with gradcam heatmap
    - img_array: 4d numpy array
    - heatmap: output of `def make_gradcam_heatmap()` as 2d numpy array
    '''
    img_array = img_array.squeeze()  # 4d => 3d

    # Rescale to 0-255 range
    heatmap_scaled = np.uint8(255 * heatmap)
    img_array_scaled = np.uint8(255 * img_array)

    colormap = cm.get_cmap(cmap)
    colors = colormap(np.arange(256))[:, :3]  # mapping RGB to heatmap
    heatmap_colored = colors[heatmap_scaled]  # ? still unclear

    # Make RGB colorized heatmap
    heatmap_colored = (tfk.preprocessing.image.array_to_img(heatmap_colored)  # array => PIL
                       .resize((img_array.shape[1], img_array.shape[0])))
    heatmap_colored = tfk.preprocessing.image.img_to_array(
        heatmap_colored)  # PIL => array

    # Overlay image with heatmap
    overlaid_img = heatmap_colored * alpha + img_array_scaled
    overlaid_img = tfk.preprocessing.image.array_to_img(overlaid_img)
    return overlaid_img


if __name__ == '__main__':
    # Test GradCAM
    examples = sorted(glob(os.path.join('examples', '*.jpg')))
    idx2lab, lab2idx = get_imagenet_classes()

    model, grad_model, preprocessor, decode_predictions = get_xception_model()

    img_4d_array = get_img_4d_array(examples[0])
    img_4d_array = preprocessor(img_4d_array)

    heatmap = make_gradcam_heatmap(grad_model, img_4d_array, pred_idx=None)

    img_pil = align_image_with_heatmap(
        img_4d_array, heatmap, alpha=0.3, cmap='jet')

    img_pil.save('test.jpg')
    print('done')