xqt's picture
REF: SAM2 AMG and the corresponding test case.
f91c3fb
import typing
import os
import sam2.sam2_image_predictor
import tqdm
import requests
import torch
import numpy
import sam2.build_sam
import sam2.automatic_mask_generator
from .Plugin import YOLOv10Plugin
import cv2
SAM2_MODELS = {
"sam2_hiera_tiny": {
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
"model_path": ".tmp/checkpoints/sam2_hiera_tiny.pt",
"config_file": "sam2_hiera_t.yaml",
},
"sam2_hiera_small": {
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
"model_path": ".tmp/checkpoints/sam2_hiera_small.pt",
"config_file": "sam2_hiera_s.yaml",
},
"sam2_hiera_base_plus": {
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
"model_path": ".tmp/checkpoints/sam2_hiera_base_plus.pt",
"config_file": "sam2_hiera_b+.yaml",
},
"sam2_hiera_large": {
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
"model_path": ".tmp/checkpoints/sam2_hiera_large.pt",
"config_file": "sam2_hiera_l.yaml",
},
}
class SegmentAnything2Assist:
def __init__(
self,
sam_model_name: (
str
| typing.Literal[
"sam2_hiera_tiny",
"sam2_hiera_small",
"sam2_hiera_base_plus",
"sam2_hiera_large",
]
) = "sam2_hiera_small",
configuration: (
str | typing.Literal["Automatic Mask Generator", "Image"]
) = "Automatic Mask Generator",
download_url: str | None = None,
model_path: str | None = None,
download: bool = True,
device: str | torch.device = torch.device("cpu"),
verbose: bool = True,
YOLOv10Model: YOLOv10Plugin.YOLOv10Plugin | None = None,
) -> None:
assert (
sam_model_name in SAM2_MODELS.keys()
), f"`sam_model_name` should be either one of {list(SAM2_MODELS.keys())}"
assert configuration in ["Automatic Mask Generator", "Image"]
self.sam_model_name = sam_model_name
self.configuration = configuration
self.config_file = SAM2_MODELS[sam_model_name]["config_file"]
self.device = device
self.download_url = (
download_url
if download_url is not None
else SAM2_MODELS[sam_model_name]["download_url"]
)
self.model_path = (
model_path
if model_path is not None
else SAM2_MODELS[sam_model_name]["model_path"]
)
os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
self.verbose = verbose
if self.verbose:
print(
f"SegmentAnything2Assist::__init__::Model Name: {self.sam_model_name}"
)
print(
f"SegmentAnything2Assist::__init__::Configuration: {self.configuration}"
)
print(
f"SegmentAnything2Assist::__init__::Download URL: {self.download_url}"
)
print(f"SegmentAnything2Assist::__init__::Default Path: {self.model_path}")
print(
f"SegmentAnything2Assist::__init__::Configuration File: {self.config_file}"
)
if download:
self.__download_model()
if self.is_model_available():
self.sam2 = sam2.build_sam.build_sam2(
config_file=self.config_file,
ckpt_path=self.model_path,
device=self.device,
)
if self.verbose:
print("SegmentAnything2Assist::__init__::SAM2 is loaded.")
else:
self.sam2 = None
if self.verbose:
print("SegmentAnything2Assist::__init__::SAM2 is not loaded.")
self.YOLOv10Model = YOLOv10Model
def is_model_available(self) -> bool:
ret = os.path.exists(self.model_path)
if self.verbose:
print(f"SegmentAnything2Assist::is_model_available::{ret}")
return ret
def __load_model(self) -> bool:
if self.is_model_available():
self.sam2 = sam2.build_sam(checkpoint=self.model_path)
return True
return False
def __download_model(self, force: bool = False) -> bool:
if not force and self.is_model_available():
print(f"{self.model_path} already exists. Skipping download.")
return False
response = requests.get(self.download_url, stream=True)
total_size = int(response.headers.get("content-length", 0))
with open(self.model_path, "wb") as file, tqdm.tqdm(
total=total_size, unit="B", unit_scale=True
) as progress_bar:
for data in response.iter_content(chunk_size=1024):
file.write(data)
progress_bar.update(len(data))
return True
def generate_automatic_masks(
self,
image: numpy.ndarray,
points_per_side=10,
points_per_batch=32,
pred_iou_thresh=0.8,
stability_score_thresh=0.95,
stability_score_offset=1.0,
mask_threshold=0.0,
box_nms_thresh=0.7,
crop_n_layers=0,
crop_nms_thresh=0.7,
crop_overlay_ratio=512 / 1500,
crop_n_points_downscale_factor=1,
min_mask_region_area=0,
use_m2m=False,
multimask_output=True,
) -> typing.Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]:
"""
Generates automatic masks from the given image.
Returns:
typing.Tuple: Four numpy arrays where:
- segmentation_masks: Numpy array shape (N, H, W, C) where N is the number of masks, H is the height of the image, W is the width of the image, and C is the number of channels. Each N is a binary mask of the image of shape (H, W, C).
- bbox_masks: Numpy array of shape (N, 4) where N is the number of masks and 4 is the bounding box coordinates. Each mask is a bounding box of shape (x, y, w, h).
- predicted_iou: Numpy array of shape (N,) where N is the number of masks. Each value is the predicted IOU of the mask.
- stability_score: Numpy array of shape (N,) where N is the number of masks. Each value is the stability score of the mask.
"""
if self.sam2 is None:
print(
"SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded."
)
return None
generator = sam2.automatic_mask_generator.SAM2AutomaticMaskGenerator(
model=self.sam2,
points_per_side=points_per_side,
points_per_batch=points_per_batch,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
stability_score_offset=stability_score_offset,
mask_threshold=mask_threshold,
box_nms_thresh=box_nms_thresh,
crop_n_layers=crop_n_layers,
crop_nms_thresh=crop_nms_thresh,
crop_overlay_ratio=crop_overlay_ratio,
crop_n_points_downscale_factor=crop_n_points_downscale_factor,
min_mask_region_area=min_mask_region_area,
use_m2m=use_m2m,
multimask_output=multimask_output,
)
masks = generator.generate(image)
segmentation_masks = [mask for mask in masks]
segmentation_masks = [
numpy.where(mask["segmentation"] == True, 255, 0).astype(numpy.uint8)
for mask in segmentation_masks
]
segmentation_masks = [
cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) for mask in segmentation_masks
]
bbox_masks = [mask["bbox"] for mask in masks]
predicted_iou = [mask["predicted_iou"] for mask in masks]
stability_score = [mask["stability_score"] for mask in masks]
return (
numpy.array(segmentation_masks, dtype=numpy.uint8),
numpy.array(bbox_masks, dtype=numpy.uint32),
numpy.array(predicted_iou, dtype=numpy.float32),
numpy.array(stability_score, dtype=numpy.float32),
)
def generate_masks_from_image(
self,
image,
point_coords,
point_labels,
box,
mask_threshold=0.0,
max_hole_area=0.0,
max_sprinkle_area=0.0,
) -> typing.Tuple[numpy.ndarray, numpy.ndarray]:
"""
Generates masks from the given image.
Returns:
typing.Tuple: Two numpy arrays where:
- masks_chw: Numpy array shape (1, H, W) for the mask, H is the height of the image, and W is the width of the image.
- mask_iou: Numpy array of shape (1,) for IOU of the mask.
"""
generator = sam2.sam2_image_predictor.SAM2ImagePredictor(
self.sam2,
mask_threshold=mask_threshold,
max_hole_area=max_hole_area,
max_sprinkle_area=max_sprinkle_area,
)
generator.set_image(image)
masks_chw, mask_iou, mask_low_logits = generator.predict(
point_coords=(
numpy.array(point_coords) if point_coords is not None else None
),
point_labels=(
numpy.array(point_labels) if point_labels is not None else None
),
box=numpy.array(box) if box is not None else None,
multimask_output=False,
)
return masks_chw, mask_iou
def apply_mask_to_image(self, image, mask):
mask = numpy.array(mask)
mask = numpy.where(mask > 0, 255, 0).astype(numpy.uint8)
segment = cv2.bitwise_and(image, image, mask=mask)
return mask, segment
def apply_auto_mask_to_image(self, image, auto_list, masks, bboxes):
image_with_bounding_boxes = image.copy()
all_masks = None
for _ in auto_list:
mask = masks[_]
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
bbox = bboxes[_]
if all_masks is None:
all_masks = mask
else:
all_masks = cv2.bitwise_or(all_masks, mask)
random_color = numpy.random.randint(0, 255, size=3)
image_with_bounding_boxes = cv2.rectangle(
image_with_bounding_boxes,
(int(bbox[0]), int(bbox[1])),
(int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])),
random_color.tolist(),
2,
)
image_with_bounding_boxes = cv2.putText(
image_with_bounding_boxes,
f"{_ + 1}",
(int(bbox[0]), int(bbox[1]) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
random_color.tolist(),
2,
)
all_masks = all_masks.astype(numpy.uint8)
image_with_segments = cv2.bitwise_and(image, image, mask=all_masks)
return image_with_bounding_boxes, all_masks, image_with_segments
def generate_mask_from_image_with_yolo(
self,
image,
YOLOv10Model: YOLOv10Plugin.YOLOv10Plugin | None = None,
YOLOv10ModelName: str | None = None,
mask_threshold=0.0,
max_hole_area=0.0,
max_sprinkle_area=0.0,
):
if self.YOLOv10Model is None:
assert bool(YOLOv10Model) != bool(
YOLOv10ModelName
), "Either YOLOv10Model or YOLOv10ModelName should be provided."
if YOLOv10Model is not None:
self.YOLOv10Model = self.YOLOv10Model
if YOLOv10ModelName is not None:
self.YOLOv10Model = YOLOv10Plugin.YOLOv10Plugin(
yolo_model_name=YOLOv10ModelName
)
results = self.YOLOv10Model.detect(image)
for _, result in enumerate(results):
mask_chw, mask_iou = self.generate_masks_from_image(
image,
point_coords=None,
point_labels=None,
box=result["box"],
mask_threshold=mask_threshold,
max_hole_area=max_hole_area,
max_sprinkle_area=max_sprinkle_area,
)
results[_]["mask_chw"] = numpy.squeeze(mask_chw, 0)
results[_]["mask_iou"] = mask_iou
return results