mawady commited on
Commit
dfbc387
1 Parent(s): 66b4c88

format code

Browse files
Files changed (1) hide show
  1. app.py +73 -57
app.py CHANGED
@@ -1,7 +1,10 @@
1
  import tensorflow as tf
2
  from tensorflow.keras.preprocessing import image
3
  from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 as keras_model
4
- from tensorflow.keras.applications.mobilenet_v2 import preprocess_input, decode_predictions
 
 
 
5
  import matplotlib.pyplot as plt
6
  from alibi.explainers import IntegratedGradients
7
  from alibi.datasets import load_cats
@@ -18,7 +21,9 @@ import urllib.request
18
  import gradio as gr
19
 
20
 
21
- url = "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
 
 
22
  path_input = "./cat.jpg"
23
  urllib.request.urlretrieve(url, filename=path_input)
24
 
@@ -26,77 +31,88 @@ url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
26
  path_input = "./dog.jpg"
27
  urllib.request.urlretrieve(url, filename=path_input)
28
 
29
- model = keras_model(weights='imagenet')
30
 
31
  n_steps = 50
32
  method = "gausslegendre"
33
  internal_batch_size = 50
34
- ig = IntegratedGradients(model,
35
- n_steps=n_steps,
36
- method=method,
37
- internal_batch_size=internal_batch_size)
38
 
39
  def do_process(img, baseline):
40
- instance = image.img_to_array(img)
41
- instance = np.expand_dims(instance, axis=0)
42
- instance = preprocess_input(instance)
43
- preds = model.predict(instance)
44
- lstPreds = decode_predictions(preds, top=3)[0]
45
- dctPreds = {lstPreds[i][1]: round(float(lstPreds[i][2]),2) for i in range(len(lstPreds))}
46
- predictions = preds.argmax(axis=1)
47
- if baseline == 'white':
48
- baselines = bls = np.ones(instance.shape).astype(instance.dtype)
49
- img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255))
50
- elif baseline == 'black':
51
- baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
52
- img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255))
53
- elif baseline == 'blur':
54
- img_flt = img.filter(ImageFilter.GaussianBlur(5))
55
- baselines = image.img_to_array(img_flt)
56
- baselines = np.expand_dims(baselines, axis=0)
57
- baselines = preprocess_input(baselines)
58
- else:
59
- baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
60
- img_flt = Image.fromarray(np.uint8(np.squeeze(baselines)*255))
61
- explanation = ig.explain(instance,
62
- baselines=baselines,
63
- target=predictions)
64
- attrs = explanation.attributions[0]
65
- fig, ax = visualize_image_attr(attr=attrs.squeeze(), original_image=img, method='blended_heat_map',
66
- sign='all', show_colorbar=True, title=baseline,
67
- plt_fig_axis=None, use_pyplot=False)
68
- fig.tight_layout()
69
- buf = io.BytesIO()
70
- fig.savefig(buf)
71
- buf.seek(0)
72
- img_res = Image.open(buf)
73
- return img_res, img_flt, dctPreds
74
-
75
- input_im = gr.inputs.Image(shape=(224, 224), image_mode='RGB',
76
- invert_colors=False, source="upload",
77
- type="pil")
78
- input_drop = gr.inputs.Dropdown(label='Baseline (default: random)',
79
- choices=['random', 'black', 'white', 'blur'], default='random', type='value')
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- output_img = gr.outputs.Image(label='Output of Integrated Gradients', type='pil')
82
- output_base = gr.outputs.Image(label='Baseline image', type='pil')
83
- output_label = gr.outputs.Label(label='Classification results', num_top_classes=3)
84
 
85
  title = "XAI - Integrated gradients"
86
  description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio."
87
- examples = [['./cat.jpg', 'blur'],['./dog.jpg', 'random']]
88
- article="<p style='text-align: center'><a href='https://github.com/mawady/colab-recipes-cv' target='_blank'>Colab recipes for computer vision - Dr. Mohamed Elawady</a></p>"
89
  iface = gr.Interface(
90
- fn=do_process,
91
- inputs=[input_im, input_drop],
92
- outputs=[output_img,output_base,output_label],
93
  live=False,
94
  interpretation=None,
95
  title=title,
96
  description=description,
97
  article=article,
98
- examples=examples
99
  )
