File size: 3,988 Bytes
f740d84
 
d57d2f2
dfbc387
 
 
 
f740d84
 
 
 
 
9c45667
f740d84
 
 
 
 
 
 
 
 
 
dfbc387
 
 
11bce97
f740d84
 
 
11bce97
f740d84
 
dfbc387
f740d84
 
 
 
dfbc387
 
 
 
f740d84
 
dfbc387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f740d84
dfbc387
 
 
f740d84
 
 
dfbc387
 
f740d84
dfbc387
 
 
f740d84
 
 
 
 
dfbc387
f740d84
 
b037fad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 as keras_model
from tensorflow.keras.applications.mobilenet_v2 import (
    preprocess_input,
    decode_predictions,
)
import matplotlib.pyplot as plt
from alibi.explainers import IntegratedGradients
from alibi.datasets import load_cats
from alibi.utils.visualization import visualize_image_attr
import numpy as np
from PIL import Image, ImageFilter
import io
import time
import os
import copy
import pickle
import datetime
import urllib.request
import gradio as gr


url = (
    "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
)
path_input = "./cat.jpg"
urllib.request.urlretrieve(url, filename=path_input)

url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
path_input = "./dog.jpg"
urllib.request.urlretrieve(url, filename=path_input)

model = keras_model(weights="imagenet")

n_steps = 50
method = "gausslegendre"
internal_batch_size = 50
ig = IntegratedGradients(
    model, n_steps=n_steps, method=method, internal_batch_size=internal_batch_size
)


def do_process(img, baseline):
    instance = image.img_to_array(img)
    instance = np.expand_dims(instance, axis=0)
    instance = preprocess_input(instance)
    preds = model.predict(instance)
    lstPreds = decode_predictions(preds, top=3)[0]
    dctPreds = {
        lstPreds[i][1]: round(float(lstPreds[i][2]), 2) for i in range(len(lstPreds))
    }
    predictions = preds.argmax(axis=1)
    if baseline == "white":
        baselines = bls = np.ones(instance.shape).astype(instance.dtype)
        img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
    elif baseline == "black":
        baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
        img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
    elif baseline == "blur":
        img_flt = img.filter(ImageFilter.GaussianBlur(5))
        baselines = image.img_to_array(img_flt)
        baselines = np.expand_dims(baselines, axis=0)
        baselines = preprocess_input(baselines)
    else:
        baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
        img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
    explanation = ig.explain(instance, baselines=baselines, target=predictions)
    attrs = explanation.attributions[0]
    fig, ax = visualize_image_attr(
        attr=attrs.squeeze(),
        original_image=img,
        method="blended_heat_map",
        sign="all",
        show_colorbar=True,
        title=baseline,
        plt_fig_axis=None,
        use_pyplot=False,
    )
    fig.tight_layout()
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img_res = Image.open(buf)
    return img_res, img_flt, dctPreds


input_im = gr.inputs.Image(
    shape=(224, 224), image_mode="RGB", invert_colors=False, source="upload", type="pil"
)
input_drop = gr.inputs.Dropdown(
    label="Baseline (default: random)",
    choices=["random", "black", "white", "blur"],
    default="random",
    type="value",
)

output_img = gr.outputs.Image(label="Output of Integrated Gradients", type="pil")
output_base = gr.outputs.Image(label="Baseline image", type="pil")
output_label = gr.outputs.Label(label="Classification results", num_top_classes=3)

title = "XAI - Integrated gradients"
description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio."
examples = [["./cat.jpg", "blur"], ["./dog.jpg", "random"]]
article = "<p style='text-align: center'><a href='https://github.com/mawady' target='_blank'>By Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
    fn=do_process,
    inputs=[input_im, input_drop],
    outputs=[output_img, output_base, output_label],
    live=False,
    interpretation=None,
    title=title,
    description=description,
    article=article,
    examples=examples,
)

iface.launch(debug=True)