File size: 6,949 Bytes
34d1f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import numpy as np
import pytest
import torch
from mmengine.structures import InstanceData

from mmdet3d.structures import Det3DDataSample, PointData


def _equal(a, b):
    if isinstance(a, (torch.Tensor, np.ndarray)):
        return (a == b).all()
    else:
        return a == b


class TestDet3DDataSample(TestCase):

    def test_init(self):
        meta_info = dict(
            img_size=[256, 256],
            scale_factor=np.array([1.5, 1.5]),
            img_shape=torch.rand(4))

        det3d_data_sample = Det3DDataSample(metainfo=meta_info)
        assert 'img_size' in det3d_data_sample
        assert det3d_data_sample.img_size == [256, 256]
        assert det3d_data_sample.get('img_size') == [256, 256]

    def test_setter(self):
        det3d_data_sample = Det3DDataSample()
        # test gt_instances_3d
        gt_instances_3d_data = dict(
            bboxes_3d=torch.rand(4, 7), labels_3d=torch.rand(4))
        gt_instances_3d = InstanceData(**gt_instances_3d_data)
        det3d_data_sample.gt_instances_3d = gt_instances_3d
        assert 'gt_instances_3d' in det3d_data_sample
        assert _equal(det3d_data_sample.gt_instances_3d.bboxes_3d,
                      gt_instances_3d_data['bboxes_3d'])
        assert _equal(det3d_data_sample.gt_instances_3d.labels_3d,
                      gt_instances_3d_data['labels_3d'])

        # test pred_instances_3d
        pred_instances_3d_data = dict(
            bboxes_3d=torch.rand(2, 7),
            labels_3d=torch.rand(2),
            scores_3d=torch.rand(2))
        pred_instances_3d = InstanceData(**pred_instances_3d_data)
        det3d_data_sample.pred_instances_3d = pred_instances_3d
        assert 'pred_instances_3d' in det3d_data_sample
        assert _equal(det3d_data_sample.pred_instances_3d.bboxes_3d,
                      pred_instances_3d_data['bboxes_3d'])
        assert _equal(det3d_data_sample.pred_instances_3d.labels_3d,
                      pred_instances_3d_data['labels_3d'])
        assert _equal(det3d_data_sample.pred_instances_3d.scores_3d,
                      pred_instances_3d_data['scores_3d'])

        # test pts_pred_instances_3d
        pts_pred_instances_3d_data = dict(
            bboxes_3d=torch.rand(2, 7),
            labels_3d=torch.rand(2),
            scores_3d=torch.rand(2))
        pts_pred_instances_3d = InstanceData(**pts_pred_instances_3d_data)
        det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
        assert 'pts_pred_instances_3d' in det3d_data_sample
        assert _equal(det3d_data_sample.pts_pred_instances_3d.bboxes_3d,
                      pts_pred_instances_3d_data['bboxes_3d'])
        assert _equal(det3d_data_sample.pts_pred_instances_3d.labels_3d,
                      pts_pred_instances_3d_data['labels_3d'])
        assert _equal(det3d_data_sample.pts_pred_instances_3d.scores_3d,
                      pts_pred_instances_3d_data['scores_3d'])

        # test img_pred_instances_3d
        img_pred_instances_3d_data = dict(
            bboxes_3d=torch.rand(2, 7),
            labels_3d=torch.rand(2),
            scores_3d=torch.rand(2))
        img_pred_instances_3d = InstanceData(**img_pred_instances_3d_data)
        det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
        assert 'img_pred_instances_3d' in det3d_data_sample
        assert _equal(det3d_data_sample.img_pred_instances_3d.bboxes_3d,
                      img_pred_instances_3d_data['bboxes_3d'])
        assert _equal(det3d_data_sample.img_pred_instances_3d.labels_3d,
                      img_pred_instances_3d_data['labels_3d'])
        assert _equal(det3d_data_sample.img_pred_instances_3d.scores_3d,
                      img_pred_instances_3d_data['scores_3d'])

        # test gt_pts_seg
        gt_pts_seg_data = dict(
            pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
        gt_pts_seg = PointData(**gt_pts_seg_data)
        det3d_data_sample.gt_pts_seg = gt_pts_seg
        assert 'gt_pts_seg' in det3d_data_sample
        assert _equal(det3d_data_sample.gt_pts_seg.pts_instance_mask,
                      gt_pts_seg_data['pts_instance_mask'])
        assert _equal(det3d_data_sample.gt_pts_seg.pts_semantic_mask,
                      gt_pts_seg_data['pts_semantic_mask'])

        # test pred_pts_seg
        pred_pts_seg_data = dict(
            pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
        pred_pts_seg = PointData(**pred_pts_seg_data)
        det3d_data_sample.pred_pts_seg = pred_pts_seg
        assert 'pred_pts_seg' in det3d_data_sample
        assert _equal(det3d_data_sample.pred_pts_seg.pts_instance_mask,
                      pred_pts_seg_data['pts_instance_mask'])
        assert _equal(det3d_data_sample.pred_pts_seg.pts_semantic_mask,
                      pred_pts_seg_data['pts_semantic_mask'])

        # test type error
        with pytest.raises(AssertionError):
            det3d_data_sample.pred_instances_3d = torch.rand(2, 4)

        with pytest.raises(AssertionError):
            det3d_data_sample.pred_pts_seg = torch.rand(20)

    def test_deleter(self):
        tmp_instances_3d_data = dict(
            bboxes_3d=torch.rand(4, 4), labels_3d=torch.rand(4))

        det3d_data_sample = Det3DDataSample()
        gt_instances_3d = InstanceData(data=tmp_instances_3d_data)
        det3d_data_sample.gt_instances_3d = gt_instances_3d
        assert 'gt_instances_3d' in det3d_data_sample
        del det3d_data_sample.gt_instances_3d
        assert 'gt_instances_3d' not in det3d_data_sample

        pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
        det3d_data_sample.pred_instances_3d = pred_instances_3d
        assert 'pred_instances_3d' in det3d_data_sample
        del det3d_data_sample.pred_instances_3d
        assert 'pred_instances_3d' not in det3d_data_sample

        pts_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
        det3d_data_sample.pts_pred_instances_3d = pts_pred_instances_3d
        assert 'pts_pred_instances_3d' in det3d_data_sample
        del det3d_data_sample.pts_pred_instances_3d
        assert 'pts_pred_instances_3d' not in det3d_data_sample

        img_pred_instances_3d = InstanceData(data=tmp_instances_3d_data)
        det3d_data_sample.img_pred_instances_3d = img_pred_instances_3d
        assert 'img_pred_instances_3d' in det3d_data_sample
        del det3d_data_sample.img_pred_instances_3d
        assert 'img_pred_instances_3d' not in det3d_data_sample

        pred_pts_seg_data = dict(
            pts_instance_mask=torch.rand(20), pts_semantic_mask=torch.rand(20))
        pred_pts_seg = PointData(**pred_pts_seg_data)
        det3d_data_sample.pred_pts_seg = pred_pts_seg
        assert 'pred_pts_seg' in det3d_data_sample
        del det3d_data_sample.pred_pts_seg
        assert 'pred_pts_seg' not in det3d_data_sample