Segment-Anything-2-Assist / src /SegmentAnything2Assist.py
xqt's picture
REF: Uses internal variable for auto mask and image segmentation.
f3d3559
raw
history blame
9.3 kB
import typing
import os
import sam2.sam2_image_predictor
import tqdm
import requests
import torch
import numpy
import pickle
import sam2.build_sam
import sam2.automatic_mask_generator
import cv2
SAM2_MODELS = {
"sam2_hiera_tiny": {
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
"model_path": ".tmp/checkpoints/sam2_hiera_tiny.pt",
"config_file": "sam2_hiera_t.yaml",
},
"sam2_hiera_small": {
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
"model_path": ".tmp/checkpoints/sam2_hiera_small.pt",
"config_file": "sam2_hiera_s.yaml",
},
"sam2_hiera_base_plus": {
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
"model_path": ".tmp/checkpoints/sam2_hiera_base_plus.pt",
"config_file": "sam2_hiera_b+.yaml",
},
"sam2_hiera_large": {
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
"model_path": ".tmp/checkpoints/sam2_hiera_large.pt",
"config_file": "sam2_hiera_l.yaml",
},
}
class SegmentAnything2Assist:
def __init__(
self,
model_name: (
str
| typing.Literal[
"sam2_hiera_tiny",
"sam2_hiera_small",
"sam2_hiera_base_plus",
"sam2_hiera_large",
]
) = "sam2_hiera_small",
configuration: (
str | typing.Literal["Automatic Mask Generator", "Image"]
) = "Automatic Mask Generator",
download_url: str | None = None,
model_path: str | None = None,
download: bool = True,
device: str | torch.device = torch.device("cpu"),
verbose: bool = True,
) -> None:
assert (
model_name in SAM2_MODELS.keys()
), f"`model_name` should be either one of {list(SAM2_MODELS.keys())}"
assert configuration in ["Automatic Mask Generator", "Image"]
self.model_name = model_name
self.configuration = configuration
self.config_file = SAM2_MODELS[model_name]["config_file"]
self.device = device
self.download_url = (
download_url
if download_url is not None
else SAM2_MODELS[model_name]["download_url"]
)
self.model_path = (
model_path
if model_path is not None
else SAM2_MODELS[model_name]["model_path"]
)
os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
self.verbose = verbose
if self.verbose:
print(f"SegmentAnything2Assist::__init__::Model Name: {self.model_name}")
print(
f"SegmentAnything2Assist::__init__::Configuration: {self.configuration}"
)
print(
f"SegmentAnything2Assist::__init__::Download URL: {self.download_url}"
)
print(f"SegmentAnything2Assist::__init__::Default Path: {self.model_path}")
print(
f"SegmentAnything2Assist::__init__::Configuration File: {self.config_file}"
)
if download:
self.download_model()
if self.is_model_available():
self.sam2 = sam2.build_sam.build_sam2(
config_file=self.config_file,
ckpt_path=self.model_path,
device=self.device,
)
if self.verbose:
print("SegmentAnything2Assist::__init__::SAM2 is loaded.")
else:
self.sam2 = None
if self.verbose:
print("SegmentAnything2Assist::__init__::SAM2 is not loaded.")
def is_model_available(self) -> bool:
ret = os.path.exists(self.model_path)
if self.verbose:
print(f"SegmentAnything2Assist::is_model_available::{ret}")
return ret
def load_model(self) -> None:
if self.is_model_available():
self.sam2 = sam2.build_sam(checkpoint=self.model_path)
def download_model(self, force: bool = False) -> None:
if not force and self.is_model_available():
print(f"{self.model_path} already exists. Skipping download.")
return
response = requests.get(self.download_url, stream=True)
total_size = int(response.headers.get("content-length", 0))
with open(self.model_path, "wb") as file, tqdm.tqdm(
total=total_size, unit="B", unit_scale=True
) as progress_bar:
for data in response.iter_content(chunk_size=1024):
file.write(data)
progress_bar.update(len(data))
def generate_automatic_masks(
self,
image,
points_per_side=32,
points_per_batch=32,
pred_iou_thresh=0.8,
stability_score_thresh=0.95,
stability_score_offset=1.0,
mask_threshold=0.0,
box_nms_thresh=0.7,
crop_n_layers=0,
crop_nms_thresh=0.7,
crop_overlay_ratio=512 / 1500,
crop_n_points_downscale_factor=1,
min_mask_region_area=0,
use_m2m=False,
multimask_output=True,
):
if self.sam2 is None:
print(
"SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded."
)
return None
generator = sam2.automatic_mask_generator.SAM2AutomaticMaskGenerator(
model=self.sam2,
points_per_side=points_per_side,
points_per_batch=points_per_batch,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
stability_score_offset=stability_score_offset,
mask_threshold=mask_threshold,
box_nms_thresh=box_nms_thresh,
crop_n_layers=crop_n_layers,
crop_nms_thresh=crop_nms_thresh,
crop_overlay_ratio=crop_overlay_ratio,
crop_n_points_downscale_factor=crop_n_points_downscale_factor,
min_mask_region_area=min_mask_region_area,
use_m2m=use_m2m,
multimask_output=multimask_output,
)
masks = generator.generate(image)
segmentation_masks = [mask for mask in masks]
segmentation_masks = [
numpy.where(mask["segmentation"] == True, 255, 0).astype(numpy.uint8)
for mask in segmentation_masks
]
segmentation_masks = [
cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) for mask in segmentation_masks
]
bbox_masks = [mask["bbox"] for mask in masks]
return masks, segmentation_masks, bbox_masks
def generate_masks_from_image(
self,
image,
point_coords,
point_labels,
box,
mask_threshold=0.0,
max_hole_area=0.0,
max_sprinkle_area=0.0,
):
generator = sam2.sam2_image_predictor.SAM2ImagePredictor(
self.sam2,
mask_threshold=mask_threshold,
max_hole_area=max_hole_area,
max_sprinkle_area=max_sprinkle_area,
)
generator.set_image(image)
masks_chw, mask_iou, mask_low_logits = generator.predict(
point_coords=(
numpy.array(point_coords) if point_coords is not None else None
),
point_labels=(
numpy.array(point_labels) if point_labels is not None else None
),
box=numpy.array(box) if box is not None else None,
multimask_output=False,
)
return masks_chw, mask_iou
def apply_mask_to_image(self, image, mask):
mask = numpy.array(mask)
mask = numpy.where(mask > 0, 255, 0).astype(numpy.uint8)
segment = cv2.bitwise_and(image, image, mask=mask)
return mask, segment
def apply_auto_mask_to_image(self, image, auto_list, masks, bboxes):
image_with_bounding_boxes = image.copy()
all_masks = None
cv2.imwrite(".tmp/mask_2.png", masks[3])
for _ in auto_list:
mask = masks[_]
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
bbox = bboxes[_]
if all_masks is None:
all_masks = mask
else:
all_masks = cv2.bitwise_or(all_masks, mask)
cv2.imwrite(".tmp/mask_3.png", masks[3])
random_color = numpy.random.randint(0, 255, size=3)
image_with_bounding_boxes = cv2.rectangle(
image_with_bounding_boxes,
(int(bbox[0]), int(bbox[1])),
(int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])),
random_color.tolist(),
2,
)
image_with_bounding_boxes = cv2.putText(
image_with_bounding_boxes,
f"{_ + 1}",
(int(bbox[0]), int(bbox[1]) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
random_color.tolist(),
2,
)
all_masks = all_masks.astype(numpy.uint8)
image_with_segments = cv2.bitwise_and(image, image, mask=all_masks)
return image_with_bounding_boxes, all_masks, image_with_segments