Added 'Visualizer' class and mask refinement to utils.py
Browse filesThis commit includes the addition of the 'Visualizer' class in utils.py, which provides several methods to annotate images with bounding boxes, masks, polygons, and labels. It also included new utility functions for refining and filtering masks based on their relative area.
app.py
CHANGED
@@ -8,10 +8,11 @@ import supervision as sv
|
|
8 |
|
9 |
from typing import List
|
10 |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
11 |
-
from utils import
|
12 |
|
13 |
HOME = os.getenv("HOME")
|
14 |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
|
15 |
|
16 |
SAM_CHECKPOINT = os.path.join(HOME, "app/weights/sam_vit_h_4b8939.pth")
|
17 |
# SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
|
@@ -27,19 +28,34 @@ MARKDOWN = """
|
|
27 |
</h1>
|
28 |
"""
|
29 |
|
30 |
-
|
31 |
-
|
32 |
|
33 |
|
34 |
def inference(image: np.ndarray, annotation_mode: List[str]) -> np.ndarray:
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
image_input = gr.Image(
|
39 |
label="Input",
|
40 |
-
type="numpy"
|
|
|
41 |
checkbox_annotation_mode = gr.CheckboxGroup(
|
42 |
-
choices=["Mark", "Mask", "Box"],
|
43 |
value=['Mark'],
|
44 |
label="Annotation Mode")
|
45 |
image_output = gr.Image(
|
|
|
8 |
|
9 |
from typing import List
|
10 |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
11 |
+
from utils import postprocess_masks, Visualizer
|
12 |
|
13 |
HOME = os.getenv("HOME")
|
14 |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
15 |
+
MINIMUM_AREA_THRESHOLD = 0.01
|
16 |
|
17 |
SAM_CHECKPOINT = os.path.join(HOME, "app/weights/sam_vit_h_4b8939.pth")
|
18 |
# SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
|
|
|
28 |
</h1>
|
29 |
"""
|
30 |
|
31 |
+
SAM = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
|
32 |
+
VISUALIZER = Visualizer()
|
33 |
|
34 |
|
35 |
def inference(image: np.ndarray, annotation_mode: List[str]) -> np.ndarray:
|
36 |
+
mask_generator = SamAutomaticMaskGenerator(SAM)
|
37 |
+
result = mask_generator.generate(image=image)
|
38 |
+
detections = sv.Detections.from_sam(result)
|
39 |
+
detections = postprocess_masks(
|
40 |
+
detections=detections,
|
41 |
+
area_threshold=MINIMUM_AREA_THRESHOLD)
|
42 |
+
bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
43 |
+
annotated_image = VISUALIZER.visualize(
|
44 |
+
image=bgr_image,
|
45 |
+
detections=detections,
|
46 |
+
with_box="Box" in annotation_mode,
|
47 |
+
with_mask="Mask" in annotation_mode,
|
48 |
+
with_polygon="Polygon" in annotation_mode,
|
49 |
+
with_label="Mark" in annotation_mode)
|
50 |
+
return cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
|
51 |
|
52 |
|
53 |
image_input = gr.Image(
|
54 |
label="Input",
|
55 |
+
type="numpy",
|
56 |
+
height=512)
|
57 |
checkbox_annotation_mode = gr.CheckboxGroup(
|
58 |
+
choices=["Mark", "Polygon", "Mask", "Box"],
|
59 |
value=['Mark'],
|
60 |
label="Annotation Mode")
|
61 |
image_output = gr.Image(
|
utils.py
CHANGED
@@ -1,6 +1,55 @@
|
|
1 |
import cv2
|
2 |
|
3 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
def refine_mask(
|
@@ -36,4 +85,71 @@ def refine_mask(
|
|
36 |
mask, [contour], -1, (0 if mode == 'islands' else 255), -1
|
37 |
)
|
38 |
|
39 |
-
return np.where(mask > 0, 1, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import cv2
|
2 |
|
3 |
import numpy as np
|
4 |
+
import supervision as sv
|
5 |
+
|
6 |
+
|
7 |
+
class Visualizer:
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
line_thickness: int = 2,
|
12 |
+
mask_opacity: float = 0.1,
|
13 |
+
text_scale: float = 0.5
|
14 |
+
) -> None:
|
15 |
+
self.box_annotator = sv.BoundingBoxAnnotator(
|
16 |
+
color_lookup=sv.ColorLookup.INDEX,
|
17 |
+
thickness=line_thickness)
|
18 |
+
self.mask_annotator = sv.MaskAnnotator(
|
19 |
+
color_lookup=sv.ColorLookup.INDEX,
|
20 |
+
opacity=mask_opacity)
|
21 |
+
self.polygon_annotator = sv.PolygonAnnotator(
|
22 |
+
color_lookup=sv.ColorLookup.INDEX,
|
23 |
+
thickness=line_thickness)
|
24 |
+
self.label_annotator = sv.LabelAnnotator(
|
25 |
+
color_lookup=sv.ColorLookup.INDEX,
|
26 |
+
text_position=sv.Position.CENTER_OF_MASS,
|
27 |
+
text_scale=text_scale)
|
28 |
+
|
29 |
+
def visualize(
|
30 |
+
self,
|
31 |
+
image: np.ndarray,
|
32 |
+
detections: sv.Detections,
|
33 |
+
with_box: bool,
|
34 |
+
with_mask: bool,
|
35 |
+
with_polygon: bool,
|
36 |
+
with_label: bool
|
37 |
+
) -> np.ndarray:
|
38 |
+
annotated_image = image.copy()
|
39 |
+
if with_box:
|
40 |
+
annotated_image = self.box_annotator.annotate(
|
41 |
+
scene=annotated_image, detections=detections)
|
42 |
+
if with_mask:
|
43 |
+
annotated_image = self.mask_annotator.annotate(
|
44 |
+
scene=annotated_image, detections=detections)
|
45 |
+
if with_polygon:
|
46 |
+
annotated_image = self.polygon_annotator.annotate(
|
47 |
+
scene=annotated_image, detections=detections)
|
48 |
+
if with_label:
|
49 |
+
labels = list(map(str, range(len(detections))))
|
50 |
+
annotated_image = self.label_annotator.annotate(
|
51 |
+
scene=annotated_image, detections=detections, labels=labels)
|
52 |
+
return annotated_image
|
53 |
|
54 |
|
55 |
def refine_mask(
|
|
|
85 |
mask, [contour], -1, (0 if mode == 'islands' else 255), -1
|
86 |
)
|
87 |
|
88 |
+
return np.where(mask > 0, 1, 0).astype(bool)
|
89 |
+
|
90 |
+
|
91 |
+
def filter_masks_by_relative_area(
|
92 |
+
masks: np.ndarray,
|
93 |
+
min_relative_area: float = 0.02,
|
94 |
+
max_relative_area: float = 1.0
|
95 |
+
) -> np.ndarray:
|
96 |
+
"""
|
97 |
+
Filters out masks based on their relative area.
|
98 |
+
|
99 |
+
Parameters:
|
100 |
+
masks (np.ndarray): A 3D numpy array where each slice along the third dimension
|
101 |
+
represents a mask.
|
102 |
+
min_relative_area (float): Minimum relative area threshold for keeping a mask.
|
103 |
+
max_relative_area (float): Maximum relative area threshold for keeping a mask.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
np.ndarray: A 3D numpy array of filtered masks.
|
107 |
+
"""
|
108 |
+
mask_areas = masks.sum(axis=(1, 2))
|
109 |
+
total_area = masks.shape[1] * masks.shape[2]
|
110 |
+
relative_areas = mask_areas / total_area
|
111 |
+
min_area_filter = relative_areas >= min_relative_area
|
112 |
+
max_area_filter = relative_areas <= max_relative_area
|
113 |
+
return masks[min_area_filter & max_area_filter]
|
114 |
+
|
115 |
+
|
116 |
+
def postprocess_masks(
|
117 |
+
detections: sv.Detections,
|
118 |
+
area_threshold: float = 0.02,
|
119 |
+
min_relative_area: float = 0.02,
|
120 |
+
max_relative_area: float = 1.0
|
121 |
+
) -> sv.Detections:
|
122 |
+
"""
|
123 |
+
Post-processes the masks of detection objects by removing small islands and filling
|
124 |
+
small holes.
|
125 |
+
|
126 |
+
Parameters:
|
127 |
+
detections (sv.Detections): Detection objects to be filtered.
|
128 |
+
area_threshold (float): Threshold for relative area to remove or fill features.
|
129 |
+
min_relative_area (float): Minimum relative area threshold for detections.
|
130 |
+
max_relative_area (float): Maximum relative area threshold for detections.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
np.ndarray: Post-processed masks.
|
134 |
+
"""
|
135 |
+
masks = detections.mask.copy()
|
136 |
+
for i in range(len(masks)):
|
137 |
+
masks[i] = refine_mask(
|
138 |
+
mask=masks[i],
|
139 |
+
area_threshold=area_threshold,
|
140 |
+
mode='islands'
|
141 |
+
)
|
142 |
+
masks[i] = refine_mask(
|
143 |
+
mask=masks[i],
|
144 |
+
area_threshold=area_threshold,
|
145 |
+
mode='holes'
|
146 |
+
)
|
147 |
+
masks = filter_masks_by_relative_area(
|
148 |
+
masks=masks,
|
149 |
+
min_relative_area=min_relative_area,
|
150 |
+
max_relative_area=max_relative_area)
|
151 |
+
|
152 |
+
return sv.Detections(
|
153 |
+
xyxy=sv.mask_to_xyxy(masks),
|
154 |
+
mask=masks
|
155 |
+
)
|