File size: 4,054 Bytes
f740d84
 
d57d2f2
 
f740d84
 
 
 
 
9c45667
f740d84
 
 
 
 
 
 
 
 
 
 
11bce97
f740d84
 
 
11bce97
f740d84
 
d57d2f2
f740d84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c45667
f740d84
9c45667
 
f740d84
9c45667
 
 
 
 
 
f740d84
 
9c45667
f740d84
 
 
 
 
9c45667
 
dbacfc1
f740d84
 
 
 
9c45667
 
8057eaa
f740d84
 
 
9c45667
f740d84
9730c7f
9c45667
9730c7f
f740d84
 
 
df2ec0f
f740d84
 
 
 
7b1981e
f740d84
 
 
 
 
 
 
 
b037fad
f740d84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 as keras_model
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input, decode_predictions
import matplotlib.pyplot as plt
from alibi.explainers import IntegratedGradients
from alibi.datasets import load_cats
from alibi.utils.visualization import visualize_image_attr
import numpy as np
from PIL import Image, ImageFilter
import io
import time
import os
import copy
import pickle
import datetime
import urllib.request
import gradio as gr


url = "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
path_input = "./cat.jpg"
urllib.request.urlretrieve(url, filename=path_input)

url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
path_input = "./dog.jpg"
urllib.request.urlretrieve(url, filename=path_input)

model = keras_model(weights='imagenet')

n_steps = 50
method = "gausslegendre"
internal_batch_size = 50
ig  = IntegratedGradients(model,
                          n_steps=n_steps,
                          method=method,
                          internal_batch_size=internal_batch_size)

def do_process(img, baseline):
  instance = image.img_to_array(img)
  instance = np.expand_dims(instance, axis=0)
  instance = preprocess_input(instance)
  preds = model.predict(instance)
  lstPreds = decode_predictions(preds, top=3)[0]
  dctPreds = {lstPreds[i][1]: round(float(lstPreds[i][2]),2) for i in range(len(lstPreds))}
  predictions = preds.argmax(axis=1)
  if baseline == 'white':
    baselines = bls = np.ones(instance.shape).astype(instance.dtype)
    img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255))
  elif baseline == 'black':
    baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
    img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255))
  elif baseline == 'blur':
    img_flt = img.filter(ImageFilter.GaussianBlur(5))
    baselines = image.img_to_array(img_flt)
    baselines = np.expand_dims(baselines, axis=0)
    baselines = preprocess_input(baselines)
  else:
    baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
    img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255))
  explanation = ig.explain(instance,
                         baselines=baselines,
                         target=predictions)
  attrs = explanation.attributions[0]
  fig, ax = visualize_image_attr(attr=attrs.squeeze(), original_image=img, method='blended_heat_map',
                      sign='all', show_colorbar=True, title=baseline,
                      plt_fig_axis=None, use_pyplot=False)
  fig.tight_layout()
  buf = io.BytesIO()
  fig.savefig(buf)
  buf.seek(0)
  img_res = Image.open(buf)
  return img_res, img_flt, dctPreds
  
input_im = gr.inputs.Image(shape=(224, 224), image_mode='RGB', 
                         invert_colors=False, source="upload",
                         type="pil")
input_drop = gr.inputs.Dropdown(label='Baseline (default: random)',
	                  choices=['random', 'black', 'white', 'blur'], default='random', type='value')

output_img = gr.outputs.Image(label='Output of Integrated Gradients', type='pil')
output_base = gr.outputs.Image(label='Baseline image', type='pil')
output_label = gr.outputs.Label(label='Classification results', num_top_classes=3)

title = "XAI - Integrated gradients"
description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio."
examples = [['./cat.jpg', 'blur'],['./dog.jpg', 'random']]
article="<p style='text-align: center'><a href='https://github.com/mawady/colab-recipes-cv' target='_blank'>Colab recipes for computer vision - Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
    fn=do_process, 
    inputs=[input_im, input_drop], 
    outputs=[output_img,output_base,output_label],
    live=False,
    interpretation=None,
    title=title,
    description=description,
    article=article,
    examples=examples
)

iface.launch(debug=True)