mm3dtest / tests /test_structures /test_det3d_data_sample.py
giantmonkeyTC
2344
34d1f8b
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmengine.structures import InstanceData
from mmdet3d.structures import Det3DDataSample, PointData
def _equal(a, b):
if isinstance(a, (torch.Tensor, np.ndarray)):
return (a == b).all()
else:
return a == b
class TestDet3DDataSample(TestCase):
def test_init(self):
meta_info = dict(
img_size=[256, 256],
scale_factor=np.array([1.5, 1.5]),
img_shape=torch.rand(4))
det3d_data_sample = Det3DDataSample(metainfo=meta_info)
assert 'img_size' in det3d_data_sample
assert det3d_data_sample.img_size == [256, 256]
assert det3d_data_sample.get('img_size') == [256, 256]
def test_setter(self):
det3d_data_sample = Det3DDataSample()
# test gt_instances_3d
gt_instances_3d_data = dict(
bboxes_3d=torch.rand(4, 7), labels_3d=torch.rand(4))
gt_instances_3d = InstanceData(**gt_instances_3d_data)
det3d_data_sample.gt_instances_3d = gt_instances_3d
assert 'gt_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.gt_instances_3d.bboxes_3d,
gt_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.gt_instances_3d.labels_3d,
gt_instances_3d_data['labels_3d'])
# test pred_instances_3d
pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 7),
labels_3d=torch.rand(2),
scores_3d=torch.rand(2))
pred_instances_3d = InstanceData(**pred_instances_3d_data)
det3d_data_sample.pred_instances_3d = pred_instances_3d
assert 'pred_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.pred_instances_3d.bboxes_3d,
pred_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.pred_instances_3d.labels_3d,
pred_instances_3d_data['labels_3d'])
assert _equal(det3d_data_sample.pred_instances_3d.scores_3d,
pred_instances_3d_data['scores_3d'])
# test pts_pred_instances_3d
pts_pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 7),
labels_3d=torch.rand(2),
scores_3d=torch.rand(2))
pts_pred_instances_3d = InstanceData(**pts_pred_instances_3d_data)
det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
assert 'pts_pred_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.pts_pred_instances_3d.bboxes_3d,
pts_pred_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.pts_pred_instances_3d.labels_3d,
pts_pred_instances_3d_data['labels_3d'])
assert _equal(det3d_data_sample.pts_pred_instances_3d.scores_3d,
pts_pred_instances_3d_data['scores_3d'])
# test img_pred_instances_3d
img_pred_instances_3d_data = dict(
bboxes_3d=torch.rand(2, 7),
labels_3d=torch.rand(2),
scores_3d=torch.rand(2))
img_pred_instances_3d = InstanceData(**img_pred_instances_3d_data)
det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
assert 'img_pred_instances_3d' in det3d_data_sample
assert _equal(det3d_data_sample.img_pred_instances_3d.bboxes_3d,
img_pred_instances_3d_data['bboxes_3d'])
assert _equal(det3d_data_sample.img_pred_instances_3d.labels_3d,
img_pred_instances_3d_data['labels_3d'])
assert _equal(det3d_data_sample.img_pred_instances_3d.scores_3d,
img_pred_instances_3d_data['scores_3d'])
# test gt_pts_seg
gt_pts_seg_data = dict(
pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
gt_pts_seg = PointData(**gt_pts_seg_data)
det3d_data_sample.gt_pts_seg = gt_pts_seg
assert 'gt_pts_seg' in det3d_data_sample
assert _equal(det3d_data_sample.gt_pts_seg.pts_instance_mask,
gt_pts_seg_data['pts_instance_mask'])
assert _equal(det3d_data_sample.gt_pts_seg.pts_semantic_mask,
gt_pts_seg_data['pts_semantic_mask'])
# test pred_pts_seg
pred_pts_seg_data = dict(
pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
pred_pts_seg = PointData(**pred_pts_seg_data)
det3d_data_sample.pred_pts_seg = pred_pts_seg
assert 'pred_pts_seg' in det3d_data_sample
assert _equal(det3d_data_sample.pred_pts_seg.pts_instance_mask,
pred_pts_seg_data['pts_instance_mask'])
assert _equal(det3d_data_sample.pred_pts_seg.pts_semantic_mask,
pred_pts_seg_data['pts_semantic_mask'])
# test type error
with pytest.raises(AssertionError):
det3d_data_sample.pred_instances_3d = torch.rand(2, 4)
with pytest.raises(AssertionError):
det3d_data_sample.pred_pts_seg = torch.rand(20)
def test_deleter(self):
tmp_instances_3d_data = dict(
bboxes_3d=torch.rand(4, 4), labels_3d=torch.rand(4))
det3d_data_sample = Det3DDataSample()
gt_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.gt_instances_3d = gt_instances_3d
assert 'gt_instances_3d' in det3d_data_sample
del det3d_data_sample.gt_instances_3d
assert 'gt_instances_3d' not in det3d_data_sample
pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.pred_instances_3d = pred_instances_3d
assert 'pred_instances_3d' in det3d_data_sample
del det3d_data_sample.pred_instances_3d
assert 'pred_instances_3d' not in det3d_data_sample
pts_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
assert 'pts_pred_instances_3d' in det3d_data_sample
del det3d_data_sample.pts_pred_instances_3d
assert 'pts_pred_instances_3d' not in det3d_data_sample
img_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
assert 'img_pred_instances_3d' in det3d_data_sample
del det3d_data_sample.img_pred_instances_3d
assert 'img_pred_instances_3d' not in det3d_data_sample
pred_pts_seg_data = dict(
pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
pred_pts_seg = PointData(**pred_pts_seg_data)
det3d_data_sample.pred_pts_seg = pred_pts_seg
assert 'pred_pts_seg' in det3d_data_sample
del det3d_data_sample.pred_pts_seg
assert 'pred_pts_seg' not in det3d_data_sample