Spaces:
Running
on
Zero
Running
on
Zero
import typing | |
import os | |
import sam2.sam2_image_predictor | |
import tqdm | |
import requests | |
import torch | |
import numpy | |
import pickle | |
import sam2.build_sam | |
import sam2.automatic_mask_generator | |
import cv2 | |
SAM2_MODELS = { | |
"sam2_hiera_tiny": { | |
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt", | |
"model_path": ".tmp/checkpoints/sam2_hiera_tiny.pt", | |
"config_file": "sam2_hiera_t.yaml", | |
}, | |
"sam2_hiera_small": { | |
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt", | |
"model_path": ".tmp/checkpoints/sam2_hiera_small.pt", | |
"config_file": "sam2_hiera_s.yaml", | |
}, | |
"sam2_hiera_base_plus": { | |
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt", | |
"model_path": ".tmp/checkpoints/sam2_hiera_base_plus.pt", | |
"config_file": "sam2_hiera_b+.yaml", | |
}, | |
"sam2_hiera_large": { | |
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt", | |
"model_path": ".tmp/checkpoints/sam2_hiera_large.pt", | |
"config_file": "sam2_hiera_l.yaml", | |
}, | |
} | |
class SegmentAnything2Assist: | |
def __init__( | |
self, | |
model_name: ( | |
str | |
| typing.Literal[ | |
"sam2_hiera_tiny", | |
"sam2_hiera_small", | |
"sam2_hiera_base_plus", | |
"sam2_hiera_large", | |
] | |
) = "sam2_hiera_small", | |
configuration: ( | |
str | typing.Literal["Automatic Mask Generator", "Image"] | |
) = "Automatic Mask Generator", | |
download_url: str | None = None, | |
model_path: str | None = None, | |
download: bool = True, | |
device: str | torch.device = torch.device("cpu"), | |
verbose: bool = True, | |
) -> None: | |
assert ( | |
model_name in SAM2_MODELS.keys() | |
), f"`model_name` should be either one of {list(SAM2_MODELS.keys())}" | |
assert configuration in ["Automatic Mask Generator", "Image"] | |
self.model_name = model_name | |
self.configuration = configuration | |
self.config_file = SAM2_MODELS[model_name]["config_file"] | |
self.device = device | |
self.download_url = ( | |
download_url | |
if download_url is not None | |
else SAM2_MODELS[model_name]["download_url"] | |
) | |
self.model_path = ( | |
model_path | |
if model_path is not None | |
else SAM2_MODELS[model_name]["model_path"] | |
) | |
os.makedirs(os.path.dirname(self.model_path), exist_ok=True) | |
self.verbose = verbose | |
if self.verbose: | |
print(f"SegmentAnything2Assist::__init__::Model Name: {self.model_name}") | |
print( | |
f"SegmentAnything2Assist::__init__::Configuration: {self.configuration}" | |
) | |
print( | |
f"SegmentAnything2Assist::__init__::Download URL: {self.download_url}" | |
) | |
print(f"SegmentAnything2Assist::__init__::Default Path: {self.model_path}") | |
print( | |
f"SegmentAnything2Assist::__init__::Configuration File: {self.config_file}" | |
) | |
if download: | |
self.download_model() | |
if self.is_model_available(): | |
self.sam2 = sam2.build_sam.build_sam2( | |
config_file=self.config_file, | |
ckpt_path=self.model_path, | |
device=self.device, | |
) | |
if self.verbose: | |
print("SegmentAnything2Assist::__init__::SAM2 is loaded.") | |
else: | |
self.sam2 = None | |
if self.verbose: | |
print("SegmentAnything2Assist::__init__::SAM2 is not loaded.") | |
def is_model_available(self) -> bool: | |
ret = os.path.exists(self.model_path) | |
if self.verbose: | |
print(f"SegmentAnything2Assist::is_model_available::{ret}") | |
return ret | |
def load_model(self) -> None: | |
if self.is_model_available(): | |
self.sam2 = sam2.build_sam(checkpoint=self.model_path) | |
def download_model(self, force: bool = False) -> None: | |
if not force and self.is_model_available(): | |
print(f"{self.model_path} already exists. Skipping download.") | |
return | |
response = requests.get(self.download_url, stream=True) | |
total_size = int(response.headers.get("content-length", 0)) | |
with open(self.model_path, "wb") as file, tqdm.tqdm( | |
total=total_size, unit="B", unit_scale=True | |
) as progress_bar: | |
for data in response.iter_content(chunk_size=1024): | |
file.write(data) | |
progress_bar.update(len(data)) | |
def generate_automatic_masks( | |
self, | |
image, | |
points_per_side=32, | |
points_per_batch=32, | |
pred_iou_thresh=0.8, | |
stability_score_thresh=0.95, | |
stability_score_offset=1.0, | |
mask_threshold=0.0, | |
box_nms_thresh=0.7, | |
crop_n_layers=0, | |
crop_nms_thresh=0.7, | |
crop_overlay_ratio=512 / 1500, | |
crop_n_points_downscale_factor=1, | |
min_mask_region_area=0, | |
use_m2m=False, | |
multimask_output=True, | |
): | |
if self.sam2 is None: | |
print( | |
"SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded." | |
) | |
return None | |
generator = sam2.automatic_mask_generator.SAM2AutomaticMaskGenerator( | |
model=self.sam2, | |
points_per_side=points_per_side, | |
points_per_batch=points_per_batch, | |
pred_iou_thresh=pred_iou_thresh, | |
stability_score_thresh=stability_score_thresh, | |
stability_score_offset=stability_score_offset, | |
mask_threshold=mask_threshold, | |
box_nms_thresh=box_nms_thresh, | |
crop_n_layers=crop_n_layers, | |
crop_nms_thresh=crop_nms_thresh, | |
crop_overlay_ratio=crop_overlay_ratio, | |
crop_n_points_downscale_factor=crop_n_points_downscale_factor, | |
min_mask_region_area=min_mask_region_area, | |
use_m2m=use_m2m, | |
multimask_output=multimask_output, | |
) | |
masks = generator.generate(image) | |
segmentation_masks = [mask for mask in masks] | |
segmentation_masks = [ | |
numpy.where(mask["segmentation"] == True, 255, 0).astype(numpy.uint8) | |
for mask in segmentation_masks | |
] | |
segmentation_masks = [ | |
cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) for mask in segmentation_masks | |
] | |
bbox_masks = [mask["bbox"] for mask in masks] | |
return masks, segmentation_masks, bbox_masks | |
def generate_masks_from_image( | |
self, | |
image, | |
point_coords, | |
point_labels, | |
box, | |
mask_threshold=0.0, | |
max_hole_area=0.0, | |
max_sprinkle_area=0.0, | |
): | |
generator = sam2.sam2_image_predictor.SAM2ImagePredictor( | |
self.sam2, | |
mask_threshold=mask_threshold, | |
max_hole_area=max_hole_area, | |
max_sprinkle_area=max_sprinkle_area, | |
) | |
generator.set_image(image) | |
masks_chw, mask_iou, mask_low_logits = generator.predict( | |
point_coords=( | |
numpy.array(point_coords) if point_coords is not None else None | |
), | |
point_labels=( | |
numpy.array(point_labels) if point_labels is not None else None | |
), | |
box=numpy.array(box) if box is not None else None, | |
multimask_output=False, | |
) | |
return masks_chw, mask_iou | |
def apply_mask_to_image(self, image, mask): | |
mask = numpy.array(mask) | |
mask = numpy.where(mask > 0, 255, 0).astype(numpy.uint8) | |
segment = cv2.bitwise_and(image, image, mask=mask) | |
return mask, segment | |
def apply_auto_mask_to_image(self, image, auto_list, masks, bboxes): | |
image_with_bounding_boxes = image.copy() | |
all_masks = None | |
cv2.imwrite(".tmp/mask_2.png", masks[3]) | |
for _ in auto_list: | |
mask = masks[_] | |
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) | |
bbox = bboxes[_] | |
if all_masks is None: | |
all_masks = mask | |
else: | |
all_masks = cv2.bitwise_or(all_masks, mask) | |
cv2.imwrite(".tmp/mask_3.png", masks[3]) | |
random_color = numpy.random.randint(0, 255, size=3) | |
image_with_bounding_boxes = cv2.rectangle( | |
image_with_bounding_boxes, | |
(int(bbox[0]), int(bbox[1])), | |
(int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])), | |
random_color.tolist(), | |
2, | |
) | |
image_with_bounding_boxes = cv2.putText( | |
image_with_bounding_boxes, | |
f"{_ + 1}", | |
(int(bbox[0]), int(bbox[1]) - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.5, | |
random_color.tolist(), | |
2, | |
) | |
all_masks = all_masks.astype(numpy.uint8) | |
image_with_segments = cv2.bitwise_and(image, image, mask=all_masks) | |
return image_with_bounding_boxes, all_masks, image_with_segments | |