Spaces:
Running
on
Zero
Running
on
Zero
import gradio | |
import gradio_image_annotation | |
import gradio_imageslider | |
import spaces | |
import torch | |
import src.SegmentAnything2Assist as SegmentAnything2Assist | |
example_image_annotation = { | |
"image": "assets/cars.jpg", | |
"boxes": [ | |
{ | |
"label": "+", | |
"color": (0, 255, 0), | |
"xmin": 886, | |
"ymin": 551, | |
"xmax": 886, | |
"ymax": 551, | |
}, | |
{ | |
"label": "-", | |
"color": (255, 0, 0), | |
"xmin": 1239, | |
"ymin": 576, | |
"xmax": 1239, | |
"ymax": 576, | |
}, | |
{ | |
"label": "-", | |
"color": (255, 0, 0), | |
"xmin": 610, | |
"ymin": 574, | |
"xmax": 610, | |
"ymax": 574, | |
}, | |
{ | |
"label": "", | |
"color": (0, 0, 255), | |
"xmin": 254, | |
"ymin": 466, | |
"xmax": 1347, | |
"ymax": 1047, | |
}, | |
], | |
} | |
VERBOSE = True | |
DEBUG = False | |
segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist( | |
model_name="sam2_hiera_tiny", device=torch.device("cpu") | |
) | |
def __change_base_model(model_name, device): | |
global segment_anything2assist | |
gradio.Info(f"Changing model to {model_name} on {device}", duration=3) | |
try: | |
segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist( | |
model_name=model_name, device=torch.device(device) | |
) | |
gradio.Info(f"Model has been changed to {model_name} on {device}", duration=5) | |
except: | |
gradio.Error(f"Model could not be changed", duration=5) | |
def __post_process_annotator_inputs(value): | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Called.") | |
__current_mask, __current_segment = None, None | |
new_boxes = [] | |
__image_point_coords = [] | |
__image_point_labels = [] | |
__image_box = [] | |
b_has_box = False | |
for box in value["boxes"]: | |
if box["label"] == "": | |
if not b_has_box: | |
new_box = box.copy() | |
new_box["color"] = (0, 0, 255) | |
new_boxes.append(new_box) | |
b_has_box = True | |
__image_box = [box["xmin"], box["ymin"], box["xmax"], box["ymax"]] | |
elif box["label"] == "+" or box["label"] == "-": | |
new_box = box.copy() | |
new_box["color"] = (0, 255, 0) if box["label"] == "+" else (255, 0, 0) | |
new_box["xmin"] = int((box["xmin"] + box["xmax"]) / 2) | |
new_box["ymin"] = int((box["ymin"] + box["ymax"]) / 2) | |
new_box["xmax"] = new_box["xmin"] | |
new_box["ymax"] = new_box["ymin"] | |
new_boxes.append(new_box) | |
__image_point_coords.append([new_box["xmin"], new_box["ymin"]]) | |
__image_point_labels.append(1 if box["label"] == "+" else 0) | |
if len(__image_box) == 0: | |
__image_box = None | |
if len(__image_point_coords) == 0: | |
__image_point_coords = None | |
if len(__image_point_labels) == 0: | |
__image_point_labels = None | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Done.") | |
return __image_point_coords, __image_point_labels, __image_box | |
def __generate_mask( | |
value, | |
mask_threshold, | |
max_hole_area, | |
max_sprinkle_area, | |
image_output_mode, | |
): | |
global segment_anything2assist | |
# Force post processing of annotated image | |
image_point_coords, image_point_labels, image_box = __post_process_annotator_inputs( | |
value | |
) | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::__generate_mask::Called.") | |
mask_chw, mask_iou = segment_anything2assist.generate_masks_from_image( | |
value["image"], | |
image_point_coords, | |
image_point_labels, | |
image_box, | |
mask_threshold, | |
max_hole_area, | |
max_sprinkle_area, | |
) | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::__generate_mask::Masks generated.") | |
__current_mask, __current_segment = segment_anything2assist.apply_mask_to_image( | |
value["image"], mask_chw[0] | |
) | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::__generate_mask::Masks and Segments created.") | |
__image_box = gradio.DataFrame(value=[[]]) | |
__image_point_coords = gradio.DataFrame(value=[[]]) | |
if DEBUG: | |
__image_box = gradio.DataFrame( | |
value=[image_box], | |
label="Box", | |
interactive=False, | |
headers=["XMin", "YMin", "XMax", "YMax"], | |
) | |
x = [] | |
for i, _ in enumerate(image_point_coords): | |
x.append( | |
[ | |
image_point_labels[i], | |
image_point_coords[i][0], | |
image_point_coords[i][1], | |
] | |
) | |
__image_point_coords = gradio.DataFrame( | |
value=x, | |
label="Point Coords", | |
interactive=False, | |
headers=["Label", "X", "Y"], | |
) | |
if image_output_mode == "Mask": | |
return ( | |
[value["image"], __current_mask], | |
__image_point_coords, | |
__image_box, | |
__current_mask, | |
__current_segment, | |
) | |
elif image_output_mode == "Segment": | |
return ( | |
[value["image"], __current_segment], | |
__image_point_coords, | |
__image_box, | |
__current_mask, | |
__current_segment, | |
) | |
else: | |
gradio.Warning("This is an issue, please report the problem!", duration=5) | |
return ( | |
gradio_imageslider.ImageSlider(render=True), | |
__image_point_coords, | |
__image_box, | |
__current_mask, | |
__current_segment, | |
) | |
def __change_output_mode(image_input, radio, __current_mask, __current_segment): | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::__generate_mask::Called.") | |
if __current_mask is None or __current_segment is None: | |
gradio.Warning("Configuration was changed, generate the mask again", duration=5) | |
return gradio_imageslider.ImageSlider(render=True) | |
if radio == "Mask": | |
return [image_input["image"], __current_mask] | |
elif radio == "Segment": | |
return [image_input["image"], __current_segment] | |
else: | |
gradio.Warning("This is an issue, please report the problem!", duration=5) | |
return gradio_imageslider.ImageSlider(render=True) | |
def __generate_multi_mask_output( | |
image, auto_list, auto_mode, auto_bbox_mode, masks, bboxes | |
): | |
global segment_anything2assist | |
# When value from gallery is called, it is a tuple | |
if type(masks[0]) == tuple: | |
masks = [mask[0] for mask in masks] | |
image_with_bbox, mask, segment = segment_anything2assist.apply_auto_mask_to_image( | |
image, [int(i) - 1 for i in auto_list], masks, bboxes | |
) | |
output_1 = image_with_bbox if auto_bbox_mode else image | |
output_2 = mask if auto_mode == "Mask" else segment | |
return [output_1, output_2] | |
def __generate_auto_mask( | |
image, | |
points_per_side, | |
points_per_batch, | |
pred_iou_thresh, | |
stability_score_thresh, | |
stability_score_offset, | |
mask_threshold, | |
box_nms_thresh, | |
crop_n_layers, | |
crop_nms_thresh, | |
crop_overlay_ratio, | |
crop_n_points_downscale_factor, | |
min_mask_region_area, | |
use_m2m, | |
multimask_output, | |
output_mode, | |
): | |
global segment_anything2assist | |
if VERBOSE: | |
print("SegmentAnything2AssistApp::__generate_auto_mask::Called.") | |
__auto_masks, masks, bboxes = segment_anything2assist.generate_automatic_masks( | |
image, | |
points_per_side, | |
points_per_batch, | |
pred_iou_thresh, | |
stability_score_thresh, | |
stability_score_offset, | |
mask_threshold, | |
box_nms_thresh, | |
crop_n_layers, | |
crop_nms_thresh, | |
crop_overlay_ratio, | |
crop_n_points_downscale_factor, | |
min_mask_region_area, | |
use_m2m, | |
multimask_output, | |
) | |
if len(__auto_masks) == 0: | |
gradio.Warning( | |
"No masks generated, please tweak the advanced parameters.", duration=5 | |
) | |
return ( | |
gradio_imageslider.ImageSlider(), | |
gradio.CheckboxGroup([], value=[], label="Mask List", interactive=False), | |
gradio.Checkbox(value=False, label="Show Bounding Box", interactive=False), | |
gradio.Gallery( | |
None, label="Output Gallery", interactive=False, type="numpy" | |
), | |
gradio.DataFrame( | |
value=[[]], | |
label="Box", | |
interactive=False, | |
headers=["XMin", "YMin", "XMax", "YMax"], | |
), | |
) | |
else: | |
choices = [str(i) for i in range(len(__auto_masks))] | |
returning_image = __generate_multi_mask_output( | |
image, ["0"], output_mode, False, masks, bboxes | |
) | |
return ( | |
returning_image, | |
gradio.CheckboxGroup( | |
choices, value=["0"], label="Mask List", interactive=True | |
), | |
gradio.Checkbox(value=False, label="Show Bounding Box", interactive=True), | |
gradio.Gallery( | |
masks, label="Output Gallery", interactive=True, type="numpy" | |
), | |
gradio.DataFrame( | |
value=bboxes, | |
label="Box", | |
interactive=False, | |
headers=["XMin", "YMin", "XMax", "YMax"], | |
type="array", | |
), | |
) | |
with gradio.Blocks() as base_app: | |
gradio.Markdown("# SegmentAnything2Assist") | |
with gradio.Row(): | |
with gradio.Column(): | |
base_model_choice = gradio.Dropdown( | |
[ | |
"sam2_hiera_large", | |
"sam2_hiera_small", | |
"sam2_hiera_base_plus", | |
"sam2_hiera_tiny", | |
], | |
value="sam2_hiera_tiny", | |
label="Model Choice", | |
) | |
with gradio.Column(): | |
base_gpu_choice = gradio.Dropdown( | |
["cpu", "cuda"], value="cuda", label="Device Choice" | |
) | |
base_model_choice.change( | |
__change_base_model, inputs=[base_model_choice, base_gpu_choice] | |
) | |
base_gpu_choice.change( | |
__change_base_model, inputs=[base_model_choice, base_gpu_choice] | |
) | |
# Image Segmentation | |
with gradio.Tab(label="Image Segmentation", id="image_tab") as image_tab: | |
gradio.Markdown("Image Segmentation", render=True) | |
with gradio.Column(): | |
with gradio.Accordion("Image Annotation Documentation", open=False): | |
gradio.Markdown( | |
""" | |
Image annotation allows you to mark specific regions of an image with labels. | |
In this app, you can annotate an image by drawing boxes and assigning labels to them. | |
The labels can be either '+' or '-'. | |
To annotate an image, simply click and drag to draw a box around the desired region. | |
You can add multiple boxes with different labels. | |
Once you have annotated the image, click the 'Generate Mask' button to generate a mask based on the annotations. | |
The mask can be either a binary mask or a segmented mask, depending on the selected output mode. | |
You can switch between the output modes using the radio buttons. | |
If you make any changes to the annotations or the output mode, you need to regenerate the mask by clicking the button again. | |
Note that the advanced options allow you to adjust the SAM mask threshold, maximum hole area, and maximum sprinkle area. | |
These options control the sensitivity and accuracy of the segmentation process. | |
Experiment with different settings to achieve the desired results. | |
""" | |
) | |
image_input = gradio_image_annotation.image_annotator( | |
example_image_annotation | |
) | |
with gradio.Accordion("Advanced Options", open=False): | |
image_generate_SAM_mask_threshold = gradio.Slider( | |
0.0, 1.0, 0.0, label="SAM Mask Threshold" | |
) | |
image_generate_SAM_max_hole_area = gradio.Slider( | |
0, 1000, 0, label="SAM Max Hole Area" | |
) | |
image_generate_SAM_max_sprinkle_area = gradio.Slider( | |
0, 1000, 0, label="SAM Max Sprinkle Area" | |
) | |
image_generate_mask_button = gradio.Button("Generate Mask") | |
with gradio.Row(): | |
with gradio.Column(): | |
image_output_mode = gradio.Radio( | |
["Segment", "Mask"], value="Segment", label="Output Mode" | |
) | |
with gradio.Column(scale=3): | |
image_output = gradio_imageslider.ImageSlider() | |
with gradio.Accordion("Debug", open=DEBUG, visible=DEBUG): | |
__image_point_coords = gradio.DataFrame( | |
value=[["+", 886, 551], ["-", 1239, 576]], | |
label="Point Coords", | |
interactive=False, | |
headers=["Label", "X", "Y"], | |
) | |
__image_box = gradio.DataFrame( | |
value=[[254, 466, 1347, 1047]], | |
label="Box", | |
interactive=False, | |
headers=["XMin", "YMin", "XMax", "YMax"], | |
) | |
__current_mask = gradio.Image(label="Current Mask", interactive=False) | |
__current_segment = gradio.Image( | |
label="Current Segment", interactive=False | |
) | |
# image_input.change(__post_process_annotator_inputs, inputs = [image_input]) | |
image_generate_mask_button.click( | |
__generate_mask, | |
inputs=[ | |
image_input, | |
image_generate_SAM_mask_threshold, | |
image_generate_SAM_max_hole_area, | |
image_generate_SAM_max_sprinkle_area, | |
image_output_mode, | |
], | |
outputs=[ | |
image_output, | |
__image_point_coords, | |
__image_box, | |
__current_mask, | |
__current_segment, | |
], | |
) | |
image_output_mode.change( | |
__change_output_mode, | |
inputs=[ | |
image_input, | |
image_output_mode, | |
__current_mask, | |
__current_segment, | |
], | |
outputs=[image_output], | |
) | |
# Auto Segmentation | |
with gradio.Tab(label="Auto Segmentation", id="auto_tab"): | |
gradio.Markdown("Auto Segmentation", render=True) | |
with gradio.Column(): | |
with gradio.Accordion("Auto Annotation Documentation", open=False): | |
gradio.Markdown( | |
""" | |
""" | |
) | |
auto_input = gradio.Image("assets/cars.jpg") | |
with gradio.Accordion("Advanced Options", open=False): | |
auto_generate_SAM_points_per_side = gradio.Slider( | |
1, 64, 12, 1, label="Points Per Side", interactive=True | |
) | |
auto_generate_SAM_points_per_batch = gradio.Slider( | |
1, 64, 32, 1, label="Points Per Batch", interactive=True | |
) | |
auto_generate_SAM_pred_iou_thresh = gradio.Slider( | |
0.0, 1.0, 0.8, 1, label="Pred IOU Threshold", interactive=True | |
) | |
auto_generate_SAM_stability_score_thresh = gradio.Slider( | |
0.0, 1.0, 0.95, label="Stability Score Threshold", interactive=True | |
) | |
auto_generate_SAM_stability_score_offset = gradio.Slider( | |
0.0, 1.0, 1.0, label="Stability Score Offset", interactive=True | |
) | |
auto_generate_SAM_mask_threshold = gradio.Slider( | |
0.0, 1.0, 0.0, label="Mask Threshold", interactive=True | |
) | |
auto_generate_SAM_box_nms_thresh = gradio.Slider( | |
0.0, 1.0, 0.7, label="Box NMS Threshold", interactive=True | |
) | |
auto_generate_SAM_crop_n_layers = gradio.Slider( | |
0, 10, 0, 1, label="Crop N Layers", interactive=True | |
) | |
auto_generate_SAM_crop_nms_thresh = gradio.Slider( | |
0.0, 1.0, 0.7, label="Crop NMS Threshold", interactive=True | |
) | |
auto_generate_SAM_crop_overlay_ratio = gradio.Slider( | |
0.0, 1.0, 512 / 1500, label="Crop Overlay Ratio", interactive=True | |
) | |
auto_generate_SAM_crop_n_points_downscale_factor = gradio.Slider( | |
1, 10, 1, label="Crop N Points Downscale Factor", interactive=True | |
) | |
auto_generate_SAM_min_mask_region_area = gradio.Slider( | |
0, 1000, 0, label="Min Mask Region Area", interactive=True | |
) | |
auto_generate_SAM_use_m2m = gradio.Checkbox( | |
label="Use M2M", interactive=True | |
) | |
auto_generate_SAM_multimask_output = gradio.Checkbox( | |
value=True, label="Multi Mask Output", interactive=True | |
) | |
auto_generate_button = gradio.Button("Generate Auto Mask") | |
with gradio.Row(): | |
with gradio.Column(): | |
auto_output_mode = gradio.Radio( | |
["Segment", "Mask"], | |
value="Segment", | |
label="Output Mode", | |
interactive=True, | |
) | |
auto_output_list = gradio.CheckboxGroup( | |
[], value=[], label="Mask List", interactive=False | |
) | |
auto_output_bbox = gradio.Checkbox( | |
value=False, label="Show Bounding Box", interactive=False | |
) | |
with gradio.Column(scale=3): | |
auto_output = gradio_imageslider.ImageSlider() | |
with gradio.Accordion("Debug", open=DEBUG, visible=DEBUG): | |
__auto_output_gallery = gradio.Gallery( | |
None, label="Output Gallery", interactive=False, type="numpy" | |
) | |
__auto_bbox = gradio.DataFrame( | |
value=[[]], | |
label="Box", | |
interactive=False, | |
headers=["XMin", "YMin", "XMax", "YMax"], | |
) | |
auto_generate_button.click( | |
__generate_auto_mask, | |
inputs=[ | |
auto_input, | |
auto_generate_SAM_points_per_side, | |
auto_generate_SAM_points_per_batch, | |
auto_generate_SAM_pred_iou_thresh, | |
auto_generate_SAM_stability_score_thresh, | |
auto_generate_SAM_stability_score_offset, | |
auto_generate_SAM_mask_threshold, | |
auto_generate_SAM_box_nms_thresh, | |
auto_generate_SAM_crop_n_layers, | |
auto_generate_SAM_crop_nms_thresh, | |
auto_generate_SAM_crop_overlay_ratio, | |
auto_generate_SAM_crop_n_points_downscale_factor, | |
auto_generate_SAM_min_mask_region_area, | |
auto_generate_SAM_use_m2m, | |
auto_generate_SAM_multimask_output, | |
auto_output_mode, | |
], | |
outputs=[ | |
auto_output, | |
auto_output_list, | |
auto_output_bbox, | |
__auto_output_gallery, | |
__auto_bbox, | |
], | |
) | |
auto_output_list.change( | |
__generate_multi_mask_output, | |
inputs=[ | |
auto_input, | |
auto_output_list, | |
auto_output_mode, | |
auto_output_bbox, | |
__auto_output_gallery, | |
__auto_bbox, | |
], | |
outputs=[auto_output], | |
) | |
auto_output_bbox.change( | |
__generate_multi_mask_output, | |
inputs=[ | |
auto_input, | |
auto_output_list, | |
auto_output_mode, | |
auto_output_bbox, | |
__auto_output_gallery, | |
__auto_bbox, | |
], | |
outputs=[auto_output], | |
) | |
auto_output_mode.change( | |
__generate_multi_mask_output, | |
inputs=[ | |
auto_input, | |
auto_output_list, | |
auto_output_mode, | |
auto_output_bbox, | |
__auto_output_gallery, | |
__auto_bbox, | |
], | |
outputs=[auto_output], | |
) | |
if __name__ == "__main__": | |
base_app.launch() | |