fossil_app / explanations.py
andy-wyx's picture
show more xai output
dd58475
raw
history blame
3.88 kB
import xplique
import tensorflow as tf
from xplique.attributions import (Saliency, GradientInput, IntegratedGradients, SmoothGrad, VarGrad,
SquareGrad, GradCAM, Occlusion, Rise, GuidedBackprop,
GradCAMPP, Lime, KernelShap,SobolAttributionMethod,HsicAttributionMethod)
from xplique.attributions.global_sensitivity_analysis import LatinHypercube
import numpy as np
import matplotlib.pyplot as plt
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
BATCH_SIZE = 1
def show(img, p=False, **kwargs):
img = np.array(img, dtype=np.float32)
# check if channel first
if img.shape[0] == 1:
img = img[0]
# check if cmap
if img.shape[-1] == 1:
img = img[:,:,0]
elif img.shape[-1] == 3:
img = img[:,:,::-1]
# normalize
if img.max() > 1 or img.min() < 0:
img -= img.min(); img/=img.max()
# check if clip percentile
if p is not False:
img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
plt.imshow(img, **kwargs)
plt.axis('off')
#return img
def explain(model, input_image,size=600, n_classes=171) :
"""
Generate explanations for a given model and dataset.
:param model: The model to explain.
:param X: The dataset.
:param Y: The labels.
:param explainer: The explainer to use.
:param batch_size: The batch size to use.
:return: The explanations.
"""
# we only need the classification part of the model
class_model = tf.keras.Model(model.input, model.output[1])
explainers = [
#Sobol, RISE, HSIC, Saliency
#IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
#SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
#GradCAM(class_model),
SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
preservation_probability=0.5),
HsicAttributionMethod(class_model,
grid_size=7, nb_design=1500,
sampler = LatinHypercube(binary=True)),
Saliency(class_model),
#
]
cropped,repetitions = _clever_crop(input_image,(size,size))
size_repetitions = int(size//(repetitions.numpy()+1))
X = preprocess(cropped,size=size)
predictions = class_model.predict(np.array([X]))
#Y = np.argmax(predictions)
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
#print(top_5_indices)
X = np.expand_dims(X, 0)
explanations = []
for e,explainer in enumerate(explainers):
print(f'{e}/{len(explainers)}')
for i,Y in enumerate(top_5_indices):
Y = tf.one_hot([Y], n_classes)
print(f'{i}/{len(top_5_indices)}')
phi = np.abs(explainer(X, Y))[0]
if len(phi.shape) == 3:
phi = np.mean(phi, -1)
show(X[0][:,size_repetitions:2*size_repetitions,:])
show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
plt.savefig(f'phi_{e}{i}.png')
explanations.append(f'phi_{e}{i}.png')
# avg=[]
# for i,Y in enumerate(top_5_indices):
# Y = tf.one_hot([Y], n_classes)
# print(f'{i}/{len(top_5_indices)}')
# phi = np.abs(explainer(X, Y))[0]
# if len(phi.shape) == 3:
# phi = np.mean(phi, -1)
# show(X[0][:,size_repetitions:2*size_repetitions,:])
# show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
# plt.savefig(f'phi_6.png')
# avg.append(f'phi_6.png')
print('Done')
if len(explanations)==1:
explanations = explanations[0]
# return explanations,avg
return explanations