File size: 3,878 Bytes
1d7c63d
 
 
 
dd58475
 
1d7c63d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd58475
c5343e6
 
1d7c63d
dd58475
 
 
 
 
 
 
c5343e6
0c61c42
 
1d7c63d
 
 
0c61c42
 
 
 
1d7c63d
 
dd58475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c61c42
 
 
 
dd58475
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import xplique
import tensorflow as tf
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
                                  SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
                                  GradCAMPP, Lime, KernelShap,SobolAttributionMethod,HsicAttributionMethod)
from xplique.attributions.global_sensitivity_analysis import LatinHypercube
import numpy as np
import matplotlib.pyplot as plt
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
BATCH_SIZE = 1

def show(img, p=False, **kwargs):
    img = np.array(img, dtype=np.float32)

    # check if channel first
    if img.shape[0] == 1:
        img = img[0]

    # check if cmap
    if img.shape[-1] == 1:
        img = img[:,:,0]
    elif img.shape[-1] == 3:
        img = img[:,:,::-1]

    # normalize
    if img.max() > 1 or img.min() < 0:
        img -= img.min(); img/=img.max()
    # check if clip percentile
    if p is not False:
        img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
    plt.imshow(img, **kwargs)
    plt.axis('off')
    
    #return img
    


def explain(model, input_image,size=600, n_classes=171) :
    """
    Generate explanations for a given model and dataset.
    :param model: The model to explain.
    :param X: The dataset.
    :param Y: The labels.
    :param explainer: The explainer to use.
    :param batch_size: The batch size to use.
    :return: The explanations.
    """
    
    # we only need the classification part of the model
    class_model = tf.keras.Model(model.input, model.output[1]) 
    
    explainers = [
        #Sobol, RISE, HSIC, Saliency
             #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
             #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
             #GradCAM(class_model),
             SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
             Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
                 preservation_probability=0.5),
             HsicAttributionMethod(class_model, 
                                  grid_size=7, nb_design=1500,
                                  sampler = LatinHypercube(binary=True)),
             Saliency(class_model),
             #
    ]
  
    cropped,repetitions = _clever_crop(input_image,(size,size))
    size_repetitions = int(size//(repetitions.numpy()+1))
    X = preprocess(cropped,size=size)
    predictions = class_model.predict(np.array([X]))
    #Y = np.argmax(predictions)
    top_5_indices = np.argsort(predictions[0])[-5:][::-1]
    #print(top_5_indices)
    X = np.expand_dims(X, 0)
    explanations = []
    for e,explainer in enumerate(explainers):
        print(f'{e}/{len(explainers)}')
        for i,Y in enumerate(top_5_indices):
            Y = tf.one_hot([Y], n_classes)
            print(f'{i}/{len(top_5_indices)}')
            phi = np.abs(explainer(X, Y))[0]
            if len(phi.shape) == 3:
                phi = np.mean(phi, -1)
            show(X[0][:,size_repetitions:2*size_repetitions,:])
            show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
            plt.savefig(f'phi_{e}{i}.png')
            explanations.append(f'phi_{e}{i}.png')
    # avg=[]
    # for i,Y in enumerate(top_5_indices):
    #     Y = tf.one_hot([Y], n_classes)
    #     print(f'{i}/{len(top_5_indices)}')
    #     phi = np.abs(explainer(X, Y))[0]
    #     if len(phi.shape) == 3:
    #         phi = np.mean(phi, -1)
    #     show(X[0][:,size_repetitions:2*size_repetitions,:])
    # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
    # plt.savefig(f'phi_6.png')
    # avg.append(f'phi_6.png')

    print('Done')
    if len(explanations)==1:
        explanations = explanations[0]
    # return explanations,avg
    return explanations