import re from typing import List import cv2 import numpy as np import supervision as sv class Visualizer: def __init__( self, line_thickness: int = 2, mask_opacity: float = 0.1, text_scale: float = 0.5 ) -> None: self.box_annotator = sv.BoundingBoxAnnotator( color_lookup=sv.ColorLookup.INDEX, thickness=line_thickness) self.mask_annotator = sv.MaskAnnotator( color_lookup=sv.ColorLookup.INDEX, opacity=mask_opacity) self.polygon_annotator = sv.PolygonAnnotator( color_lookup=sv.ColorLookup.INDEX, thickness=line_thickness) self.label_annotator = sv.LabelAnnotator( color_lookup=sv.ColorLookup.INDEX, text_position=sv.Position.CENTER_OF_MASS, text_scale=text_scale) def visualize( self, image: np.ndarray, detections: sv.Detections, with_box: bool, with_mask: bool, with_polygon: bool, with_label: bool ) -> np.ndarray: annotated_image = image.copy() if with_box: annotated_image = self.box_annotator.annotate( scene=annotated_image, detections=detections) if with_mask: annotated_image = self.mask_annotator.annotate( scene=annotated_image, detections=detections) if with_polygon: annotated_image = self.polygon_annotator.annotate( scene=annotated_image, detections=detections) if with_label: labels = list(map(str, range(len(detections)))) annotated_image = self.label_annotator.annotate( scene=annotated_image, detections=detections, labels=labels) return annotated_image def refine_mask( mask: np.ndarray, area_threshold: float, mode: str = 'islands' ) -> np.ndarray: """ Refines a mask by removing small islands or filling small holes based on area threshold. Parameters: mask (np.ndarray): Input binary mask. area_threshold (float): Threshold for relative area to remove or fill features. mode (str): Operation mode ('islands' for removing islands, 'holes' for filling holes). Returns: np.ndarray: Refined binary mask. """ mask = np.uint8(mask * 255) operation = cv2.RETR_EXTERNAL if mode == 'islands' else cv2.RETR_CCOMP contours, _ = cv2.findContours( mask, operation, cv2.CHAIN_APPROX_SIMPLE ) total_area = cv2.countNonZero(mask) if mode == 'islands' else mask.size for contour in contours: area = cv2.contourArea(contour) relative_area = area / total_area if relative_area < area_threshold: cv2.drawContours( mask, [contour], -1, (0 if mode == 'islands' else 255), -1 ) return np.where(mask > 0, 1, 0).astype(bool) def filter_masks_by_relative_area( masks: np.ndarray, min_relative_area: float = 0.02, max_relative_area: float = 1.0 ) -> np.ndarray: """ Filters out masks based on their relative area. Parameters: masks (np.ndarray): A 3D numpy array where each slice along the third dimension represents a mask. min_relative_area (float): Minimum relative area threshold for keeping a mask. max_relative_area (float): Maximum relative area threshold for keeping a mask. Returns: np.ndarray: A 3D numpy array of filtered masks. """ mask_areas = masks.sum(axis=(1, 2)) total_area = masks.shape[1] * masks.shape[2] relative_areas = mask_areas / total_area min_area_filter = relative_areas >= min_relative_area max_area_filter = relative_areas <= max_relative_area return masks[min_area_filter & max_area_filter] def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: """ Computes the Intersection over Union (IoU) of two masks. Parameters: mask1, mask2 (np.ndarray): Two mask arrays. Returns: float: The IoU of the two masks. """ intersection = np.logical_and(mask1, mask2).sum() union = np.logical_or(mask1, mask2).sum() return intersection / union if union != 0 else 0 def filter_highly_overlapping_masks( masks: np.ndarray, iou_threshold: float ) -> np.ndarray: """ Removes masks with high overlap from a set of masks. Parameters: masks (np.ndarray): A 3D numpy array with shape (N, H, W), where N is the number of masks, and H and W are the height and width of the masks. iou_threshold (float): The IoU threshold above which masks will be considered as overlapping. Returns: np.ndarray: A 3D numpy array of masks with highly overlapping masks removed. """ num_masks = masks.shape[0] keep_mask = np.ones(num_masks, dtype=bool) for i in range(num_masks): for j in range(i + 1, num_masks): if not keep_mask[i] or not keep_mask[j]: continue iou = compute_iou(masks[i, :, :], masks[j, :, :]) if iou > iou_threshold: keep_mask[j] = False return masks[keep_mask] def postprocess_masks( detections: sv.Detections, area_threshold: float = 0.01, min_relative_area: float = 0.01, max_relative_area: float = 1.0, iou_threshold: float = 0.9 ) -> sv.Detections: """ Post-processes the masks of detection objects by removing small islands and filling small holes. Parameters: detections (sv.Detections): Detection objects to be filtered. area_threshold (float): Threshold for relative area to remove or fill features. min_relative_area (float): Minimum relative area threshold for detections. max_relative_area (float): Maximum relative area threshold for detections. iou_threshold (float): The IoU threshold above which masks will be considered as overlapping. Returns: np.ndarray: Post-processed masks. """ masks = detections.mask.copy() for i in range(len(masks)): masks[i] = refine_mask( mask=masks[i], area_threshold=area_threshold, mode='islands' ) masks[i] = refine_mask( mask=masks[i], area_threshold=area_threshold, mode='holes' ) masks = filter_masks_by_relative_area( masks=masks, min_relative_area=min_relative_area, max_relative_area=max_relative_area) masks = filter_highly_overlapping_masks( masks=masks, iou_threshold=iou_threshold) return sv.Detections( xyxy=sv.mask_to_xyxy(masks), mask=masks ) def extract_numbers_in_brackets(text: str) -> List[int]: """ Extracts all numbers enclosed in square brackets from a given string. Args: text (str): The string to be searched. Returns: List[int]: A list of integers found within square brackets. """ pattern = r'\[(\d+)\]' numbers = [int(num) for num in re.findall(pattern, text)] return numbers