File size: 1,906 Bytes
34d1f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import pytest
import torch
import torch.nn.functional as F

from mmdet3d.models.decode_heads import MinkUNetHead
from mmdet3d.structures import Det3DDataSample, PointData


class TestMinkUNetHead(TestCase):

    def test_minkunet_head_loss(self):
        """Tests PAConv head loss."""

        try:
            import torchsparse
        except ImportError:
            pytest.skip('test requires Torchsparse installation')
        if torch.cuda.is_available():
            minkunet_head = MinkUNetHead(channels=4, num_classes=19)

            minkunet_head.cuda()
            coordinates, features = [], []
            for i in range(2):
                c = torch.randint(0, 10, (100, 3)).int()
                c = F.pad(c, (0, 1), mode='constant', value=i)
                coordinates.append(c)
                f = torch.rand(100, 4)
                features.append(f)
            features = torch.cat(features, dim=0).cuda()
            coordinates = torch.cat(coordinates, dim=0).cuda()
            x = torchsparse.SparseTensor(feats=features, coords=coordinates)

            # Test forward
            seg_logits = minkunet_head.forward(x)

            self.assertEqual(seg_logits.shape, torch.Size([200, 19]))

            # When truth is non-empty then losses
            # should be nonzero for random inputs
            voxel_semantic_mask = torch.randint(0, 19, (100, )).long().cuda()
            gt_pts_seg = PointData(voxel_semantic_mask=voxel_semantic_mask)

            datasample = Det3DDataSample()
            datasample.gt_pts_seg = gt_pts_seg

            gt_losses = minkunet_head.loss(x, [datasample, datasample], {})

            gt_sem_seg_loss = gt_losses['loss_sem_seg'].item()

            self.assertGreater(gt_sem_seg_loss, 0,
                               'semantic seg loss should be positive')