import typing import os import sam2.sam2_image_predictor import tqdm import requests import torch import numpy import sam2.build_sam import sam2.automatic_mask_generator from .Plugin import YOLOv10Plugin import cv2 SAM2_MODELS = { "sam2_hiera_tiny": { "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt", "model_path": ".tmp/checkpoints/sam2_hiera_tiny.pt", "config_file": "sam2_hiera_t.yaml", }, "sam2_hiera_small": { "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt", "model_path": ".tmp/checkpoints/sam2_hiera_small.pt", "config_file": "sam2_hiera_s.yaml", }, "sam2_hiera_base_plus": { "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt", "model_path": ".tmp/checkpoints/sam2_hiera_base_plus.pt", "config_file": "sam2_hiera_b+.yaml", }, "sam2_hiera_large": { "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt", "model_path": ".tmp/checkpoints/sam2_hiera_large.pt", "config_file": "sam2_hiera_l.yaml", }, } class SegmentAnything2Assist: def __init__( self, sam_model_name: ( str | typing.Literal[ "sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_base_plus", "sam2_hiera_large", ] ) = "sam2_hiera_small", configuration: ( str | typing.Literal["Automatic Mask Generator", "Image"] ) = "Automatic Mask Generator", download_url: str | None = None, model_path: str | None = None, download: bool = True, device: str | torch.device = torch.device("cpu"), verbose: bool = True, YOLOv10Model: YOLOv10Plugin.YOLOv10Plugin | None = None, ) -> None: assert ( sam_model_name in SAM2_MODELS.keys() ), f"`sam_model_name` should be either one of {list(SAM2_MODELS.keys())}" assert configuration in ["Automatic Mask Generator", "Image"] self.sam_model_name = sam_model_name self.configuration = configuration self.config_file = SAM2_MODELS[sam_model_name]["config_file"] self.device = device self.download_url = ( download_url if download_url is not None else SAM2_MODELS[sam_model_name]["download_url"] ) self.model_path = ( model_path if model_path is not None else SAM2_MODELS[sam_model_name]["model_path"] ) os.makedirs(os.path.dirname(self.model_path), exist_ok=True) self.verbose = verbose if self.verbose: print( f"SegmentAnything2Assist::__init__::Model Name: {self.sam_model_name}" ) print( f"SegmentAnything2Assist::__init__::Configuration: {self.configuration}" ) print( f"SegmentAnything2Assist::__init__::Download URL: {self.download_url}" ) print(f"SegmentAnything2Assist::__init__::Default Path: {self.model_path}") print( f"SegmentAnything2Assist::__init__::Configuration File: {self.config_file}" ) if download: self.__download_model() if self.is_model_available(): self.sam2 = sam2.build_sam.build_sam2( config_file=self.config_file, ckpt_path=self.model_path, device=self.device, ) if self.verbose: print("SegmentAnything2Assist::__init__::SAM2 is loaded.") else: self.sam2 = None if self.verbose: print("SegmentAnything2Assist::__init__::SAM2 is not loaded.") self.YOLOv10Model = YOLOv10Model def is_model_available(self) -> bool: ret = os.path.exists(self.model_path) if self.verbose: print(f"SegmentAnything2Assist::is_model_available::{ret}") return ret def __load_model(self) -> bool: if self.is_model_available(): self.sam2 = sam2.build_sam(checkpoint=self.model_path) return True return False def __download_model(self, force: bool = False) -> bool: if not force and self.is_model_available(): print(f"{self.model_path} already exists. Skipping download.") return False response = requests.get(self.download_url, stream=True) total_size = int(response.headers.get("content-length", 0)) with open(self.model_path, "wb") as file, tqdm.tqdm( total=total_size, unit="B", unit_scale=True ) as progress_bar: for data in response.iter_content(chunk_size=1024): file.write(data) progress_bar.update(len(data)) return True def generate_automatic_masks( self, image: numpy.ndarray, points_per_side=10, points_per_batch=32, pred_iou_thresh=0.8, stability_score_thresh=0.95, stability_score_offset=1.0, mask_threshold=0.0, box_nms_thresh=0.7, crop_n_layers=0, crop_nms_thresh=0.7, crop_overlay_ratio=512 / 1500, crop_n_points_downscale_factor=1, min_mask_region_area=0, use_m2m=False, multimask_output=True, ) -> typing.Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]: """ Generates automatic masks from the given image. Returns: typing.Tuple: Four numpy arrays where: - segmentation_masks: Numpy array shape (N, H, W, C) where N is the number of masks, H is the height of the image, W is the width of the image, and C is the number of channels. Each N is a binary mask of the image of shape (H, W, C). - bbox_masks: Numpy array of shape (N, 4) where N is the number of masks and 4 is the bounding box coordinates. Each mask is a bounding box of shape (x, y, w, h). - predicted_iou: Numpy array of shape (N,) where N is the number of masks. Each value is the predicted IOU of the mask. - stability_score: Numpy array of shape (N,) where N is the number of masks. Each value is the stability score of the mask. """ if self.sam2 is None: print( "SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded." ) return None generator = sam2.automatic_mask_generator.SAM2AutomaticMaskGenerator( model=self.sam2, points_per_side=points_per_side, points_per_batch=points_per_batch, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh, stability_score_offset=stability_score_offset, mask_threshold=mask_threshold, box_nms_thresh=box_nms_thresh, crop_n_layers=crop_n_layers, crop_nms_thresh=crop_nms_thresh, crop_overlay_ratio=crop_overlay_ratio, crop_n_points_downscale_factor=crop_n_points_downscale_factor, min_mask_region_area=min_mask_region_area, use_m2m=use_m2m, multimask_output=multimask_output, ) masks = generator.generate(image) segmentation_masks = [mask for mask in masks] segmentation_masks = [ numpy.where(mask["segmentation"] == True, 255, 0).astype(numpy.uint8) for mask in segmentation_masks ] segmentation_masks = [ cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) for mask in segmentation_masks ] bbox_masks = [mask["bbox"] for mask in masks] predicted_iou = [mask["predicted_iou"] for mask in masks] stability_score = [mask["stability_score"] for mask in masks] return ( numpy.array(segmentation_masks, dtype=numpy.uint8), numpy.array(bbox_masks, dtype=numpy.uint32), numpy.array(predicted_iou, dtype=numpy.float32), numpy.array(stability_score, dtype=numpy.float32), ) def generate_masks_from_image( self, image, point_coords, point_labels, box, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, ) -> typing.Tuple[numpy.ndarray, numpy.ndarray]: """ Generates masks from the given image. Returns: typing.Tuple: Two numpy arrays where: - masks_chw: Numpy array shape (1, H, W) for the mask, H is the height of the image, and W is the width of the image. - mask_iou: Numpy array of shape (1,) for IOU of the mask. """ generator = sam2.sam2_image_predictor.SAM2ImagePredictor( self.sam2, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area, ) generator.set_image(image) masks_chw, mask_iou, mask_low_logits = generator.predict( point_coords=( numpy.array(point_coords) if point_coords is not None else None ), point_labels=( numpy.array(point_labels) if point_labels is not None else None ), box=numpy.array(box) if box is not None else None, multimask_output=False, ) return masks_chw, mask_iou def apply_mask_to_image(self, image, mask): mask = numpy.array(mask) mask = numpy.where(mask > 0, 255, 0).astype(numpy.uint8) segment = cv2.bitwise_and(image, image, mask=mask) return mask, segment def apply_auto_mask_to_image(self, image, auto_list, masks, bboxes): image_with_bounding_boxes = image.copy() all_masks = None for _ in auto_list: mask = masks[_] mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) bbox = bboxes[_] if all_masks is None: all_masks = mask else: all_masks = cv2.bitwise_or(all_masks, mask) random_color = numpy.random.randint(0, 255, size=3) image_with_bounding_boxes = cv2.rectangle( image_with_bounding_boxes, (int(bbox[0]), int(bbox[1])), (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])), random_color.tolist(), 2, ) image_with_bounding_boxes = cv2.putText( image_with_bounding_boxes, f"{_ + 1}", (int(bbox[0]), int(bbox[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, random_color.tolist(), 2, ) all_masks = all_masks.astype(numpy.uint8) image_with_segments = cv2.bitwise_and(image, image, mask=all_masks) return image_with_bounding_boxes, all_masks, image_with_segments def generate_mask_from_image_with_yolo( self, image, YOLOv10Model: YOLOv10Plugin.YOLOv10Plugin | None = None, YOLOv10ModelName: str | None = None, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, ): if self.YOLOv10Model is None: assert bool(YOLOv10Model) != bool( YOLOv10ModelName ), "Either YOLOv10Model or YOLOv10ModelName should be provided." if YOLOv10Model is not None: self.YOLOv10Model = self.YOLOv10Model if YOLOv10ModelName is not None: self.YOLOv10Model = YOLOv10Plugin.YOLOv10Plugin( yolo_model_name=YOLOv10ModelName ) results = self.YOLOv10Model.detect(image) for _, result in enumerate(results): mask_chw, mask_iou = self.generate_masks_from_image( image, point_coords=None, point_labels=None, box=result["box"], mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area, ) results[_]["mask_chw"] = numpy.squeeze(mask_chw, 0) results[_]["mask_iou"] = mask_iou return results