SkalskiP commited on
Commit
1a1d05a
1 Parent(s): 242f627

Added 'Visualizer' class and mask refinement to utils.py

Browse files

This commit includes the addition of the 'Visualizer' class in utils.py, which provides several methods to annotate images with bounding boxes, masks, polygons, and labels. It also included new utility functions for refining and filtering masks based on their relative area.

Files changed (2) hide show
  1. app.py +22 -6
  2. utils.py +117 -1
app.py CHANGED
@@ -8,10 +8,11 @@ import supervision as sv
8
 
9
  from typing import List
10
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
11
- from utils import refine_mask
12
 
13
  HOME = os.getenv("HOME")
14
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
 
15
 
16
  SAM_CHECKPOINT = os.path.join(HOME, "app/weights/sam_vit_h_4b8939.pth")
17
  # SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
@@ -27,19 +28,34 @@ MARKDOWN = """
27
  </h1>
28
  """
29
 
30
- sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
31
- mask_generator = SamAutomaticMaskGenerator(sam)
32
 
33
 
34
  def inference(image: np.ndarray, annotation_mode: List[str]) -> np.ndarray:
35
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  image_input = gr.Image(
39
  label="Input",
40
- type="numpy")
 
41
  checkbox_annotation_mode = gr.CheckboxGroup(
42
- choices=["Mark", "Mask", "Box"],
43
  value=['Mark'],
44
  label="Annotation Mode")
