import unittest import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist import cv2 class TestSegmentAnything2Assist(unittest.TestCase): def setUp(self) -> None: return super().setUp() def tearDown(self) -> None: return super().tearDown() def _loading_all_sam_model_types(self): # Test loading all types of SAM2 models. all_sam_models_type = [ "sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_base_plus", "sam2_hiera_large", ] for sam_model_type in all_sam_models_type: sam_model = SegmentAnything2Assist.SegmentAnything2Assist( sam_model_name=sam_model_type, download=True, device="cpu" ) self.assertEqual(sam_model.is_model_available(), True) sam_model = SegmentAnything2Assist.SegmentAnything2Assist( sam_model_name=sam_model_type, download=False, model_path=f".tmp/checkpoints/{sam_model_type}.pth", device="cpu", ) with self.assertRaises(Exception): sam_model = SegmentAnything2Assist.SegmentAnything2Assist( sam_model_name=sam_model_type, download=False, model_path=".", device="cpu", ) def test_generate_automatic_mask(self): image = cv2.imread("test/assets/liberty.jpg") sam_model = SegmentAnything2Assist.SegmentAnything2Assist( sam_model_name="sam2_hiera_tiny", download=True, device="cpu" ) masks, segmentation_masks, bboxes = sam_model.generate_automatic_masks(image) print(type(masks[0])) print(type(segmentation_masks[0])) print(type(bboxes[0])) self.assertEqual(len(masks), len(segmentation_masks)) self.assertEqual(len(masks), len(bboxes)) # for mask, segmentation_mask, bbox in zip(masks, segmentation_masks, bboxes): self.assertEqual(segmentation_masks[0].shape, image.shape)