import gradio import gradio_image_annotation import gradio_imageslider import spaces import torch import src.SegmentAnything2Assist as SegmentAnything2Assist example_image_annotation = { "image": "assets/cars.jpg", "boxes": [ { "label": "+", "color": (0, 255, 0), "xmin": 886, "ymin": 551, "xmax": 886, "ymax": 551, }, { "label": "-", "color": (255, 0, 0), "xmin": 1239, "ymin": 576, "xmax": 1239, "ymax": 576, }, { "label": "-", "color": (255, 0, 0), "xmin": 610, "ymin": 574, "xmax": 610, "ymax": 574, }, { "label": "", "color": (0, 0, 255), "xmin": 254, "ymin": 466, "xmax": 1347, "ymax": 1047, }, ], } VERBOSE = True DEBUG = False segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist( model_name="sam2_hiera_tiny", device=torch.device("cpu") ) def __change_base_model(model_name, device): global segment_anything2assist gradio.Info(f"Changing model to {model_name} on {device}", duration=3) try: segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist( model_name=model_name, device=torch.device(device) ) gradio.Info(f"Model has been changed to {model_name} on {device}", duration=5) except: gradio.Error(f"Model could not be changed", duration=5) def __post_process_annotator_inputs(value): if VERBOSE: print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Called.") __current_mask, __current_segment = None, None new_boxes = [] __image_point_coords = [] __image_point_labels = [] __image_box = [] b_has_box = False for box in value["boxes"]: if box["label"] == "": if not b_has_box: new_box = box.copy() new_box["color"] = (0, 0, 255) new_boxes.append(new_box) b_has_box = True __image_box = [box["xmin"], box["ymin"], box["xmax"], box["ymax"]] elif box["label"] == "+" or box["label"] == "-": new_box = box.copy() new_box["color"] = (0, 255, 0) if box["label"] == "+" else (255, 0, 0) new_box["xmin"] = int((box["xmin"] + box["xmax"]) / 2) new_box["ymin"] = int((box["ymin"] + box["ymax"]) / 2) new_box["xmax"] = new_box["xmin"] new_box["ymax"] = new_box["ymin"] new_boxes.append(new_box) __image_point_coords.append([new_box["xmin"], new_box["ymin"]]) __image_point_labels.append(1 if box["label"] == "+" else 0) if len(__image_box) == 0: __image_box = None if len(__image_point_coords) == 0: __image_point_coords = None if len(__image_point_labels) == 0: __image_point_labels = None if VERBOSE: print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Done.") return __image_point_coords, __image_point_labels, __image_box @spaces.GPU(duration=60) def __generate_mask( value, mask_threshold, max_hole_area, max_sprinkle_area, image_output_mode, ): global segment_anything2assist # Force post processing of annotated image image_point_coords, image_point_labels, image_box = __post_process_annotator_inputs( value ) if VERBOSE: print("SegmentAnything2AssistApp::__generate_mask::Called.") mask_chw, mask_iou = segment_anything2assist.generate_masks_from_image( value["image"], image_point_coords, image_point_labels, image_box, mask_threshold, max_hole_area, max_sprinkle_area, ) if VERBOSE: print("SegmentAnything2AssistApp::__generate_mask::Masks generated.") __current_mask, __current_segment = segment_anything2assist.apply_mask_to_image( value["image"], mask_chw[0] ) if VERBOSE: print("SegmentAnything2AssistApp::__generate_mask::Masks and Segments created.") __image_box = gradio.DataFrame(value=[[]]) __image_point_coords = gradio.DataFrame(value=[[]]) if DEBUG: __image_box = gradio.DataFrame( value=[image_box], label="Box", interactive=False, headers=["XMin", "YMin", "XMax", "YMax"], ) x = [] for i, _ in enumerate(image_point_coords): x.append( [ image_point_labels[i], image_point_coords[i][0], image_point_coords[i][1], ] ) __image_point_coords = gradio.DataFrame( value=x, label="Point Coords", interactive=False, headers=["Label", "X", "Y"], ) if image_output_mode == "Mask": return ( [value["image"], __current_mask], __image_point_coords, __image_box, __current_mask, __current_segment, ) elif image_output_mode == "Segment": return ( [value["image"], __current_segment], __image_point_coords, __image_box, __current_mask, __current_segment, ) else: gradio.Warning("This is an issue, please report the problem!", duration=5) return ( gradio_imageslider.ImageSlider(render=True), __image_point_coords, __image_box, __current_mask, __current_segment, ) def __change_output_mode(image_input, radio, __current_mask, __current_segment): if VERBOSE: print("SegmentAnything2AssistApp::__generate_mask::Called.") if __current_mask is None or __current_segment is None: gradio.Warning("Configuration was changed, generate the mask again", duration=5) return gradio_imageslider.ImageSlider(render=True) if radio == "Mask": return [image_input["image"], __current_mask] elif radio == "Segment": return [image_input["image"], __current_segment] else: gradio.Warning("This is an issue, please report the problem!", duration=5) return gradio_imageslider.ImageSlider(render=True) def __generate_multi_mask_output( image, auto_list, auto_mode, auto_bbox_mode, masks, bboxes ): global segment_anything2assist # When value from gallery is called, it is a tuple if type(masks[0]) == tuple: masks = [mask[0] for mask in masks] image_with_bbox, mask, segment = segment_anything2assist.apply_auto_mask_to_image( image, [int(i) - 1 for i in auto_list], masks, bboxes ) output_1 = image_with_bbox if auto_bbox_mode else image output_2 = mask if auto_mode == "Mask" else segment return [output_1, output_2] @spaces.GPU(duration=60) def __generate_auto_mask( image, points_per_side, points_per_batch, pred_iou_thresh, stability_score_thresh, stability_score_offset, mask_threshold, box_nms_thresh, crop_n_layers, crop_nms_thresh, crop_overlay_ratio, crop_n_points_downscale_factor, min_mask_region_area, use_m2m, multimask_output, output_mode, ): global segment_anything2assist if VERBOSE: print("SegmentAnything2AssistApp::__generate_auto_mask::Called.") __auto_masks, masks, bboxes = segment_anything2assist.generate_automatic_masks( image, points_per_side, points_per_batch, pred_iou_thresh, stability_score_thresh, stability_score_offset, mask_threshold, box_nms_thresh, crop_n_layers, crop_nms_thresh, crop_overlay_ratio, crop_n_points_downscale_factor, min_mask_region_area, use_m2m, multimask_output, ) if len(__auto_masks) == 0: gradio.Warning( "No masks generated, please tweak the advanced parameters.", duration=5 ) return ( gradio_imageslider.ImageSlider(), gradio.CheckboxGroup([], value=[], label="Mask List", interactive=False), gradio.Checkbox(value=False, label="Show Bounding Box", interactive=False), gradio.Gallery( None, label="Output Gallery", interactive=False, type="numpy" ), gradio.DataFrame( value=[[]], label="Box", interactive=False, headers=["XMin", "YMin", "XMax", "YMax"], ), ) else: choices = [str(i) for i in range(len(__auto_masks))] returning_image = __generate_multi_mask_output( image, ["0"], output_mode, False, masks, bboxes ) return ( returning_image, gradio.CheckboxGroup( choices, value=["0"], label="Mask List", interactive=True ), gradio.Checkbox(value=False, label="Show Bounding Box", interactive=True), gradio.Gallery( masks, label="Output Gallery", interactive=True, type="numpy" ), gradio.DataFrame( value=bboxes, label="Box", interactive=False, headers=["XMin", "YMin", "XMax", "YMax"], type="array", ), ) with gradio.Blocks() as base_app: gradio.Markdown("# SegmentAnything2Assist") with gradio.Row(): with gradio.Column(): base_model_choice = gradio.Dropdown( [ "sam2_hiera_large", "sam2_hiera_small", "sam2_hiera_base_plus", "sam2_hiera_tiny", ], value="sam2_hiera_tiny", label="Model Choice", ) with gradio.Column(): base_gpu_choice = gradio.Dropdown( ["cpu", "cuda"], value="cuda", label="Device Choice" ) base_model_choice.change( __change_base_model, inputs=[base_model_choice, base_gpu_choice] ) base_gpu_choice.change( __change_base_model, inputs=[base_model_choice, base_gpu_choice] ) # Image Segmentation with gradio.Tab(label="Image Segmentation", id="image_tab") as image_tab: gradio.Markdown("Image Segmentation", render=True) with gradio.Column(): with gradio.Accordion("Image Annotation Documentation", open=False): gradio.Markdown( """ Image annotation allows you to mark specific regions of an image with labels. In this app, you can annotate an image by drawing boxes and assigning labels to them. The labels can be either '+' or '-'. To annotate an image, simply click and drag to draw a box around the desired region. You can add multiple boxes with different labels. Once you have annotated the image, click the 'Generate Mask' button to generate a mask based on the annotations. The mask can be either a binary mask or a segmented mask, depending on the selected output mode. You can switch between the output modes using the radio buttons. If you make any changes to the annotations or the output mode, you need to regenerate the mask by clicking the button again. Note that the advanced options allow you to adjust the SAM mask threshold, maximum hole area, and maximum sprinkle area. These options control the sensitivity and accuracy of the segmentation process. Experiment with different settings to achieve the desired results. """ ) image_input = gradio_image_annotation.image_annotator( example_image_annotation ) with gradio.Accordion("Advanced Options", open=False): image_generate_SAM_mask_threshold = gradio.Slider( 0.0, 1.0, 0.0, label="SAM Mask Threshold" ) image_generate_SAM_max_hole_area = gradio.Slider( 0, 1000, 0, label="SAM Max Hole Area" ) image_generate_SAM_max_sprinkle_area = gradio.Slider( 0, 1000, 0, label="SAM Max Sprinkle Area" ) image_generate_mask_button = gradio.Button("Generate Mask") with gradio.Row(): with gradio.Column(): image_output_mode = gradio.Radio( ["Segment", "Mask"], value="Segment", label="Output Mode" ) with gradio.Column(scale=3): image_output = gradio_imageslider.ImageSlider() with gradio.Accordion("Debug", open=DEBUG, visible=DEBUG): __image_point_coords = gradio.DataFrame( value=[["+", 886, 551], ["-", 1239, 576]], label="Point Coords", interactive=False, headers=["Label", "X", "Y"], ) __image_box = gradio.DataFrame( value=[[254, 466, 1347, 1047]], label="Box", interactive=False, headers=["XMin", "YMin", "XMax", "YMax"], ) __current_mask = gradio.Image(label="Current Mask", interactive=False) __current_segment = gradio.Image( label="Current Segment", interactive=False ) # image_input.change(__post_process_annotator_inputs, inputs = [image_input]) image_generate_mask_button.click( __generate_mask, inputs=[ image_input, image_generate_SAM_mask_threshold, image_generate_SAM_max_hole_area, image_generate_SAM_max_sprinkle_area, image_output_mode, ], outputs=[ image_output, __image_point_coords, __image_box, __current_mask, __current_segment, ], ) image_output_mode.change( __change_output_mode, inputs=[ image_input, image_output_mode, __current_mask, __current_segment, ], outputs=[image_output], ) # Auto Segmentation with gradio.Tab(label="Auto Segmentation", id="auto_tab"): gradio.Markdown("Auto Segmentation", render=True) with gradio.Column(): with gradio.Accordion("Auto Annotation Documentation", open=False): gradio.Markdown( """ """ ) auto_input = gradio.Image("assets/cars.jpg") with gradio.Accordion("Advanced Options", open=False): auto_generate_SAM_points_per_side = gradio.Slider( 1, 64, 12, 1, label="Points Per Side", interactive=True ) auto_generate_SAM_points_per_batch = gradio.Slider( 1, 64, 32, 1, label="Points Per Batch", interactive=True ) auto_generate_SAM_pred_iou_thresh = gradio.Slider( 0.0, 1.0, 0.8, 1, label="Pred IOU Threshold", interactive=True ) auto_generate_SAM_stability_score_thresh = gradio.Slider( 0.0, 1.0, 0.95, label="Stability Score Threshold", interactive=True ) auto_generate_SAM_stability_score_offset = gradio.Slider( 0.0, 1.0, 1.0, label="Stability Score Offset", interactive=True ) auto_generate_SAM_mask_threshold = gradio.Slider( 0.0, 1.0, 0.0, label="Mask Threshold", interactive=True ) auto_generate_SAM_box_nms_thresh = gradio.Slider( 0.0, 1.0, 0.7, label="Box NMS Threshold", interactive=True ) auto_generate_SAM_crop_n_layers = gradio.Slider( 0, 10, 0, 1, label="Crop N Layers", interactive=True ) auto_generate_SAM_crop_nms_thresh = gradio.Slider( 0.0, 1.0, 0.7, label="Crop NMS Threshold", interactive=True ) auto_generate_SAM_crop_overlay_ratio = gradio.Slider( 0.0, 1.0, 512 / 1500, label="Crop Overlay Ratio", interactive=True ) auto_generate_SAM_crop_n_points_downscale_factor = gradio.Slider( 1, 10, 1, label="Crop N Points Downscale Factor", interactive=True ) auto_generate_SAM_min_mask_region_area = gradio.Slider( 0, 1000, 0, label="Min Mask Region Area", interactive=True ) auto_generate_SAM_use_m2m = gradio.Checkbox( label="Use M2M", interactive=True ) auto_generate_SAM_multimask_output = gradio.Checkbox( value=True, label="Multi Mask Output", interactive=True ) auto_generate_button = gradio.Button("Generate Auto Mask") with gradio.Row(): with gradio.Column(): auto_output_mode = gradio.Radio( ["Segment", "Mask"], value="Segment", label="Output Mode", interactive=True, ) auto_output_list = gradio.CheckboxGroup( [], value=[], label="Mask List", interactive=False ) auto_output_bbox = gradio.Checkbox( value=False, label="Show Bounding Box", interactive=False ) with gradio.Column(scale=3): auto_output = gradio_imageslider.ImageSlider() with gradio.Accordion("Debug", open=DEBUG, visible=DEBUG): __auto_output_gallery = gradio.Gallery( None, label="Output Gallery", interactive=False, type="numpy" ) __auto_bbox = gradio.DataFrame( value=[[]], label="Box", interactive=False, headers=["XMin", "YMin", "XMax", "YMax"], ) auto_generate_button.click( __generate_auto_mask, inputs=[ auto_input, auto_generate_SAM_points_per_side, auto_generate_SAM_points_per_batch, auto_generate_SAM_pred_iou_thresh, auto_generate_SAM_stability_score_thresh, auto_generate_SAM_stability_score_offset, auto_generate_SAM_mask_threshold, auto_generate_SAM_box_nms_thresh, auto_generate_SAM_crop_n_layers, auto_generate_SAM_crop_nms_thresh, auto_generate_SAM_crop_overlay_ratio, auto_generate_SAM_crop_n_points_downscale_factor, auto_generate_SAM_min_mask_region_area, auto_generate_SAM_use_m2m, auto_generate_SAM_multimask_output, auto_output_mode, ], outputs=[ auto_output, auto_output_list, auto_output_bbox, __auto_output_gallery, __auto_bbox, ], ) auto_output_list.change( __generate_multi_mask_output, inputs=[ auto_input, auto_output_list, auto_output_mode, auto_output_bbox, __auto_output_gallery, __auto_bbox, ], outputs=[auto_output], ) auto_output_bbox.change( __generate_multi_mask_output, inputs=[ auto_input, auto_output_list, auto_output_mode, auto_output_bbox, __auto_output_gallery, __auto_bbox, ], outputs=[auto_output], ) auto_output_mode.change( __generate_multi_mask_output, inputs=[ auto_input, auto_output_list, auto_output_mode, auto_output_bbox, __auto_output_gallery, __auto_bbox, ], outputs=[auto_output], ) if __name__ == "__main__": base_app.launch()