100
 
101
  iface.launch(debug=True)
102
-
 
1
  import tensorflow as tf
2
  from tensorflow.keras.preprocessing import image
3
  from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 as keras_model
4
+ from tensorflow.keras.applications.mobilenet_v2 import (
5
+ preprocess_input,
6
+ decode_predictions,
7
+ )
8
  import matplotlib.pyplot as plt
9
  from alibi.explainers import IntegratedGradients
10
  from alibi.datasets import load_cats
 
21
  import gradio as gr
22
 
23
 
24
+ url = (
25
+ "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
26
+ )
27
  path_input = "./cat.jpg"
28
  urllib.request.urlretrieve(url, filename=path_input)
29
 
 
31
  path_input = "./dog.jpg"
32
  urllib.request.urlretrieve(url, filename=path_input)
33
 
34
+ model = keras_model(weights="imagenet")
35
 
36
  n_steps = 50
37
  method = "gausslegendre"
38
  internal_batch_size = 50
39
+ ig = IntegratedGradients(
40
+ model, n_steps=n_steps, method=method, internal_batch_size=internal_batch_size
41
+ )
42
+
43
 
44
  def do_process(img, baseline):
45
+ instance = image.img_to_array(img)
46
+ instance = np.expand_dims(instance, axis=0)
47
+ instance = preprocess_input(instance)
48
+ preds = model.predict(instance)
49
+ lstPreds = decode_predictions(preds, top=3)[0]
50
+ dctPreds = {
51
+ lstPreds[i][1]: round(float(lstPreds[i][2]), 2) for i in range(len(lstPreds))
52
+ }
53
+ predictions = preds.argmax(axis=1)
54
+ if baseline == "white":
55
+ baselines = bls = np.ones(instance.shape).astype(instance.dtype)
56
+ img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
57
+ elif baseline == "black":
58
+ baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
59
+ img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
60
+ elif baseline == "blur":
61
+ img_flt = img.filter(ImageFilter.GaussianBlur(5))
62
+ baselines = image.img_to_array(img_flt)
63
+ baselines = np.expand_dims(baselines, axis=0)
64
+ baselines = preprocess_input(baselines)
65
+ else:
66
+ baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
67
+ img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
68
+ explanation = ig.explain(instance, baselines=baselines, target=predictions)
69
+ attrs = explanation.attributions[0]
70
+ fig, ax = visualize_image_attr(
71
+ attr=attrs.squeeze(),
72
+ original_image=img,
73
+ method="blended_heat_map",
74
+ sign="all",
75
+ show_colorbar=True,
76
+ title=baseline,
77
+ plt_fig_axis=None,
78
+ use_pyplot=False,
79
+ )
80
+ fig.tight_layout()
81
+ buf = io.BytesIO()
82
+ fig.savefig(buf)
83
+ buf.seek(0)
84
+ img_res = Image.open(buf)
85
+ return img_res, img_flt, dctPreds
86
+
87
+
88
+ input_im = gr.inputs.Image(
89
+ shape=(224, 224), image_mode="RGB", invert_colors=False, source="upload", type="pil"
90
+ )
91
+ input_drop = gr.inputs.Dropdown(
92
+ label="Baseline (default: random)",
93
+ choices=["random", "black", "white", "blur"],
94
+ default="random",
95
+ type="value",
96
+ )
97
 
98
+ output_img = gr.outputs.Image(label="Output of Integrated Gradients", type="pil")
99
+ output_base = gr.outputs.Image(label="Baseline image", type="pil")
100
+ output_label = gr.outputs.Label(label="Classification results", num_top_classes=3)
101
 
102
  title = "XAI - Integrated gradients"
103
  description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio."
104
+ examples = [["./cat.jpg", "blur"], ["./dog.jpg", "random"]]
105
+ article = "<p style='text-align: center'><a href='https://github.com/mawady' target='_blank'>By Dr. Mohamed Elawady</a></p>"
106
  iface = gr.Interface(
107
+ fn=do_process,
108
+ inputs=[input_im, input_drop],
109
+ outputs=[output_img, output_base, output_label],
110
  live=False,
111
  interpretation=None,
112
  title=title,
113
  description=description,
114
  article=article,
115
+ examples=examples,
116
  )
117
 
118
  iface.launch(debug=True)