|
|
|
from unittest import TestCase |
|
|
|
import torch |
|
from mmengine import ConfigDict, DefaultScope |
|
|
|
from mmdet3d.models import Seg3DTTAModel |
|
from mmdet3d.registry import MODELS |
|
from mmdet3d.structures import Det3DDataSample |
|
from mmdet3d.testing import get_detector_cfg |
|
|
|
|
|
class TestSeg3DTTAModel(TestCase): |
|
|
|
def test_seg3d_tta_model(self): |
|
import mmdet3d.models |
|
|
|
assert hasattr(mmdet3d.models, 'Cylinder3D') |
|
DefaultScope.get_instance('test_cylinder3d', scope_name='mmdet3d') |
|
segmentor3d_cfg = get_detector_cfg( |
|
'cylinder3d/cylinder3d_4xb4-3x_semantickitti.py') |
|
cfg = ConfigDict(type='Seg3DTTAModel', module=segmentor3d_cfg) |
|
|
|
model: Seg3DTTAModel = MODELS.build(cfg) |
|
|
|
points = [] |
|
data_samples = [] |
|
pcd_horizontal_flip_list = [False, False, True, True] |
|
pcd_vertical_flip_list = [False, True, False, True] |
|
for i in range(4): |
|
points.append({'points': [torch.randn(200, 4)]}) |
|
data_samples.append([ |
|
Det3DDataSample( |
|
metainfo=dict( |
|
pcd_horizontal_flip=pcd_horizontal_flip_list[i], |
|
pcd_vertical_flip=pcd_vertical_flip_list[i])) |
|
]) |
|
if torch.cuda.is_available(): |
|
model.eval().cuda() |
|
model.test_step(dict(inputs=points, data_samples=data_samples)) |
|
|