import gradio import gradio_image_annotation import gradio_imageslider import spaces import torch import src.SegmentAnything2Assist as SegmentAnything2Assist example_image_annotation = { "image": "assets/cars.jpg", "boxes": [{'label': '+', 'color': (0, 255, 0), 'xmin': 886, 'ymin': 551, 'xmax': 886, 'ymax': 551}, {'label': '-', 'color': (255, 0, 0), 'xmin': 1239, 'ymin': 576, 'xmax': 1239, 'ymax': 576}, {'label': '-', 'color': (255, 0, 0), 'xmin': 610, 'ymin': 574, 'xmax': 610, 'ymax': 574}, {'label': '', 'color': (0, 0, 255), 'xmin': 254, 'ymin': 466, 'xmax': 1347, 'ymax': 1047}] } VERBOSE = True segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(model_name = "sam2_hiera_tiny", device = torch.device("cuda")) __image_point_coords = None __image_point_labels = None __image_box = None __current_mask = None __current_segment = None def __change_base_model(model_name, device): global segment_anything2assist try: segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(model_name = model_name, device = torch.device(device)) gradio.Info(f"Model changed to {model_name} on {device}", duration = 5) except: gradio.Error(f"Model could not be changed", duration = 5) def __post_process_annotator_inputs(value): global __image_point_coords, __image_point_labels, __image_box global __current_mask, __current_segment if VERBOSE: print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Called.") __current_mask, __current_segment = None, None new_boxes = [] __image_point_coords = [] __image_point_labels = [] __image_box = [] b_has_box = False for box in value["boxes"]: if box['label'] == '': if not b_has_box: new_box = box.copy() new_box['color'] = (0, 0, 255) new_boxes.append(new_box) b_has_box = True __image_box = [ box['xmin'], box['ymin'], box['xmax'], box['ymax'] ] elif box['label'] == '+' or box['label'] == '-': new_box = box.copy() new_box['color'] = (0, 255, 0) if box['label'] == '+' else (255, 0, 0) new_box['xmin'] = int((box['xmin'] + box['xmax']) / 2) new_box['ymin'] = int((box['ymin'] + box['ymax']) / 2) new_box['xmax'] = new_box['xmin'] new_box['ymax'] = new_box['ymin'] new_boxes.append(new_box) __image_point_coords.append([new_box['xmin'], new_box['ymin']]) __image_point_labels.append(1 if box['label'] == '+' else 0) if len(__image_box) == 0: __image_box = None if len(__image_point_coords) == 0: __image_point_coords = None if len(__image_point_labels) == 0: __image_point_labels = None if VERBOSE: print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Done.") @spaces.GPU(duration = 60) def __generate_mask(value, mask_threshold, max_hole_area, max_sprinkle_area, image_output_mode): global __current_mask, __current_segment global __image_point_coords, __image_point_labels, __image_box global segment_anything2assist # Force post processing of annotated image __post_process_annotator_inputs(value) if VERBOSE: print("SegmentAnything2AssistApp::__generate_mask::Called.") mask_chw, mask_iou = segment_anything2assist.generate_masks_from_image( value["image"], __image_point_coords, __image_point_labels, __image_box, mask_threshold, max_hole_area, max_sprinkle_area ) if VERBOSE: print("SegmentAnything2AssistApp::__generate_mask::Masks generated.") __current_mask, __current_segment = segment_anything2assist.apply_mask_to_image(value["image"], mask_chw[0]) if VERBOSE: print("SegmentAnything2AssistApp::__generate_mask::Masks and Segments created.") if image_output_mode == "Mask": return [value["image"], __current_mask] elif image_output_mode == "Segment": return [value["image"], __current_segment] else: gradio.Warning("This is an issue, please report the problem!", duration=5) return gradio_imageslider.ImageSlider(render = True) def __change_output_mode(image_input, radio): global __current_mask, __current_segment global __image_point_coords, __image_point_labels, __image_box if VERBOSE: print("SegmentAnything2AssistApp::__generate_mask::Called.") if __current_mask is None or __current_segment is None: gradio.Warning("Configuration was changed, generate the mask again", duration=5) return gradio_imageslider.ImageSlider(render = True) if radio == "Mask": return [image_input["image"], __current_mask] elif radio == "Segment": return [image_input["image"], __current_segment] else: gradio.Warning("This is an issue, please report the problem!", duration=5) return gradio_imageslider.ImageSlider(render = True) def __generate_multi_mask_output(image, auto_list, auto_mode, auto_bbox_mode): global segment_anything2assist image_with_bbox, mask, segment = segment_anything2assist.apply_auto_mask_to_image(image, [int(i) - 1 for i in auto_list]) output_1 = image_with_bbox if auto_bbox_mode else image output_2 = mask if auto_mode == "Mask" else segment return [output_1, output_2] @spaces.GPU(duration = 60) def __generate_auto_mask( image, points_per_side, points_per_batch, pred_iou_thresh, stability_score_thresh, stability_score_offset, mask_threshold, box_nms_thresh, crop_n_layers, crop_nms_thresh, crop_overlay_ratio, crop_n_points_downscale_factor, min_mask_region_area, use_m2m, multimask_output, output_mode ): global segment_anything2assist if VERBOSE: print("SegmentAnything2AssistApp::__generate_auto_mask::Called.") __auto_masks = segment_anything2assist.generate_automatic_masks( image, points_per_side, points_per_batch, pred_iou_thresh, stability_score_thresh, stability_score_offset, mask_threshold, box_nms_thresh, crop_n_layers, crop_nms_thresh, crop_overlay_ratio, crop_n_points_downscale_factor, min_mask_region_area, use_m2m, multimask_output ) if len(__auto_masks) == 0: gradio.Warning("No masks generated, please tweak the advanced parameters.", duration = 5) return gradio_imageslider.ImageSlider(), \ gradio.CheckboxGroup([], value = [], label = "Mask List", interactive = False), \ gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = False) else: choices = [str(i) for i in range(len(__auto_masks))] returning_image = __generate_multi_mask_output(image, ["0"], output_mode, False) return returning_image, \ gradio.CheckboxGroup(choices, value = ["0"], label = "Mask List", interactive = True), \ gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = True) with gradio.Blocks() as base_app: gradio.Markdown("# SegmentAnything2Assist") with gradio.Row(): with gradio.Column(): base_model_choice = gradio.Dropdown( ['sam2_hiera_large', 'sam2_hiera_small', 'sam2_hiera_base_plus','sam2_hiera_tiny'], value = 'sam2_hiera_tiny', label = "Model Choice" ) with gradio.Column(): base_gpu_choice = gradio.Dropdown( ['cpu', 'cuda'], value = 'cuda', label = "Device Choice" ) base_model_choice.change(__change_base_model, inputs = [base_model_choice, base_gpu_choice]) base_gpu_choice.change(__change_base_model, inputs = [base_model_choice, base_gpu_choice]) with gradio.Tab(label = "Image Segmentation", id = "image_tab") as image_tab: gradio.Markdown("Image Segmentation", render = True) with gradio.Column(): with gradio.Accordion("Image Annotation Documentation", open = False): gradio.Markdown(""" Image annotation allows you to mark specific regions of an image with labels. In this app, you can annotate an image by drawing boxes and assigning labels to them. The labels can be either '+' or '-'. To annotate an image, simply click and drag to draw a box around the desired region. You can add multiple boxes with different labels. Once you have annotated the image, click the 'Generate Mask' button to generate a mask based on the annotations. The mask can be either a binary mask or a segmented mask, depending on the selected output mode. You can switch between the output modes using the radio buttons. If you make any changes to the annotations or the output mode, you need to regenerate the mask by clicking the button again. Note that the advanced options allow you to adjust the SAM mask threshold, maximum hole area, and maximum sprinkle area. These options control the sensitivity and accuracy of the segmentation process. Experiment with different settings to achieve the desired results. """) image_input = gradio_image_annotation.image_annotator(example_image_annotation) with gradio.Accordion("Advanced Options", open = False): image_generate_SAM_mask_threshold = gradio.Slider(0.0, 1.0, 0.0, label = "SAM Mask Threshold") image_generate_SAM_max_hole_area = gradio.Slider(0, 1000, 0, label = "SAM Max Hole Area") image_generate_SAM_max_sprinkle_area = gradio.Slider(0, 1000, 0, label = "SAM Max Sprinkle Area") image_generate_mask_button = gradio.Button("Generate Mask") image_output = gradio_imageslider.ImageSlider() image_output_mode = gradio.Radio(["Segment", "Mask"], value = "Segment", label = "Output Mode") image_input.change(__post_process_annotator_inputs, inputs = [image_input]) image_generate_mask_button.click(__generate_mask, inputs = [ image_input, image_generate_SAM_mask_threshold, image_generate_SAM_max_hole_area, image_generate_SAM_max_sprinkle_area, image_output_mode ], outputs = [image_output]) image_output_mode.change(__change_output_mode, inputs = [image_input, image_output_mode], outputs = [image_output]) with gradio.Tab(label = "Auto Segmentation", id = "auto_tab"): gradio.Markdown("Auto Segmentation", render = True) with gradio.Column(): with gradio.Accordion("Auto Annotation Documentation", open = False): gradio.Markdown(""" """) auto_input = gradio.Image("assets/cars.jpg") with gradio.Accordion("Advanced Options", open = False): auto_generate_SAM_points_per_side = gradio.Slider(1, 64, 32, 1, label = "Points Per Side", interactive = True) auto_generate_SAM_points_per_batch = gradio.Slider(1, 64, 32, 1, label = "Points Per Batch", interactive = True) auto_generate_SAM_pred_iou_thresh = gradio.Slider(0.0, 1.0, 0.8, 1, label = "Pred IOU Threshold", interactive = True) auto_generate_SAM_stability_score_thresh = gradio.Slider(0.0, 1.0, 0.95, label = "Stability Score Threshold", interactive = True) auto_generate_SAM_stability_score_offset = gradio.Slider(0.0, 1.0, 1.0, label = "Stability Score Offset", interactive = True) auto_generate_SAM_mask_threshold = gradio.Slider(0.0, 1.0, 0.0, label = "Mask Threshold", interactive = True) auto_generate_SAM_box_nms_thresh = gradio.Slider(0.0, 1.0, 0.7, label = "Box NMS Threshold", interactive = True) auto_generate_SAM_crop_n_layers = gradio.Slider(0, 10, 0, 1, label = "Crop N Layers", interactive = True) auto_generate_SAM_crop_nms_thresh = gradio.Slider(0.0, 1.0, 0.7, label = "Crop NMS Threshold", interactive = True) auto_generate_SAM_crop_overlay_ratio = gradio.Slider(0.0, 1.0, 512 / 1500, label = "Crop Overlay Ratio", interactive = True) auto_generate_SAM_crop_n_points_downscale_factor = gradio.Slider(1, 10, 1, label = "Crop N Points Downscale Factor", interactive = True) auto_generate_SAM_min_mask_region_area = gradio.Slider(0, 1000, 0, label = "Min Mask Region Area", interactive = True) auto_generate_SAM_use_m2m = gradio.Checkbox(label = "Use M2M", interactive = True) auto_generate_SAM_multimask_output = gradio.Checkbox(value = True, label = "Multi Mask Output", interactive = True) auto_generate_button = gradio.Button("Generate Auto Mask") with gradio.Row(): with gradio.Column(): auto_output_mode = gradio.Radio(["Segment", "Mask"], value = "Segment", label = "Output Mode", interactive = True) auto_output_list = gradio.CheckboxGroup([], value = [], label = "Mask List", interactive = False) auto_output_bbox = gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = False) with gradio.Column(scale = 3): auto_output = gradio_imageslider.ImageSlider() auto_generate_button.click( __generate_auto_mask, inputs = [ auto_input, auto_generate_SAM_points_per_side, auto_generate_SAM_points_per_batch, auto_generate_SAM_pred_iou_thresh, auto_generate_SAM_stability_score_thresh, auto_generate_SAM_stability_score_offset, auto_generate_SAM_mask_threshold, auto_generate_SAM_box_nms_thresh, auto_generate_SAM_crop_n_layers, auto_generate_SAM_crop_nms_thresh, auto_generate_SAM_crop_overlay_ratio, auto_generate_SAM_crop_n_points_downscale_factor, auto_generate_SAM_min_mask_region_area, auto_generate_SAM_use_m2m, auto_generate_SAM_multimask_output, auto_output_mode ], outputs = [ auto_output, auto_output_list, auto_output_bbox ] ) auto_output_list.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output]) auto_output_bbox.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output]) auto_output_mode.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output]) if __name__ == "__main__": base_app.launch()