barry-ravichandran commited on
Commit
12c6662
·
1 Parent(s): f15fe03

Add app and requirements files

Browse files
Files changed (3) hide show
  1. app.py +596 -0
  2. gr_component_state.py +103 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---
2
+ # jupyter:
3
+ # jupytext:
4
+ # text_representation:
5
+ # extension: .py
6
+ # format_name: light
7
+ # format_version: '1.5'
8
+ # jupytext_version: 1.15.2
9
+ # kernelspec:
10
+ # display_name: Python 3
11
+ # language: python
12
+ # name: python3
13
+ # ---
14
+
15
+ # # Gradio Example <a name="XAITK-Saliency-Gradio-Example"></a>
16
+ # This notebook makes use of the saliency generation example found in the base ``xaitk-saliency`` repo [here](https://github.com/XAITK/xaitk-saliency/blob/master/examples/OcclusionSaliency.ipynb), and explores integrating ``xaitk-saliency`` with ``Gradio`` to create an interactive interface for computing saliency maps.
17
+ #
18
+ # ## Test Image <a name="Test-Image-Gradio"></a>
19
+
20
+ # +
21
+ import os
22
+ import PIL.Image
23
+ import matplotlib.pyplot as plt # type: ignore
24
+ import urllib
25
+ import numpy as np
26
+
27
+ import gradio as gr
28
+ from gradio import ( # type: ignore
29
+ AnnotatedImage, Button, Column, Image, Label, # type: ignore
30
+ Number, Plot, Row, TabItem, Tab, Tabs # type: ignore
31
+ )
32
+ from gradio import components as gr_components # type: ignore
33
+
34
+ # +
35
+ # State variables for Image Classification
36
+ from gr_component_state import ( # type: ignore
37
+ img_cls_model_name, img_cls_saliency_algo_name, window_size_state, stride_state, debiased_state,
38
+ )
39
+
40
+ # State functions for Image Classification
41
+ from gr_component_state import ( # type: ignore
42
+ select_img_cls_model, select_img_cls_saliency_algo, enter_window_size, enter_stride, check_debiased
43
+ )
44
+
45
+ # State variables for Object Detection
46
+ from gr_component_state import ( # type: ignore
47
+ obj_det_model_name, obj_det_saliency_algo_name, occlusion_grid_state
48
+ )
49
+
50
+ # State functions for Object Detection
51
+ from gr_component_state import ( # type: ignore
52
+ select_obj_det_model, select_obj_det_saliency_algo, enter_occlusion_grid_size
53
+ )
54
+
55
+ # Common state variables
56
+ from gr_component_state import ( # type: ignore
57
+ threads_state, num_masks_state, spatial_res_state, p1_state, seed_state
58
+ )
59
+
60
+ # Common state functions
61
+ from gr_component_state import ( # type: ignore
62
+ select_threads, enter_num_masks, enter_spatial_res, select_p1, enter_seed
63
+ )
64
+
65
+ import torch
66
+ import torchvision.transforms as transforms
67
+ import torchvision.models as models
68
+
69
+ from smqtk_detection.impls.detect_image_objects.resnet_frcnn import ResNetFRCNN
70
+ from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.slidingwindow import SlidingWindowStack
71
+ from xaitk_saliency.impls.gen_image_classifier_blackbox_sal.rise import RISEStack
72
+ from xaitk_saliency.impls.gen_object_detector_blackbox_sal.drise import RandomGridStack, DRISEStack
73
+
74
+ import torch.nn.functional
75
+ from smqtk_classifier.interfaces.classify_image import ClassifyImage
76
+
77
+ import numpy as np
78
+ from gradio import ( # type: ignore
79
+ Checkbox, Dropdown, SelectData, Slider, Textbox # type: ignore
80
+ )
81
+ from gradio import processing_utils as gr_processing_utils # type: ignore
82
+ from xaitk_saliency.interfaces.gen_object_detector_blackbox_sal import GenerateObjectDetectorBlackboxSaliency
83
+ from smqtk_detection.interfaces.detect_image_objects import DetectImageObjects
84
+
85
+ # Use JPEG format for inline visualizations here.
86
+ # %config InlineBackend.figure_format = "jpeg"
87
+
88
+ os.makedirs('data', exist_ok=True)
89
+ test_image_filename = 'data/catdog.jpg'
90
+ urllib.request.urlretrieve('https://farm1.staticflickr.com/74/202734059_fcce636dcd_z.jpg', test_image_filename)
91
+ plt.figure(figsize=(12, 8))
92
+ plt.axis('off')
93
+ _ = plt.imshow(PIL.Image.open(test_image_filename))
94
+ # -
95
+
96
+ # ## Initialize state variables for Gradio components <a name="Global-State-Gradio"></a>
97
+ # Gradio expects either a list or dict format to maintain state variables based on the use case. The cell below initializes the state variables from the ``gr_component_state.py`` file for the various components in our gradio demo.
98
+
99
+
100
+
101
+ # ## Helper Functions <a name="Helper-Functions-Gradio"></a>
102
+ # The functions defined in the cell below are used to set up the model, saliency algorithm, class labels and image transforms needed for the demo.
103
+
104
+ CUDA_AVAILABLE = torch.cuda.is_available()
105
+
106
+ model_input_size = (224, 224)
107
+ model_mean = [0.485, 0.456, 0.406]
108
+ model_loader = transforms.Compose([
109
+ transforms.ToPILImage(),
110
+ transforms.Resize(model_input_size),
111
+ transforms.ToTensor(),
112
+ transforms.Normalize(
113
+ mean=model_mean,
114
+ std=[0.229, 0.224, 0.225]
115
+ ),
116
+ ])
117
+
118
+ def get_sal_labels(classes_file, custom_categories_list=None):
119
+ if not os.path.isfile(classes_file):
120
+ url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
121
+ _ = urllib.request.urlretrieve(url, classes_file)
122
+
123
+ f = open(classes_file, "r")
124
+ categories = [s.strip() for s in f.readlines()]
125
+
126
+ if not custom_categories_list == None:
127
+ sal_class_labels = custom_categories_list
128
+ else:
129
+ sal_class_labels = categories
130
+
131
+ sal_class_idxs = [categories.index(lbl) for lbl in sal_class_labels]
132
+
133
+ return sal_class_labels, sal_class_idxs
134
+
135
+ def get_det_sal_labels(classes_file, custom_categories_list=None):
136
+ if not os.path.isfile(classes_file):
137
+ url = "https://raw.githubusercontent.com/matlab-deep-learning/Object-Detection-Using-Pretrained-YOLO-v2/main/%2Bhelper/coco-classes.txt"
138
+ _ = urllib.request.urlretrieve(url, classes_file)
139
+
140
+ f = open(classes_file, "r")
141
+ categories = [s.strip() for s in f.readlines()]
142
+
143
+ if not custom_categories_list == None:
144
+ sal_obj_labels = custom_categories_list
145
+ else:
146
+ sal_obj_labels = categories
147
+
148
+ sal_obj_idxs = [categories.index(lbl) for lbl in sal_obj_labels]
149
+
150
+ return sal_obj_labels, sal_obj_idxs
151
+
152
+ def get_model(model_choice):
153
+ if model_choice == "ResNet-18":
154
+ model = models.resnet18(pretrained=True)
155
+ else:
156
+ model = models.resnet50(pretrained=True)
157
+ model = model.eval()
158
+ if CUDA_AVAILABLE:
159
+ model = model.cuda()
160
+
161
+ return model
162
+
163
+ def get_detection_model(model_choice):
164
+
165
+ if model_choice == "Faster-RCNN":
166
+ blackbox_detector = ResNetFRCNN(
167
+ box_thresh=0.05,
168
+ img_batch_size=1,
169
+ use_cuda=True
170
+ )
171
+
172
+ else:
173
+ raise Exception("Unknown Input")
174
+
175
+ return blackbox_detector
176
+
177
+ def get_saliency_algo(sal_choice):
178
+ if sal_choice == "RISE":
179
+ gen_sal = RISEStack(
180
+ n=num_masks_state[-1],
181
+ s=spatial_res_state[-1],
182
+ p1=p1_state[-1],
183
+ seed=seed_state[-1],
184
+ threads=threads_state[-1],
185
+ debiased=debiased_state[-1]
186
+ )
187
+
188
+ elif sal_choice == "SlidingWindowStack":
189
+ gen_sal = SlidingWindowStack(
190
+ window_size=eval(window_size_state[-1]),
191
+ stride=eval(stride_state[-1]),
192
+ threads=threads_state[-1]
193
+ )
194
+
195
+ else:
196
+ raise Exception("Unknown Input")
197
+
198
+ return gen_sal
199
+
200
+ def get_detection_saliency_algo(sal_choice):
201
+ if sal_choice == "RandomGridStack":
202
+ gen_sal = RandomGridStack(
203
+ n=num_masks_state[-1],
204
+ s=eval(occlusion_grid_state[-1]),
205
+ p1=p1_state[-1],
206
+ threads=threads_state[-1],
207
+ seed=seed_state[-1],
208
+ )
209
+
210
+ elif sal_choice == "DRISE":
211
+ gen_sal = DRISEStack(
212
+ n=num_masks_state[-1],
213
+ s=spatial_res_state[-1],
214
+ p1=p1_state[-1],
215
+ seed=seed_state[-1],
216
+ threads=threads_state[-1]
217
+ )
218
+
219
+ else:
220
+ raise Exception("Unknown Input")
221
+
222
+ return gen_sal
223
+
224
+
225
+ data_path = "./data"
226
+ if not os.path.exists(data_path):
227
+ os.makedirs(data_path)
228
+
229
+ # Setup imagenet classes and ClassifyImage for generating classification saliency
230
+
231
+ classes_file = os.path.join(data_path,"imagenet_classes.txt")
232
+ sal_class_labels, sal_class_idxs = get_sal_labels(classes_file)
233
+
234
+ class TorchResnet (ClassifyImage):
235
+
236
+ modified_class_labels = []
237
+
238
+ def get_labels(self):
239
+ return self.modified_class_labels
240
+
241
+ def set_labels(self, class_labels):
242
+ self.modified_class_labels = [lbl for lbl in class_labels]
243
+
244
+ @torch.no_grad()
245
+ def classify_images(self, image_iter):
246
+ # Input may either be an NDaray, or some arbitrary iterable of NDarray images.
247
+
248
+ model = get_model(img_cls_model_name[-1])
249
+
250
+ for img in image_iter:
251
+ image_tensor = model_loader(img).unsqueeze(0)
252
+ if CUDA_AVAILABLE:
253
+ image_tensor = image_tensor.cuda()
254
+
255
+ feature_vec = model(image_tensor)
256
+ # Converting feature extractor output to probabilities.
257
+ class_conf = torch.nn.functional.softmax(feature_vec, dim=1).cpu().detach().numpy().squeeze()
258
+ # Only return the confidences for the focus classes
259
+ yield dict(zip(sal_class_labels, class_conf[sal_class_idxs]))
260
+
261
+ def get_config(self):
262
+ # Required by a parent class.
263
+ return {}
264
+
265
+ blackbox_classifier, blackbox_fill = TorchResnet(), np.uint8(np.asarray(model_mean) * 255).tolist()
266
+
267
+ # Setup COCO object classes for generating detection saliency
268
+
269
+ obj_classes_file = os.path.join(data_path,"coco_classes.txt")
270
+ sal_obj_labels, sal_obj_idxs = get_det_sal_labels(obj_classes_file)
271
+
272
+
273
+ # Modify textbox parameters based on chosen saliency algorithm
274
+ def show_textbox_parameters(choice):
275
+ if choice == 'RISE':
276
+ return Textbox.update(visible=False), Textbox.update(visible=False), Textbox.update(visible=True), Textbox.update(visible=True), Textbox.update(visible=True)
277
+ elif choice == 'SlidingWindowStack':
278
+ return Textbox.update(visible=True), Textbox.update(visible=True), Textbox.update(visible=False), Textbox.update(visible=False), Textbox.update(visible=False)
279
+ elif choice == "RandomGridStack":
280
+ return Textbox.update(visible=True), Textbox.update(visible=False), Textbox.update(visible=True), Textbox.update(visible=True)
281
+ elif choice == "DRISE":
282
+ return Textbox.update(visible=True), Textbox.update(visible=True), Textbox.update(visible=True), Textbox.update(visible=False)
283
+ else:
284
+ raise Exception("Unknown Input")
285
+
286
+ # Modify slider parameters based on chosen saliency algorithm
287
+ def show_slider_parameters(choice):
288
+ if choice == 'RISE' or choice == 'RandomGridStack' or choice == 'DRISE':
289
+ return Slider.update(visible=True), Slider.update(visible=True)
290
+ elif choice == 'SlidingWindowStack':
291
+ return Slider.update(visible=True), Slider.update(visible=False)
292
+ else:
293
+ raise Exception("Unknown Input")
294
+
295
+ # Modify checkbox parameters based on chosen saliency algorithm
296
+ def show_debiased_checkbox(choice):
297
+ if choice == 'RISE':
298
+ return Checkbox.update(visible=True)
299
+ elif choice == 'SlidingWindowStack' or choice == 'RandomGridStack' or choice == 'DRISE':
300
+ return Checkbox.update(visible=False)
301
+ else:
302
+ raise Exception("Unknown Input")
303
+
304
+ # Function that is called after clicking the "Classify" button in the demo
305
+ def predict(x,top_n_classes):
306
+
307
+ image_tensor = model_loader(x).unsqueeze(0)
308
+ if CUDA_AVAILABLE:
309
+ image_tensor = image_tensor.cuda()
310
+ model = get_model(img_cls_model_name[-1])
311
+ feature_vec = model(image_tensor)
312
+ class_conf = torch.nn.functional.softmax(feature_vec, dim=1).cpu().detach().numpy().squeeze()
313
+ labels = list(zip(sal_class_labels, class_conf[sal_class_idxs].tolist()))
314
+ final_labels = dict(sorted(labels, key=lambda t: t[1],reverse=True)[:top_n_classes])
315
+
316
+ return final_labels, Dropdown.update(choices=list(final_labels))
317
+
318
+ # Interpretation function for image classification that implements the selected saliency algorithm and generates the class-wise saliency map visualizations
319
+ def interpretation_function(image: np.ndarray,
320
+ labels: dict,
321
+ nth_class: str,
322
+ img_alpha,
323
+ sal_alpha,
324
+ sal_range_min,
325
+ sal_range_max):
326
+
327
+ sal_generator = get_saliency_algo(img_cls_saliency_algo_name[-1])
328
+ sal_generator.fill = blackbox_fill
329
+ labels_list = [i['label'] for i in labels['confidences']]
330
+ blackbox_classifier.set_labels(labels_list)
331
+ sal_maps = sal_generator(image, blackbox_classifier)
332
+ nth_class_index = blackbox_classifier.get_labels().index(nth_class)
333
+ scores = sal_maps[nth_class_index,:,:]
334
+ fig = visualize_saliency_plot(image,
335
+ sal_maps[nth_class_index,:,:],
336
+ img_alpha,
337
+ sal_alpha,
338
+ sal_range_min,
339
+ sal_range_max)
340
+
341
+ scores = np.clip(scores, sal_range_min, sal_range_max)
342
+
343
+ return {"original": gr_processing_utils.encode_array_to_base64(image),
344
+ "interpretation": scores.tolist()}, fig
345
+
346
+ def visualize_saliency_plot(image: np.ndarray,
347
+ class_sal_map: np.ndarray,
348
+ img_alpha,
349
+ sal_alpha,
350
+ sal_range_min,
351
+ sal_range_max):
352
+ colorbar_kwargs = {
353
+ "fraction": 0.046*(image.shape[0]/image.shape[1]),
354
+ "pad": 0.04,
355
+ }
356
+ fig = plt.figure()
357
+ plt.imshow(image, alpha=img_alpha)
358
+ plt.imshow(
359
+ np.clip(class_sal_map, sal_range_min, sal_range_max),
360
+ cmap='jet', alpha=sal_alpha
361
+ )
362
+ plt.clim(sal_range_min, sal_range_max)
363
+ plt.colorbar(**colorbar_kwargs)
364
+ plt.title(f"Saliency Map")
365
+ plt.axis('off')
366
+ plt.close(fig)
367
+
368
+ return fig
369
+
370
+ # Generate top-n object detect predictions on the input image
371
+ def run_detect(input_img: np.ndarray, num_detections: int):
372
+ detect_model = get_detection_model(obj_det_model_name[-1])
373
+ preds = list(list(detect_model([input_img]))[0])
374
+ n_preds = len(preds)
375
+ n_classes = len(preds[0][1])
376
+
377
+ bboxes = np.empty((n_preds, 4), dtype=np.float32)
378
+ scores = np.empty((n_preds, n_classes), dtype=np.float32)
379
+ max_scores_index = np.empty((n_preds, 1), dtype=int)
380
+ labels = None
381
+ final_bbox = []
382
+ final_label = []
383
+ for i, (bbox, score_dict) in enumerate(preds):
384
+ bboxes[i] = (*bbox.min_vertex, *bbox.max_vertex)
385
+ score_list = list(score_dict.values())
386
+ scores[i] = score_list
387
+ max_scores_index[i] = score_list.index(max(score_list))
388
+ if labels is None:
389
+ labels = list(score_dict.keys())
390
+ label_name = str(labels[int(max_scores_index[i,0])])
391
+ conf_score = str(round(score_list[int(max_scores_index[i,0])],4))
392
+ label_with_score = str(i) + " : "+ label_name + " - " + conf_score
393
+ final_label.append(label_with_score)
394
+
395
+ bboxes_list = bboxes[:,:].astype(int).tolist()
396
+
397
+ return (input_img, list(zip([f for f in bboxes_list], [l for l in final_label]))[:num_detections]), Dropdown.update(choices=[l for l in final_label][:num_detections])
398
+
399
+ # Run saliency algorithm on the object detect predictions and generate corresponding visualizations
400
+ def run_detect_saliency(input_img: np.ndarray,
401
+ num_predictions,
402
+ obj_label,
403
+ img_alpha,
404
+ sal_alpha,
405
+ sal_range_min,
406
+ sal_range_max):
407
+
408
+ detect_model = get_detection_model(obj_det_model_name[-1])
409
+ img_preds = list(list(detect_model([input_img]))[0])
410
+ ref_preds = img_preds[:int(num_predictions)]
411
+ ref_bboxes = []
412
+ ref_scores = []
413
+ for det in ref_preds:
414
+ bbox = det[0]
415
+ ref_bboxes.append([
416
+ *bbox.min_vertex,
417
+ *bbox.max_vertex,
418
+ ])
419
+
420
+ score_dict = det[1]
421
+ ref_scores.append(list(score_dict.values()))
422
+
423
+ ref_bboxes = np.array(ref_bboxes)
424
+ ref_scores = np.array(ref_scores)
425
+
426
+ print(f"Ref bboxes: {ref_bboxes.shape}")
427
+ print(f"Ref scores: {ref_scores.shape}")
428
+
429
+ sal_generator = get_detection_saliency_algo(obj_det_saliency_algo_name[-1])
430
+ sal_generator.fill = blackbox_fill
431
+
432
+ sal_maps = gen_det_saliency(input_img, detect_model, sal_generator,ref_bboxes,ref_scores)
433
+ print(f"Saliency maps: {sal_maps.shape}")
434
+
435
+ nth_class_index = int(obj_label.split(' : ')[0])
436
+ scores = sal_maps[nth_class_index,:,:]
437
+ fig = visualize_saliency_plot(input_img,
438
+ sal_maps[nth_class_index,:,:],
439
+ img_alpha,
440
+ sal_alpha,
441
+ sal_range_min,
442
+ sal_range_max)
443
+
444
+ scores = np.clip(scores, sal_range_min, sal_range_max)
445
+
446
+ return fig
447
+
448
+ def gen_det_saliency(input_img: np.ndarray,
449
+ blackbox_detector: DetectImageObjects,
450
+ sal_map_generator: GenerateObjectDetectorBlackboxSaliency,
451
+ ref_bboxes: np.ndarray,
452
+ ref_scores: np.ndarray
453
+ ):
454
+ sal_maps = sal_map_generator.generate(
455
+ input_img,
456
+ ref_bboxes,
457
+ ref_scores,
458
+ blackbox_detector,
459
+ )
460
+
461
+ return sal_maps
462
+
463
+ # Event handler that populates the dropdown list of classes based on the Label/AnnotatedImage components' output
464
+ def map_labels(evt: SelectData):
465
+
466
+ return str(evt.value)
467
+
468
+ with gr.Blocks() as demo:
469
+ with Tab("Image Classification"):
470
+ with Row():
471
+ with Column(scale=0.5):
472
+ drop_list = Dropdown(value=img_cls_model_name[-1],choices=["ResNet-18","ResNet-50"],label="Choose Model",interactive="True")
473
+ with Column(scale=0.5):
474
+ drop_list_sal = Dropdown(value=img_cls_saliency_algo_name[-1],choices=["SlidingWindowStack","RISE"],label="Choose Saliency Algorithm",interactive="True")
475
+ with Row():
476
+ with Column(scale=0.33):
477
+ window_size = Textbox(value=window_size_state[-1],label="Tuple of window size values (Press Enter to submit the input)",interactive=True,visible=False)
478
+ masks = Number(value=num_masks_state[-1],label="Number of Random Masks (Press Enter to submit the input)",interactive=True,visible=False,precision=0)
479
+ with Column(scale=0.33):
480
+ stride = Textbox(value=stride_state[-1],label="Tuple of stride values (Press Enter to submit the input)" ,interactive=True,visible=False)
481
+ spatial_res = Number(value=spatial_res_state[-1],label="Spatial Resolution of Masking Grid (Press Enter to submit the input)" ,interactive=True,visible=False,precision=0)
482
+ with Column(scale=0.33):
483
+ threads = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=False)
484
+ with Row():
485
+ with Column(scale=0.33):
486
+ seed = Number(value=seed_state[-1],label="Seed (Press Enter to submit the input)",interactive=True,visible=False,precision=0)
487
+ with Column(scale=0.33):
488
+ p1 = Slider(value=p1_state[-1],label="P1",interactive=True,visible=False, minimum=0,maximum=1,step=0.1)
489
+ with Column(scale=0.33):
490
+ debiased = Checkbox(value=debiased_state[-1],label="Debiased", interactive=True, visible=False)
491
+ with Row():
492
+ with Column():
493
+ input_img = Image(label="Saliency Map Generation", shape=(640, 480))
494
+ num_classes = Slider(value=2,label="Top-N class labels", interactive=True,visible=True)
495
+ classify = Button("Classify")
496
+ with Column():
497
+ class_label = Label(label="Predicted Class")
498
+ with Column():
499
+ with Row():
500
+ class_name = Dropdown(label="Class to compute saliency",interactive=True,visible=True)
501
+ with Row():
502
+ img_alpha = Slider(value=0.7,label="Image Opacity",interactive=True,visible=True,minimum=0,maximum=1,step=0.1)
503
+ sal_alpha = Slider(value=0.3,label="Saliency Map Opacity",interactive=True,visible=True,minimum=0,maximum=1,step=0.1)
504
+ with Row():
505
+ min_sal_range = Slider(value=0,label="Minimum Saliency Value",interactive=True,visible=True,minimum=-1,maximum=1,step=0.05)
506
+ max_sal_range = Slider(value=1,label="Maximum Saliency Value",interactive=True,visible=True,minimum=-1,maximum=1,step=0.05)
507
+ with Row():
508
+ generate_saliency = Button("Generate Saliency")
509
+ with Column():
510
+ with Tabs():
511
+ with TabItem("Display interpretation with plot"):
512
+ interpretation_plot = Plot()
513
+ with TabItem("Display interpretation with built-in component"):
514
+ interpretation = gr_components.Interpretation(input_img)
515
+
516
+ with Tab("Object Detection"):
517
+ with Row():
518
+ with Column(scale=0.5):
519
+ drop_list_detect_model = Dropdown(value=obj_det_model_name[-1],choices=["Faster-RCNN"],label="Choose Model",interactive="True")
520
+ with Column(scale=0.5):
521
+ drop_list_detect_sal = Dropdown(value=obj_det_saliency_algo_name[-1],choices=["RandomGridStack","DRISE"],label="Choose Saliency Algorithm",interactive="True")
522
+ with Row():
523
+ with Column(scale=0.33):
524
+ masks_detect = Number(value=num_masks_state[-1],label="Number of Random Masks (Press Enter to submit the input)",interactive=True,visible=False,precision=0)
525
+ occlusion_grid_size = Textbox(value=occlusion_grid_state[-1],label="Tuple of occlusion grid size values (Press Enter to submit the input)",interactive=True,visible=False)
526
+ spatial_res_detect = Number(value=spatial_res_state[-1],label="Spatial Resolution of Masking Grid (Press Enter to submit the input)" ,interactive=True,visible=False,precision=0)
527
+ with Column(scale=0.33):
528
+ seed_detect = Number(value=seed_state[-1],label="Seed (Press Enter to submit the input)",interactive=True,visible=False,precision=0)
529
+ p1_detect = Slider(value=p1_state[-1],label="P1",interactive=True,visible=False, minimum=0,maximum=1,step=0.1)
530
+ with Column(scale=0.33):
531
+ threads_detect = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=False)
532
+ with Row():
533
+ with Column():
534
+ input_img_detect = Image(label="Saliency Map Generation", shape=(640, 480))
535
+ num_detections = Slider(value=2,label="Top-N detections", interactive=True,visible=True)
536
+ detection = Button("Run Detection Algorithm")
537
+ with Column():
538
+ detect_label = AnnotatedImage(label="Detections")
539
+ with Column():
540
+ with Row():
541
+ class_name_det = Dropdown(label="Detection to compute saliency",interactive=True,visible=True)
542
+ with Row():
543
+ img_alpha_det = Slider(value=0.7,label="Image Opacity",interactive=True,visible=True,minimum=0,maximum=1,step=0.1)
544
+ sal_alpha_det = Slider(value=0.3,label="Saliency Map Opacity",interactive=True,visible=True,minimum=0,maximum=1,step=0.1)
545
+ with Row():
546
+ min_sal_range_det = Slider(value=0.95,label="Minimum Saliency Value",interactive=True,visible=True,minimum=0.80,maximum=1,step=0.05)
547
+ max_sal_range_det = Slider(value=1,label="Maximum Saliency Value",interactive=True,visible=True,minimum=0.80,maximum=1,step=0.05)
548
+ with Row():
549
+ generate_det_saliency = Button("Generate Saliency")
550
+ with Column():
551
+ with Tabs():
552
+ with TabItem("Display saliency map plot"):
553
+ det_saliency_plot = Plot()
554
+
555
+ # Image Classification dropdown list event listeners
556
+ drop_list.select(select_img_cls_model,drop_list,drop_list)
557
+ drop_list_sal.select(select_img_cls_saliency_algo,drop_list_sal,drop_list_sal)
558
+ drop_list_sal.change(show_textbox_parameters,drop_list_sal,[window_size,stride,masks,spatial_res,seed])
559
+ drop_list_sal.change(show_slider_parameters,drop_list_sal,[threads,p1])
560
+ drop_list_sal.change(show_debiased_checkbox,drop_list_sal,debiased)
561
+
562
+ # Image Classification textbox, slider and checkbox event listeners
563
+ window_size.submit(enter_window_size,window_size,window_size)
564
+ masks.submit(enter_num_masks,masks,masks)
565
+ stride.submit(enter_stride, stride, stride)
566
+ spatial_res.submit(enter_spatial_res, spatial_res, spatial_res)
567
+ seed.submit(enter_seed, seed, seed)
568
+ threads.change(select_threads, threads, threads)
569
+ p1.change(select_p1, p1, p1)
570
+ debiased.change(check_debiased,debiased,debiased)
571
+
572
+ # Image Classification prediction and saliency generation event listeners
573
+ classify.click(predict, [input_img, num_classes], [class_label,class_name])
574
+ class_label.select(map_labels,None,class_name)
575
+ generate_saliency.click(interpretation_function, [input_img, class_label, class_name, img_alpha, sal_alpha, min_sal_range, max_sal_range], [interpretation,interpretation_plot])
576
+
577
+ # Object Detection dropdown list event listeners
578
+ drop_list_detect_model.select(select_obj_det_model,drop_list_detect_model,drop_list_detect_model)
579
+ drop_list_detect_sal.select(select_obj_det_saliency_algo,drop_list_detect_sal,drop_list_detect_sal)
580
+ drop_list_detect_sal.change(show_slider_parameters,drop_list_detect_sal,[threads_detect,p1_detect])
581
+ drop_list_detect_sal.change(show_textbox_parameters,drop_list_detect_sal,[masks_detect,spatial_res_detect,seed_detect,occlusion_grid_size])
582
+
583
+ # Object detection textbox and slider event listeners
584
+ masks_detect.submit(enter_num_masks,masks_detect,masks_detect)
585
+ occlusion_grid_size.submit(enter_occlusion_grid_size,occlusion_grid_size,occlusion_grid_size)
586
+ spatial_res_detect.submit(enter_spatial_res, spatial_res_detect, spatial_res_detect)
587
+ seed_detect.submit(enter_seed, seed_detect, seed_detect)
588
+ threads_detect.change(select_threads, threads_detect, threads_detect)
589
+ p1_detect.change(select_p1, p1_detect, p1_detect)
590
+
591
+ # Object detection prediction, class selection and saliency generation event listeners
592
+ detection.click(run_detect, [input_img_detect, num_detections], [detect_label,class_name_det])
593
+ detect_label.select(map_labels, None, class_name_det)
594
+ generate_det_saliency.click(run_detect_saliency,[input_img_detect, num_detections, class_name_det, img_alpha_det, sal_alpha_det, min_sal_range_det, max_sal_range_det],det_saliency_plot)
595
+
596
+ demo.launch()
gr_component_state.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Choice of image classification model
2
+ img_cls_model_name = ['ResNet-50']
3
+
4
+ # Choice of object detection model
5
+ obj_det_model_name = ['Faster-RCNN']
6
+
7
+ # Choice of image classification saliency algorithm
8
+ img_cls_saliency_algo_name = ['RISE']
9
+
10
+ # Choice of object detection saliency algorithm
11
+ obj_det_saliency_algo_name = ['DRISE']
12
+
13
+ # Number of threads to utilize when generating masks
14
+ threads_state = [4]
15
+
16
+ # Window_size for SlidingWindowStack algorithm
17
+ window_size_state = ['(50,50)']
18
+
19
+ # Stride for SlidingWindowStack algorithm
20
+ stride_state = ['(20,20)']
21
+
22
+ # Number of random masks for RISEStack/DRISEStack algorithm
23
+ num_masks_state = [200]
24
+
25
+ # Spatial resolution of masking grid for RISEStack/DRISEStack algorithm
26
+ spatial_res_state = [8]
27
+
28
+ # Probability of the grid cell being set to 1 (otherwise 0)
29
+ p1_state = [0.5]
30
+
31
+ # Random seed to allow for reproducibility
32
+ seed_state = [0]
33
+
34
+ # Debiased option for RISEStack/DRISEStack saliency algorithm
35
+ debiased_state = [True]
36
+
37
+ # Occlusion grid cell size in pixels for RandomGridStack algorithm
38
+ occlusion_grid_state = ['(128,128)']
39
+
40
+
41
+ def select_img_cls_model(model_choice):
42
+ img_cls_model_name.append(model_choice)
43
+ return model_choice
44
+
45
+
46
+ def select_obj_det_model(model_choice):
47
+ obj_det_model_name.append(model_choice)
48
+ return model_choice
49
+
50
+
51
+ def select_img_cls_saliency_algo(sal_choice):
52
+ img_cls_saliency_algo_name.append(sal_choice)
53
+ return sal_choice
54
+
55
+
56
+ def select_obj_det_saliency_algo(sal_choice):
57
+ obj_det_saliency_algo_name.append(sal_choice)
58
+ return sal_choice
59
+
60
+
61
+ def select_threads(threads):
62
+ threads_state.append(threads)
63
+ return threads
64
+
65
+
66
+ def enter_window_size(val):
67
+ window_size_state.append(val)
68
+ return val
69
+
70
+
71
+ def enter_stride(val):
72
+ stride_state.append(val)
73
+ return val
74
+
75
+
76
+ def enter_num_masks(val):
77
+ num_masks_state.append(val)
78
+ return val
79
+
80
+
81
+ def enter_spatial_res(val):
82
+ spatial_res_state.append(val)
83
+ return val
84
+
85
+
86
+ def select_p1(prob):
87
+ p1_state.append(prob)
88
+ return prob
89
+
90
+
91
+ def enter_seed(seed):
92
+ seed_state.append(seed)
93
+ return seed
94
+
95
+
96
+ def check_debiased(debiased):
97
+ debiased_state.append(debiased)
98
+ return debiased
99
+
100
+
101
+ def enter_occlusion_grid_size(val):
102
+ occlusion_grid_state.append(val)
103
+ return val
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ xaitk-saliency==0.6.1
2
+ torch==1.9.0
3
+ torchvision==0.10.0
4
+ matplotlib
5
+ urllib3
6
+ Pillow
7
+ gradio==3.28.1