import tensorflow as tf from tensorflow.keras.preprocessing import image from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 as keras_model from tensorflow.keras.applications.mobilenet_v2 import preprocess_input, decode_predictions import matplotlib.pyplot as plt from alibi.explainers import IntegratedGradients from alibi.datasets import load_cats from alibi.utils.visualization import visualize_image_attr import numpy as np from PIL import Image, ImageFilter import io import time import os import copy import pickle import datetime import urllib.request import gradio as gr url = "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg" path_input = "./cat.jpg" urllib.request.urlretrieve(url, filename=path_input) url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg" path_input = "./dog.jpg" urllib.request.urlretrieve(url, filename=path_input) model = keras_model(weights='imagenet') n_steps = 50 method = "gausslegendre" internal_batch_size = 50 ig = IntegratedGradients(model, n_steps=n_steps, method=method, internal_batch_size=internal_batch_size) def do_process(img, baseline): instance = image.img_to_array(img) instance = np.expand_dims(instance, axis=0) instance = preprocess_input(instance) preds = model.predict(instance) lstPreds = decode_predictions(preds, top=3)[0] dctPreds = {lstPreds[i][1]: round(float(lstPreds[i][2]),2) for i in range(len(lstPreds))} predictions = preds.argmax(axis=1) if baseline == 'white': baselines = bls = np.ones(instance.shape).astype(instance.dtype) img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255)) elif baseline == 'black': baselines = bls = np.zeros(instance.shape).astype(instance.dtype) img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255)) elif baseline == 'blur': img_flt = img.filter(ImageFilter.GaussianBlur(5)) baselines = image.img_to_array(img_flt) baselines = np.expand_dims(baselines, axis=0) baselines = preprocess_input(baselines) else: baselines = np.random.random_sample(instance.shape).astype(instance.dtype) img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255)) explanation = ig.explain(instance, baselines=baselines, target=predictions) attrs = explanation.attributions[0] fig, ax = visualize_image_attr(attr=attrs.squeeze(), original_image=img, method='blended_heat_map', sign='all', show_colorbar=True, title=baseline, plt_fig_axis=None, use_pyplot=False) fig.tight_layout() buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img_res = Image.open(buf) return img_res, img_flt, dctPreds input_im = gr.inputs.Image(shape=(224, 224), image_mode='RGB', invert_colors=False, source="upload", type="pil") input_drop = gr.inputs.Dropdown(label='Baseline (default: random)', choices=['random', 'black', 'white', 'blur'], default='random', type='value') output_img = gr.outputs.Image(label='Output of Integrated Gradients', type='pil') output_base = gr.outputs.Image(label='Baseline image', type='pil') output_label = gr.outputs.Label(label='Classification results', num_top_classes=3) title = "XAI - Integrated gradients" description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio." examples = [['./cat.jpg', 'blur'],['./dog.jpg', 'random']] article="

Colab recipes for computer vision - Dr. Mohamed Elawady

" iface = gr.Interface( fn=do_process, inputs=[input_im, input_drop], outputs=[output_img,output_base,output_label], live=False, interpretation=None, title=title, description=description, article=article, examples=examples ) iface.launch(share=True)