Spaces:
Running
on
Zero
Running
on
Zero
import unittest | |
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist | |
import cv2 | |
class TestSegmentAnything2Assist(unittest.TestCase): | |
def setUp(self) -> None: | |
return super().setUp() | |
def tearDown(self) -> None: | |
return super().tearDown() | |
def _loading_all_sam_model_types(self): | |
# Test loading all types of SAM2 models. | |
all_sam_models_type = [ | |
"sam2_hiera_tiny", | |
"sam2_hiera_small", | |
"sam2_hiera_base_plus", | |
"sam2_hiera_large", | |
] | |
for sam_model_type in all_sam_models_type: | |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
sam_model_name=sam_model_type, download=True, device="cpu" | |
) | |
self.assertEqual(sam_model.is_model_available(), True) | |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
sam_model_name=sam_model_type, | |
download=False, | |
model_path=f".tmp/checkpoints/{sam_model_type}.pth", | |
device="cpu", | |
) | |
with self.assertRaises(Exception): | |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
sam_model_name=sam_model_type, | |
download=False, | |
model_path=".", | |
device="cpu", | |
) | |
def test_generate_automatic_mask(self): | |
image = cv2.imread("test/assets/liberty.jpg") | |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
sam_model_name="sam2_hiera_tiny", download=True, device="cpu" | |
) | |
masks, segmentation_masks, bboxes = sam_model.generate_automatic_masks(image) | |
print(type(masks[0])) | |
print(type(segmentation_masks[0])) | |
print(type(bboxes[0])) | |
self.assertEqual(len(masks), len(segmentation_masks)) | |
self.assertEqual(len(masks), len(bboxes)) | |
# for mask, segmentation_mask, bbox in zip(masks, segmentation_masks, bboxes): | |
self.assertEqual(segmentation_masks[0].shape, image.shape) | |