|
|
|
import unittest |
|
|
|
import numpy as np |
|
import torch |
|
from mmengine.testing import assert_allclose |
|
|
|
from mmdet3d.datasets import SUNRGBDDataset |
|
from mmdet3d.structures import DepthInstance3DBoxes |
|
|
|
|
|
def _generate_scannet_dataset_config(): |
|
data_root = 'tests/data/sunrgbd' |
|
ann_file = 'sunrgbd_infos.pkl' |
|
|
|
classes = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser', |
|
'night_stand', 'bookshelf', 'bathtub') |
|
|
|
from mmcv.transforms.base import BaseTransform |
|
from mmengine.registry import TRANSFORMS |
|
|
|
if 'Identity' not in TRANSFORMS: |
|
|
|
@TRANSFORMS.register_module() |
|
class Identity(BaseTransform): |
|
|
|
def transform(self, info): |
|
if 'ann_info' in info: |
|
info['gt_labels_3d'] = info['ann_info']['gt_labels_3d'] |
|
return info |
|
|
|
modality = dict(use_camera=True, use_lidar=True) |
|
pipeline = [ |
|
dict(type='Identity'), |
|
] |
|
data_prefix = dict(pts='points', img='sunrgbd_trainval') |
|
return data_root, ann_file, classes, data_prefix, pipeline, modality |
|
|
|
|
|
class TestScanNetDataset(unittest.TestCase): |
|
|
|
def test_sunrgbd_ataset(self): |
|
np.random.seed(0) |
|
data_root, ann_file, classes, data_prefix, \ |
|
pipeline, modality, = _generate_scannet_dataset_config() |
|
scannet_dataset = SUNRGBDDataset( |
|
data_root, |
|
ann_file, |
|
data_prefix=data_prefix, |
|
pipeline=pipeline, |
|
metainfo=dict(classes=classes), |
|
modality=modality) |
|
|
|
scannet_dataset.prepare_data(0) |
|
input_dict = scannet_dataset.get_data_info(0) |
|
scannet_dataset[0] |
|
|
|
assert data_prefix['pts'] in input_dict['lidar_points']['lidar_path'] |
|
assert data_root in input_dict['lidar_points']['lidar_path'] |
|
for cam_id, img_info in input_dict['images'].items(): |
|
if 'img_path' in img_info: |
|
assert data_prefix['img'] in img_info['img_path'] |
|
assert data_root in img_info['img_path'] |
|
|
|
ann_info = scannet_dataset.parse_ann_info(input_dict) |
|
|
|
|
|
except_label = np.array([0, 7, 6]) |
|
|
|
self.assertEqual(ann_info['gt_labels_3d'].dtype, np.int64) |
|
assert_allclose(ann_info['gt_labels_3d'], except_label) |
|
self.assertIsInstance(ann_info['gt_bboxes_3d'], DepthInstance3DBoxes) |
|
|
|
self.assertEqual(len(ann_info['gt_bboxes_3d']), 3) |
|
assert_allclose(ann_info['gt_bboxes_3d'].tensor.sum(), |
|
torch.tensor(19.2575)) |
|
|
|
classes = ['bed'] |
|
bed_scannet_dataset = SUNRGBDDataset( |
|
data_root, |
|
ann_file, |
|
data_prefix=data_prefix, |
|
pipeline=pipeline, |
|
metainfo=dict(classes=classes), |
|
modality=modality) |
|
|
|
input_dict = bed_scannet_dataset.get_data_info(0) |
|
ann_info = bed_scannet_dataset.parse_ann_info(input_dict) |
|
|
|
|
|
self.assertIn('gt_labels_3d', ann_info) |
|
|
|
assert (ann_info['gt_labels_3d'] <= 0).all() |
|
assert ann_info['gt_labels_3d'].dtype == np.int64 |
|
|
|
self.assertEqual(len(ann_info['gt_labels_3d']), 3) |
|
self.assertEqual(len(bed_scannet_dataset.metainfo['classes']), 1) |
|
|