File size: 1,672 Bytes
34d1f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch

from mmdet3d.models.decode_heads import DGCNNHead
from mmdet3d.structures import Det3DDataSample, PointData


class TestDGCNNHead(TestCase):

    def test_dgcnn_head_loss(self):
        """Tests DGCNN head loss."""

        dgcnn_head = DGCNNHead(
            fp_channels=(1024, 512),
            channels=256,
            num_classes=13,
            dropout_ratio=0.5,
            conv_cfg=dict(type='Conv1d'),
            norm_cfg=dict(type='BN1d'),
            act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
            loss_decode=dict(
                type='mmdet.CrossEntropyLoss',
                use_sigmoid=False,
                class_weight=None,
                loss_weight=1.0),
            ignore_index=13)

        # DGCNN head expects dict format features
        fa_points = torch.rand(1, 4096, 1024).float()
        feat_dict = dict(fa_points=fa_points)

        # Test forward
        seg_logits = dgcnn_head.forward(feat_dict)

        self.assertEqual(seg_logits.shape, torch.Size([1, 13, 4096]))

        # When truth is non-empty then losses
        # should be nonzero for random inputs
        pts_semantic_mask = torch.randint(0, 13, (4096, )).long()
        gt_pts_seg = PointData(pts_semantic_mask=pts_semantic_mask)

        datasample = Det3DDataSample()
        datasample.gt_pts_seg = gt_pts_seg

        gt_losses = dgcnn_head.loss_by_feat(seg_logits, [datasample])

        gt_sem_seg_loss = gt_losses['loss_sem_seg'].item()

        self.assertGreater(gt_sem_seg_loss, 0,
                           'semantic seg loss should be positive')