Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,110 Bytes
e312782 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import unittest
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
import cv2
class TestSegmentAnything2Assist(unittest.TestCase):
def setUp(self) -> None:
return super().setUp()
def tearDown(self) -> None:
return super().tearDown()
def _loading_all_sam_model_types(self):
# Test loading all types of SAM2 models.
all_sam_models_type = [
"sam2_hiera_tiny",
"sam2_hiera_small",
"sam2_hiera_base_plus",
"sam2_hiera_large",
]
for sam_model_type in all_sam_models_type:
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name=sam_model_type, download=True, device="cpu"
)
self.assertEqual(sam_model.is_model_available(), True)
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name=sam_model_type,
download=False,
model_path=f".tmp/checkpoints/{sam_model_type}.pth",
device="cpu",
)
with self.assertRaises(Exception):
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name=sam_model_type,
download=False,
model_path=".",
device="cpu",
)
def test_generate_automatic_mask(self):
image = cv2.imread("test/assets/liberty.jpg")
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
)
masks, segmentation_masks, bboxes = sam_model.generate_automatic_masks(image)
print(type(masks[0]))
print(type(segmentation_masks[0]))
print(type(bboxes[0]))
self.assertEqual(len(masks), len(segmentation_masks))
self.assertEqual(len(masks), len(bboxes))
# for mask, segmentation_mask, bbox in zip(masks, segmentation_masks, bboxes):
self.assertEqual(segmentation_masks[0].shape, image.shape)
|