xqt's picture
REF: SAM2 AMG and the corresponding test case.
f91c3fb
raw
history blame
3.19 kB
import unittest
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
import cv2
import numpy
class TestSegmentAnything2Assist(unittest.TestCase):
def setUp(self) -> None:
return super().setUp()
def tearDown(self) -> None:
return super().tearDown()
def _loading_all_sam_model_types(self):
# Test loading all types of SAM2 models.
all_sam_models_type = [
"sam2_hiera_tiny",
"sam2_hiera_small",
"sam2_hiera_base_plus",
"sam2_hiera_large",
]
for sam_model_type in all_sam_models_type:
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name=sam_model_type, download=True, device="cpu"
)
self.assertEqual(sam_model.is_model_available(), True)
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name=sam_model_type,
download=False,
model_path=f".tmp/checkpoints/{sam_model_type}.pth",
device="cpu",
)
with self.assertRaises(Exception):
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name=sam_model_type,
download=False,
model_path=".",
device="cpu",
)
def _generate_automatic_mask(self):
image = cv2.imread("test/assets/liberty.jpg")
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
)
segmentation_masks, bboxes, predicted_iou, stability_score = (
sam_model.generate_automatic_masks(image)
)
self.assertEqual(len(segmentation_masks.shape), 4)
self.assertEqual(segmentation_masks[0].shape, image.shape)
self.assertEqual(segmentation_masks.shape[3], 3)
self.assertEqual(type(segmentation_masks[0][0][0][0]), numpy.uint8)
self.assertEqual(len(bboxes.shape), 2)
self.assertEqual(bboxes[0].shape, (4,))
self.assertEqual(type(bboxes[0][0]), numpy.uint32)
self.assertEqual(len(predicted_iou.shape), 1)
self.assertEqual(type(predicted_iou[0]), numpy.float32)
self.assertEqual(len(stability_score.shape), 1)
self.assertEqual(type(stability_score[0]), numpy.float32)
for segmentation_mask in segmentation_masks:
self.assertEqual(segmentation_mask.shape, image.shape)
def test_generate_masks_from_image(self):
image = cv2.imread("test/assets/liberty.jpg")
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
)
mask_chw, mask_iou = sam_model.generate_masks_from_image(
image, None, None, None
)
self.assertEqual(len(mask_chw.shape), 3)
self.assertEqual(mask_chw[0].shape, image.shape)
self.assertEqual(mask_chw.shape[0], 1)
self.assertEqual(len(mask_iou.shape), 1)
self.assertEqual(mask_iou.shape[0], 1)