SoM / utils.py
SkalskiP's picture
Add IoU filter function and update Dockerfile & app.py
e0f6bc4
raw
history blame
6.73 kB
import cv2
import numpy as np
import supervision as sv
class Visualizer:
def __init__(
self,
line_thickness: int = 2,
mask_opacity: float = 0.1,
text_scale: float = 0.5
) -> None:
self.box_annotator = sv.BoundingBoxAnnotator(
color_lookup=sv.ColorLookup.INDEX,
thickness=line_thickness)
self.mask_annotator = sv.MaskAnnotator(
color_lookup=sv.ColorLookup.INDEX,
opacity=mask_opacity)
self.polygon_annotator = sv.PolygonAnnotator(
color_lookup=sv.ColorLookup.INDEX,
thickness=line_thickness)
self.label_annotator = sv.LabelAnnotator(
color_lookup=sv.ColorLookup.INDEX,
text_position=sv.Position.CENTER_OF_MASS,
text_scale=text_scale)
def visualize(
self,
image: np.ndarray,
detections: sv.Detections,
with_box: bool,
with_mask: bool,
with_polygon: bool,
with_label: bool
) -> np.ndarray:
annotated_image = image.copy()
if with_box:
annotated_image = self.box_annotator.annotate(
scene=annotated_image, detections=detections)
if with_mask:
annotated_image = self.mask_annotator.annotate(
scene=annotated_image, detections=detections)
if with_polygon:
annotated_image = self.polygon_annotator.annotate(
scene=annotated_image, detections=detections)
if with_label:
labels = list(map(str, range(len(detections))))
annotated_image = self.label_annotator.annotate(
scene=annotated_image, detections=detections, labels=labels)
return annotated_image
def refine_mask(
mask: np.ndarray,
area_threshold: float,
mode: str = 'islands'
) -> np.ndarray:
"""
Refines a mask by removing small islands or filling small holes based on area
threshold.
Parameters:
mask (np.ndarray): Input binary mask.
area_threshold (float): Threshold for relative area to remove or fill features.
mode (str): Operation mode ('islands' for removing islands, 'holes' for filling
holes).
Returns:
np.ndarray: Refined binary mask.
"""
mask = np.uint8(mask * 255)
operation = cv2.RETR_EXTERNAL if mode == 'islands' else cv2.RETR_CCOMP
contours, _ = cv2.findContours(
mask, operation, cv2.CHAIN_APPROX_SIMPLE
)
total_area = cv2.countNonZero(mask) if mode == 'islands' else mask.size
for contour in contours:
area = cv2.contourArea(contour)
relative_area = area / total_area
if relative_area < area_threshold:
cv2.drawContours(
mask, [contour], -1, (0 if mode == 'islands' else 255), -1
)
return np.where(mask > 0, 1, 0).astype(bool)
def filter_masks_by_relative_area(
masks: np.ndarray,
min_relative_area: float = 0.02,
max_relative_area: float = 1.0
) -> np.ndarray:
"""
Filters out masks based on their relative area.
Parameters:
masks (np.ndarray): A 3D numpy array where each slice along the third dimension
represents a mask.
min_relative_area (float): Minimum relative area threshold for keeping a mask.
max_relative_area (float): Maximum relative area threshold for keeping a mask.
Returns:
np.ndarray: A 3D numpy array of filtered masks.
"""
mask_areas = masks.sum(axis=(1, 2))
total_area = masks.shape[1] * masks.shape[2]
relative_areas = mask_areas / total_area
min_area_filter = relative_areas >= min_relative_area
max_area_filter = relative_areas <= max_relative_area
return masks[min_area_filter & max_area_filter]
def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
"""
Computes the Intersection over Union (IoU) of two masks.
Parameters:
mask1, mask2 (np.ndarray): Two mask arrays.
Returns:
float: The IoU of the two masks.
"""
intersection = np.logical_and(mask1, mask2).sum()
union = np.logical_or(mask1, mask2).sum()
return intersection / union if union != 0 else 0
def filter_highly_overlapping_masks(
masks: np.ndarray,
iou_threshold: float
) -> np.ndarray:
"""
Removes masks with high overlap from a set of masks.
Parameters:
masks (np.ndarray): A 3D numpy array with shape (N, H, W), where N is the
number of masks, and H and W are the height and width of the masks.
iou_threshold (float): The IoU threshold above which masks will be considered as
overlapping.
Returns:
np.ndarray: A 3D numpy array of masks with highly overlapping masks removed.
"""
num_masks = masks.shape[0]
keep_mask = np.ones(num_masks, dtype=bool)
for i in range(num_masks):
for j in range(i + 1, num_masks):
if not keep_mask[i] or not keep_mask[j]:
continue
iou = compute_iou(masks[i, :, :], masks[j, :, :])
if iou > iou_threshold:
keep_mask[j] = False
return masks[keep_mask]
def postprocess_masks(
detections: sv.Detections,
area_threshold: float = 0.01,
min_relative_area: float = 0.01,
max_relative_area: float = 1.0,
iou_threshold: float = 0.9
) -> sv.Detections:
"""
Post-processes the masks of detection objects by removing small islands and filling
small holes.
Parameters:
detections (sv.Detections): Detection objects to be filtered.
area_threshold (float): Threshold for relative area to remove or fill features.
min_relative_area (float): Minimum relative area threshold for detections.
max_relative_area (float): Maximum relative area threshold for detections.
iou_threshold (float): The IoU threshold above which masks will be considered as
overlapping.
Returns:
np.ndarray: Post-processed masks.
"""
masks = detections.mask.copy()
for i in range(len(masks)):
masks[i] = refine_mask(
mask=masks[i],
area_threshold=area_threshold,
mode='islands'
)
masks[i] = refine_mask(
mask=masks[i],
area_threshold=area_threshold,
mode='holes'
)
masks = filter_masks_by_relative_area(
masks=masks,
min_relative_area=min_relative_area,
max_relative_area=max_relative_area)
masks = filter_highly_overlapping_masks(
masks=masks,
iou_threshold=iou_threshold)
return sv.Detections(
xyxy=sv.mask_to_xyxy(masks),
mask=masks
)