Spaces:
Sleeping
Sleeping
Commit
·
5defe06
1
Parent(s):
7984e9a
Updated app to support latest gradio version
Browse files
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 🐢
|
|
4 |
colorFrom: yellow
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: other
|
|
|
4 |
colorFrom: yellow
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.7.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: other
|
app.py
CHANGED
@@ -313,7 +313,7 @@ def predict(x,top_n_classes):
|
|
313 |
labels = list(zip(sal_class_labels, class_conf[sal_class_idxs].tolist()))
|
314 |
final_labels = dict(sorted(labels, key=lambda t: t[1],reverse=True)[:top_n_classes])
|
315 |
|
316 |
-
return final_labels, Dropdown
|
317 |
|
318 |
# Interpretation function for image classification that implements the selected saliency algorithm and generates the class-wise saliency map visualizations
|
319 |
def interpretation_function(image: np.ndarray,
|
@@ -326,22 +326,18 @@ def interpretation_function(image: np.ndarray,
|
|
326 |
|
327 |
sal_generator = get_saliency_algo(img_cls_saliency_algo_name[-1])
|
328 |
sal_generator.fill = blackbox_fill
|
329 |
-
labels_list =
|
330 |
blackbox_classifier.set_labels(labels_list)
|
331 |
sal_maps = sal_generator(image, blackbox_classifier)
|
332 |
nth_class_index = blackbox_classifier.get_labels().index(nth_class)
|
333 |
-
scores = sal_maps[nth_class_index,:,:]
|
334 |
fig = visualize_saliency_plot(image,
|
335 |
sal_maps[nth_class_index,:,:],
|
336 |
img_alpha,
|
337 |
sal_alpha,
|
338 |
sal_range_min,
|
339 |
sal_range_max)
|
340 |
-
|
341 |
-
scores = np.clip(scores, sal_range_min, sal_range_max)
|
342 |
|
343 |
-
return
|
344 |
-
"interpretation": scores.tolist()}, fig
|
345 |
|
346 |
def visualize_saliency_plot(image: np.ndarray,
|
347 |
class_sal_map: np.ndarray,
|
@@ -394,7 +390,7 @@ def run_detect(input_img: np.ndarray, num_detections: int):
|
|
394 |
|
395 |
bboxes_list = bboxes[:,:].astype(int).tolist()
|
396 |
|
397 |
-
return (input_img, list(zip([f for f in bboxes_list], [l for l in final_label]))[:num_detections]), Dropdown
|
398 |
|
399 |
# Run saliency algorithm on the object detect predictions and generate corresponding visualizations
|
400 |
def run_detect_saliency(input_img: np.ndarray,
|
@@ -490,7 +486,7 @@ with gr.Blocks() as demo:
|
|
490 |
debiased = Checkbox(value=debiased_state[-1],label="Debiased", interactive=True, visible=False)
|
491 |
with Row():
|
492 |
with Column():
|
493 |
-
input_img = Image(label="Saliency Map Generation",
|
494 |
num_classes = Slider(value=2,label="Top-N class labels", interactive=True,visible=True)
|
495 |
classify = Button("Classify")
|
496 |
with Column():
|
@@ -510,8 +506,6 @@ with gr.Blocks() as demo:
|
|
510 |
with Tabs():
|
511 |
with TabItem("Display interpretation with plot"):
|
512 |
interpretation_plot = Plot()
|
513 |
-
with TabItem("Display interpretation with built-in component"):
|
514 |
-
interpretation = gr_components.Interpretation(input_img)
|
515 |
|
516 |
with Tab("Object Detection"):
|
517 |
with Row():
|
@@ -531,7 +525,7 @@ with gr.Blocks() as demo:
|
|
531 |
threads_detect = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=False)
|
532 |
with Row():
|
533 |
with Column():
|
534 |
-
input_img_detect = Image(label="Saliency Map Generation",
|
535 |
num_detections = Slider(value=2,label="Top-N detections", interactive=True,visible=True)
|
536 |
detection = Button("Run Detection Algorithm")
|
537 |
with Column():
|
@@ -572,7 +566,7 @@ with gr.Blocks() as demo:
|
|
572 |
# Image Classification prediction and saliency generation event listeners
|
573 |
classify.click(predict, [input_img, num_classes], [class_label,class_name])
|
574 |
class_label.select(map_labels,None,class_name)
|
575 |
-
generate_saliency.click(interpretation_function, [input_img, class_label, class_name, img_alpha, sal_alpha, min_sal_range, max_sal_range], [
|
576 |
|
577 |
# Object Detection dropdown list event listeners
|
578 |
drop_list_detect_model.select(select_obj_det_model,drop_list_detect_model,drop_list_detect_model)
|
|
|
313 |
labels = list(zip(sal_class_labels, class_conf[sal_class_idxs].tolist()))
|
314 |
final_labels = dict(sorted(labels, key=lambda t: t[1],reverse=True)[:top_n_classes])
|
315 |
|
316 |
+
return final_labels, Dropdown(choices=list(final_labels),label="Class to compute saliency",interactive=True,visible=True)
|
317 |
|
318 |
# Interpretation function for image classification that implements the selected saliency algorithm and generates the class-wise saliency map visualizations
|
319 |
def interpretation_function(image: np.ndarray,
|
|
|
326 |
|
327 |
sal_generator = get_saliency_algo(img_cls_saliency_algo_name[-1])
|
328 |
sal_generator.fill = blackbox_fill
|
329 |
+
labels_list = labels.keys()
|
330 |
blackbox_classifier.set_labels(labels_list)
|
331 |
sal_maps = sal_generator(image, blackbox_classifier)
|
332 |
nth_class_index = blackbox_classifier.get_labels().index(nth_class)
|
|
|
333 |
fig = visualize_saliency_plot(image,
|
334 |
sal_maps[nth_class_index,:,:],
|
335 |
img_alpha,
|
336 |
sal_alpha,
|
337 |
sal_range_min,
|
338 |
sal_range_max)
|
|
|
|
|
339 |
|
340 |
+
return fig
|
|
|
341 |
|
342 |
def visualize_saliency_plot(image: np.ndarray,
|
343 |
class_sal_map: np.ndarray,
|
|
|
390 |
|
391 |
bboxes_list = bboxes[:,:].astype(int).tolist()
|
392 |
|
393 |
+
return (input_img, list(zip([f for f in bboxes_list], [l for l in final_label]))[:num_detections]), Dropdown(choices=[l for l in final_label][:num_detections],label="Detection to compute saliency",interactive=True,visible=True)
|
394 |
|
395 |
# Run saliency algorithm on the object detect predictions and generate corresponding visualizations
|
396 |
def run_detect_saliency(input_img: np.ndarray,
|
|
|
486 |
debiased = Checkbox(value=debiased_state[-1],label="Debiased", interactive=True, visible=False)
|
487 |
with Row():
|
488 |
with Column():
|
489 |
+
input_img = Image(label="Saliency Map Generation", width=640, height=480)
|
490 |
num_classes = Slider(value=2,label="Top-N class labels", interactive=True,visible=True)
|
491 |
classify = Button("Classify")
|
492 |
with Column():
|
|
|
506 |
with Tabs():
|
507 |
with TabItem("Display interpretation with plot"):
|
508 |
interpretation_plot = Plot()
|
|
|
|
|
509 |
|
510 |
with Tab("Object Detection"):
|
511 |
with Row():
|
|
|
525 |
threads_detect = Slider(value=threads_state[-1],label="Threads",interactive=True,visible=False)
|
526 |
with Row():
|
527 |
with Column():
|
528 |
+
input_img_detect = Image(label="Saliency Map Generation", width=640, height=480)
|
529 |
num_detections = Slider(value=2,label="Top-N detections", interactive=True,visible=True)
|
530 |
detection = Button("Run Detection Algorithm")
|
531 |
with Column():
|
|
|
566 |
# Image Classification prediction and saliency generation event listeners
|
567 |
classify.click(predict, [input_img, num_classes], [class_label,class_name])
|
568 |
class_label.select(map_labels,None,class_name)
|
569 |
+
generate_saliency.click(interpretation_function, [input_img, class_label, class_name, img_alpha, sal_alpha, min_sal_range, max_sal_range], [interpretation_plot])
|
570 |
|
571 |
# Object Detection dropdown list event listeners
|
572 |
drop_list_detect_model.select(select_obj_det_model,drop_list_detect_model,drop_list_detect_model)
|