Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
from tensorflow.keras.preprocessing import image | |
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 as keras_model | |
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input, decode_predictions | |
import matplotlib.pyplot as plt | |
from alibi.explainers import IntegratedGradients | |
from alibi.datasets import load_cats | |
from alibi.utils.visualization import visualize_image_attr | |
import numpy as np | |
from PIL import Image, ImageFilter | |
import io | |
import time | |
import os | |
import copy | |
import pickle | |
import datetime | |
import urllib.request | |
import gradio as gr | |
url = "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg" | |
path_input = "./cat.jpg" | |
urllib.request.urlretrieve(url, filename=path_input) | |
url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg" | |
path_input = "./dog.jpg" | |
urllib.request.urlretrieve(url, filename=path_input) | |
model = keras_model(weights='imagenet') | |
n_steps = 50 | |
method = "gausslegendre" | |
internal_batch_size = 50 | |
ig = IntegratedGradients(model, | |
n_steps=n_steps, | |
method=method, | |
internal_batch_size=internal_batch_size) | |
def do_process(img, baseline): | |
instance = image.img_to_array(img) | |
instance = np.expand_dims(instance, axis=0) | |
instance = preprocess_input(instance) | |
preds = model.predict(instance) | |
lstPreds = decode_predictions(preds, top=3)[0] | |
dctPreds = {lstPreds[i][1]: round(float(lstPreds[i][2]),2) for i in range(len(lstPreds))} | |
predictions = preds.argmax(axis=1) | |
if baseline == 'white': | |
baselines = bls = np.ones(instance.shape).astype(instance.dtype) | |
img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255)) | |
elif baseline == 'black': | |
baselines = bls = np.zeros(instance.shape).astype(instance.dtype) | |
img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255)) | |
elif baseline == 'blur': | |
img_flt = img.filter(ImageFilter.GaussianBlur(5)) | |
baselines = image.img_to_array(img_flt) | |
baselines = np.expand_dims(baselines, axis=0) | |
baselines = preprocess_input(baselines) | |
else: | |
baselines = np.random.random_sample(instance.shape).astype(instance.dtype) | |
img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255)) | |
explanation = ig.explain(instance, | |
baselines=baselines, | |
target=predictions) | |
attrs = explanation.attributions[0] | |
fig, ax = visualize_image_attr(attr=attrs.squeeze(), original_image=img, method='blended_heat_map', | |
sign='all', show_colorbar=True, title=baseline, | |
plt_fig_axis=None, use_pyplot=False) | |
fig.tight_layout() | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img_res = Image.open(buf) | |
return img_res, img_flt, dctPreds | |
input_im = gr.inputs.Image(shape=(224, 224), image_mode='RGB', | |
invert_colors=False, source="upload", | |
type="pil") | |
input_drop = gr.inputs.Dropdown(label='Baseline (default: random)', | |
choices=['random', 'black', 'white', 'blur'], default='random', type='value') | |
output_img = gr.outputs.Image(label='Output of Integrated Gradients', type='pil') | |
output_base = gr.outputs.Image(label='Baseline image', type='pil') | |
output_label = gr.outputs.Label(label='Classification results', num_top_classes=3) | |
title = "XAI - Integrated gradients" | |
description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio." | |
examples = [['./cat.jpg', 'blur'],['./dog.jpg', 'random']] | |
article="<p style='text-align: center'><a href='https://github.com/mawady/colab-recipes-cv' target='_blank'>Colab recipes for computer vision - Dr. Mohamed Elawady</a></p>" | |
iface = gr.Interface( | |
fn=do_process, | |
inputs=[input_im, input_drop], | |
outputs=[output_img,output_base,output_label], | |
live=False, | |
interpretation=None, | |
title=title, | |
description=description, | |
article=article, | |
examples=examples | |
) | |
iface.launch(share=True) | |