xqt's picture
UPD: added setup.py for installation
e312782
raw
history blame
2.11 kB
import unittest
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
import cv2
class TestSegmentAnything2Assist(unittest.TestCase):
def setUp(self) -> None:
return super().setUp()
def tearDown(self) -> None:
return super().tearDown()
def _loading_all_sam_model_types(self):
# Test loading all types of SAM2 models.
all_sam_models_type = [
"sam2_hiera_tiny",
"sam2_hiera_small",
"sam2_hiera_base_plus",
"sam2_hiera_large",
]
for sam_model_type in all_sam_models_type:
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name=sam_model_type, download=True, device="cpu"
)
self.assertEqual(sam_model.is_model_available(), True)
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name=sam_model_type,
download=False,
model_path=f".tmp/checkpoints/{sam_model_type}.pth",
device="cpu",
)
with self.assertRaises(Exception):
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name=sam_model_type,
download=False,
model_path=".",
device="cpu",
)
def test_generate_automatic_mask(self):
image = cv2.imread("test/assets/liberty.jpg")
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
)
masks, segmentation_masks, bboxes = sam_model.generate_automatic_masks(image)
print(type(masks[0]))
print(type(segmentation_masks[0]))
print(type(bboxes[0]))
self.assertEqual(len(masks), len(segmentation_masks))
self.assertEqual(len(masks), len(bboxes))
# for mask, segmentation_mask, bbox in zip(masks, segmentation_masks, bboxes):
self.assertEqual(segmentation_masks[0].shape, image.shape)