Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,188 Bytes
e312782 f91c3fb e312782 f91c3fb e312782 f91c3fb e312782 f91c3fb e312782 f91c3fb e312782 f91c3fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import unittest
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
import cv2
import numpy
class TestSegmentAnything2Assist(unittest.TestCase):
def setUp(self) -> None:
return super().setUp()
def tearDown(self) -> None:
return super().tearDown()
def _loading_all_sam_model_types(self):
# Test loading all types of SAM2 models.
all_sam_models_type = [
"sam2_hiera_tiny",
"sam2_hiera_small",
"sam2_hiera_base_plus",
"sam2_hiera_large",
]
for sam_model_type in all_sam_models_type:
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name=sam_model_type, download=True, device="cpu"
)
self.assertEqual(sam_model.is_model_available(), True)
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name=sam_model_type,
download=False,
model_path=f".tmp/checkpoints/{sam_model_type}.pth",
device="cpu",
)
with self.assertRaises(Exception):
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name=sam_model_type,
download=False,
model_path=".",
device="cpu",
)
def _generate_automatic_mask(self):
image = cv2.imread("test/assets/liberty.jpg")
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
)
segmentation_masks, bboxes, predicted_iou, stability_score = (
sam_model.generate_automatic_masks(image)
)
self.assertEqual(len(segmentation_masks.shape), 4)
self.assertEqual(segmentation_masks[0].shape, image.shape)
self.assertEqual(segmentation_masks.shape[3], 3)
self.assertEqual(type(segmentation_masks[0][0][0][0]), numpy.uint8)
self.assertEqual(len(bboxes.shape), 2)
self.assertEqual(bboxes[0].shape, (4,))
self.assertEqual(type(bboxes[0][0]), numpy.uint32)
self.assertEqual(len(predicted_iou.shape), 1)
self.assertEqual(type(predicted_iou[0]), numpy.float32)
self.assertEqual(len(stability_score.shape), 1)
self.assertEqual(type(stability_score[0]), numpy.float32)
for segmentation_mask in segmentation_masks:
self.assertEqual(segmentation_mask.shape, image.shape)
def test_generate_masks_from_image(self):
image = cv2.imread("test/assets/liberty.jpg")
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
)
mask_chw, mask_iou = sam_model.generate_masks_from_image(
image, None, None, None
)
self.assertEqual(len(mask_chw.shape), 3)
self.assertEqual(mask_chw[0].shape, image.shape)
self.assertEqual(mask_chw.shape[0], 1)
self.assertEqual(len(mask_iou.shape), 1)
self.assertEqual(mask_iou.shape[0], 1)
|