barry-ravichandran commited on
Commit
5defe06
·
1 Parent(s): 7984e9a

Updated app to support latest gradio version

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +7 -13
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🐢
4
  colorFrom: yellow
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.28.1
8
  app_file: app.py
9
  pinned: false
10
  license: other
 
4
  colorFrom: yellow
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.7.1
8
  app_file: app.py
9
  pinned: false
10
  license: other
app.py CHANGED
@@ -313,7 +313,7 @@ def predict(x,top_n_classes):
313
  labels = list(zip(sal_class_labels, class_conf[sal_class_idxs].tolist()))
314
  final_labels = dict(sorted(labels, key=lambda t: t[1],reverse=True)[:top_n_classes])
315
 
316
- return final_labels, Dropdown.update(choices=list(final_labels))
317
 
318
  # Interpretation function for image classification that implements the selected saliency algorithm and generates the class-wise saliency map visualizations
319
  def interpretation_function(image: np.ndarray,
@@ -326,22 +326,18 @@ def interpretation_function(image: np.ndarray,
326
 
327
  sal_generator = get_saliency_algo(img_cls_saliency_algo_name[-1])
328
  sal_generator.fill = blackbox_fill
329
- labels_list = [i['label'] for i in labels['confidences']]
330
  blackbox_classifier.set_labels(labels_list)
331
  sal_maps = sal_generator(image, blackbox_classifier)
332
  nth_class_index = blackbox_classifier.get_labels().index(nth_class)
333
- scores = sal_maps[nth_class_index,:,:]
334
  fig = visualize_saliency_plot(image,
335
  sal_maps[nth_class_index,:,:],
336
  img_alpha,
337
  sal_alpha,
338
  sal_range_min,
339
  sal_range_max)
340
-
341
- scores = np.clip(scores, sal_range_min, sal_range_max)
342
 
343
- return {"original": gr_processing_utils.encode_array_to_base64(image),
344
- "interpretation": scores.tolist()}, fig
345
 
346
  def visualize_saliency_plot(image: np.ndarray,
347
  class_sal_map: np.ndarray,
@@ -394,7 +390,7 @@ def run_detect(input_img: np.ndarray, num_detections: int):
394
 
395
  bboxes_list = bboxes[:,:].astype(int).tolist()
396
 
397
- return (input_img, list(zip([f for f in bboxes_list], [l for l in final_label]))[:num_detections]), Dropdown.update(choices=[l for l in final_label][:num_detections])
398
 
399
  # Run saliency algorithm on the object detect predictions and generate corresponding visualizations
400
  def run_detect_saliency(input_img: np.ndarray,
@@ -490,7 +486,7 @@ with gr.Blocks() as demo:
490
  debiased = Checkbox(value=debiased_state[-1],label="Debiased", interactive=True, visible=False)
491
  with Row():
492
  with Column():
493
- input_img = Image(label="Saliency Map Generation", shape=(640, 480))
494
  num_classes = Slider(value=2,label="Top-N class labels", interactive=True,visible=True)
495
  classify = Button("Classify")
496
  with Column():
@@ -510,8 +506,6 @@ with gr.Blocks() as demo:
510
  with Tabs():
511
  with TabItem("Display interpretation with plot"):
512
  interpretation_plot = Plot()
513
- with TabItem("Display interpretation with built-in component"):
514
- interpretation = gr_components.Interpretation(input_img)
515
 
516
  with Tab("Object Detection"):
517
  with Row():
@@ -531,7 +525,7 @@ with gr.Blocks() as demo:
531
  threads_detect = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=False)
532
  with Row():
533
  with Column():
534
- input_img_detect = Image(label="Saliency Map Generation", shape=(640, 480))
535
  num_detections = Slider(value=2,label="Top-N detections", interactive=True,visible=True)
536
  detection = Button("Run Detection Algorithm")
537
  with Column():
@@ -572,7 +566,7 @@ with gr.Blocks() as demo:
572
  # Image Classification prediction and saliency generation event listeners
573
  classify.click(predict, [input_img, num_classes], [class_label,class_name])
574
  class_label.select(map_labels,None,class_name)
575
- generate_saliency.click(interpretation_function, [input_img, class_label, class_name, img_alpha, sal_alpha, min_sal_range, max_sal_range], [interpretation,interpretation_plot])
576
 
577
  # Object Detection dropdown list event listeners
578
  drop_list_detect_model.select(select_obj_det_model,drop_list_detect_model,drop_list_detect_model)
 
313
  labels = list(zip(sal_class_labels, class_conf[sal_class_idxs].tolist()))
314
  final_labels = dict(sorted(labels, key=lambda t: t[1],reverse=True)[:top_n_classes])
315
 
316
+ return final_labels, Dropdown(choices=list(final_labels),label="Class to compute saliency",interactive=True,visible=True)
317
 
318
  # Interpretation function for image classification that implements the selected saliency algorithm and generates the class-wise saliency map visualizations
319
  def interpretation_function(image: np.ndarray,
 
326
 
327
  sal_generator = get_saliency_algo(img_cls_saliency_algo_name[-1])
328
  sal_generator.fill = blackbox_fill
329
+ labels_list = labels.keys()
330
  blackbox_classifier.set_labels(labels_list)
331
  sal_maps = sal_generator(image, blackbox_classifier)
332
  nth_class_index = blackbox_classifier.get_labels().index(nth_class)
 
333
  fig = visualize_saliency_plot(image,
334
  sal_maps[nth_class_index,:,:],
335
  img_alpha,
336
  sal_alpha,
337
  sal_range_min,
338
  sal_range_max)
 
 
339
 
340
+ return fig
 
341
 
342
  def visualize_saliency_plot(image: np.ndarray,
343
  class_sal_map: np.ndarray,
 
390
 
391
  bboxes_list = bboxes[:,:].astype(int).tolist()
392
 
393
+ return (input_img, list(zip([f for f in bboxes_list], [l for l in final_label]))[:num_detections]), Dropdown(choices=[l for l in final_label][:num_detections],label="Detection to compute saliency",interactive=True,visible=True)
394
 
395
  # Run saliency algorithm on the object detect predictions and generate corresponding visualizations
396
  def run_detect_saliency(input_img: np.ndarray,
 
486
  debiased = Checkbox(value=debiased_state[-1],label="Debiased", interactive=True, visible=False)
487
  with Row():
488
  with Column():
489
+ input_img = Image(label="Saliency Map Generation", width=640, height=480)
490
  num_classes = Slider(value=2,label="Top-N class labels", interactive=True,visible=True)
491
  classify = Button("Classify")
492
  with Column():
 
506
  with Tabs():
507
  with TabItem("Display interpretation with plot"):
508
  interpretation_plot = Plot()
 
 
509
 
510
  with Tab("Object Detection"):
511
  with Row():
 
525
  threads_detect = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=False)
526
  with Row():
527
  with Column():
528
+ input_img_detect = Image(label="Saliency Map Generation", width=640, height=480)
529
  num_detections = Slider(value=2,label="Top-N detections", interactive=True,visible=True)
530
  detection = Button("Run Detection Algorithm")
531
  with Column():
 
566
  # Image Classification prediction and saliency generation event listeners
567
  classify.click(predict, [input_img, num_classes], [class_label,class_name])
568
  class_label.select(map_labels,None,class_name)
569
+ generate_saliency.click(interpretation_function, [input_img, class_label, class_name, img_alpha, sal_alpha, min_sal_range, max_sal_range], [interpretation_plot])
570
 
571
  # Object Detection dropdown list event listeners
572
  drop_list_detect_model.select(select_obj_det_model,drop_list_detect_model,drop_list_detect_model)