45
  image_output = gr.Image(
 
8
 
9
  from typing import List
10
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
11
+ from utils import postprocess_masks, Visualizer
12
 
13
  HOME = os.getenv("HOME")
14
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15
+ MINIMUM_AREA_THRESHOLD = 0.01
16
 
17
  SAM_CHECKPOINT = os.path.join(HOME, "app/weights/sam_vit_h_4b8939.pth")
18
  # SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
 
28
  </h1>
29
  """
30
 
31
+ SAM = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
32
+ VISUALIZER = Visualizer()
33
 
34
 
35
  def inference(image: np.ndarray, annotation_mode: List[str]) -> np.ndarray:
36
+ mask_generator = SamAutomaticMaskGenerator(SAM)
37
+ result = mask_generator.generate(image=image)
38
+ detections = sv.Detections.from_sam(result)
39
+ detections = postprocess_masks(
40
+ detections=detections,
41
+ area_threshold=MINIMUM_AREA_THRESHOLD)
42
+ bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
43
+ annotated_image = VISUALIZER.visualize(
44
+ image=bgr_image,
45
+ detections=detections,
46
+ with_box="Box" in annotation_mode,
47
+ with_mask="Mask" in annotation_mode,
48
+ with_polygon="Polygon" in annotation_mode,
49
+ with_label="Mark" in annotation_mode)
50
+ return cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
51
 
52
 
53
  image_input = gr.Image(
54
  label="Input",
55
+ type="numpy",
56
+ height=512)
57
  checkbox_annotation_mode = gr.CheckboxGroup(
58
+ choices=["Mark", "Polygon", "Mask", "Box"],
59
  value=['Mark'],
60
  label="Annotation Mode")
61
  image_output = gr.Image(
utils.py CHANGED
@@ -1,6 +1,55 @@
1
  import cv2
2
 
3
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def refine_mask(
@@ -36,4 +85,71 @@ def refine_mask(
36
  mask, [contour], -1, (0 if mode == 'islands' else 255), -1
37
  )
38
 
39
- return np.where(mask > 0, 1, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
 
3
  import numpy as np
4
+ import supervision as sv
5
+
6
+
7
+ class Visualizer:
8
+
9
+ def __init__(
10
+ self,
11
+ line_thickness: int = 2,
12
+ mask_opacity: float = 0.1,
13
+ text_scale: float = 0.5
14
+ ) -> None:
15
+ self.box_annotator = sv.BoundingBoxAnnotator(
16
+ color_lookup=sv.ColorLookup.INDEX,
17
+ thickness=line_thickness)
18
+ self.mask_annotator = sv.MaskAnnotator(
19
+ color_lookup=sv.ColorLookup.INDEX,
20
+ opacity=mask_opacity)
21
+ self.polygon_annotator = sv.PolygonAnnotator(
22
+ color_lookup=sv.ColorLookup.INDEX,
23
+ thickness=line_thickness)
24
+ self.label_annotator = sv.LabelAnnotator(
25
+ color_lookup=sv.ColorLookup.INDEX,
26
+ text_position=sv.Position.CENTER_OF_MASS,
27
+ text_scale=text_scale)
28
+
29
+ def visualize(
30
+ self,
31
+ image: np.ndarray,
32
+ detections: sv.Detections,
33
+ with_box: bool,
34
+ with_mask: bool,
35
+ with_polygon: bool,
36
+ with_label: bool
37
+ ) -> np.ndarray:
38
+ annotated_image = image.copy()
39
+ if with_box:
40
+ annotated_image = self.box_annotator.annotate(
41
+ scene=annotated_image, detections=detections)
42
+ if with_mask:
43
+ annotated_image = self.mask_annotator.annotate(
44
+ scene=annotated_image, detections=detections)
45
+ if with_polygon:
46
+ annotated_image = self.polygon_annotator.annotate(
47
+ scene=annotated_image, detections=detections)
48
+ if with_label:
49
+ labels = list(map(str, range(len(detections))))
50
+ annotated_image = self.label_annotator.annotate(
51
+ scene=annotated_image, detections=detections, labels=labels)
52
+ return annotated_image
53
 
54
 
55
  def refine_mask(
 
85
  mask, [contour], -1, (0 if mode == 'islands' else 255), -1
86
  )
87
 
88
+ return np.where(mask > 0, 1, 0).astype(bool)
89
+
90
+
91
+ def filter_masks_by_relative_area(
92
+ masks: np.ndarray,
93
+ min_relative_area: float = 0.02,
94
+ max_relative_area: float = 1.0
95
+ ) -> np.ndarray:
96
+ """
97
+ Filters out masks based on their relative area.
98
+
99
+ Parameters:
100
+ masks (np.ndarray): A 3D numpy array where each slice along the third dimension
101
+ represents a mask.
102
+ min_relative_area (float): Minimum relative area threshold for keeping a mask.
103
+ max_relative_area (float): Maximum relative area threshold for keeping a mask.
104
+
105
+ Returns:
106
+ np.ndarray: A 3D numpy array of filtered masks.
107
+ """
108
+ mask_areas = masks.sum(axis=(1, 2))
109
+ total_area = masks.shape[1] * masks.shape[2]
110
+ relative_areas = mask_areas / total_area
111
+ min_area_filter = relative_areas >= min_relative_area
112
+ max_area_filter = relative_areas <= max_relative_area
113
+ return masks[min_area_filter & max_area_filter]
114
+
115
+
116
+ def postprocess_masks(
117
+ detections: sv.Detections,
118
+ area_threshold: float = 0.02,
119
+ min_relative_area: float = 0.02,
120
+ max_relative_area: float = 1.0
121
+ ) -> sv.Detections:
122
+ """
123
+ Post-processes the masks of detection objects by removing small islands and filling
124
+ small holes.
125
+
126
+ Parameters:
127
+ detections (sv.Detections): Detection objects to be filtered.
128
+ area_threshold (float): Threshold for relative area to remove or fill features.
129
+ min_relative_area (float): Minimum relative area threshold for detections.
130
+ max_relative_area (float): Maximum relative area threshold for detections.
131
+
132
+ Returns:
133
+ np.ndarray: Post-processed masks.
134
+ """
135
+ masks = detections.mask.copy()
136
+ for i in range(len(masks)):
137
+ masks[i] = refine_mask(
138
+ mask=masks[i],
139
+ area_threshold=area_threshold,
140
+ mode='islands'
141
+ )
142
+ masks[i] = refine_mask(
143
+ mask=masks[i],
144
+ area_threshold=area_threshold,
145
+ mode='holes'
146
+ )
147
+ masks = filter_masks_by_relative_area(
148
+ masks=masks,
149
+ min_relative_area=min_relative_area,
150
+ max_relative_area=max_relative_area)
151
+
152
+ return sv.Detections(
153
+ xyxy=sv.mask_to_xyxy(masks),
154
+ mask=masks
155
+ )