File size: 1,906 Bytes
34d1f8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import pytest
import torch
import torch.nn.functional as F
from mmdet3d.models.decode_heads import MinkUNetHead
from mmdet3d.structures import Det3DDataSample, PointData
class TestMinkUNetHead(TestCase):
def test_minkunet_head_loss(self):
"""Tests PAConv head loss."""
try:
import torchsparse
except ImportError:
pytest.skip('test requires Torchsparse installation')
if torch.cuda.is_available():
minkunet_head = MinkUNetHead(channels=4, num_classes=19)
minkunet_head.cuda()
coordinates, features = [], []
for i in range(2):
c = torch.randint(0, 10, (100, 3)).int()
c = F.pad(c, (0, 1), mode='constant', value=i)
coordinates.append(c)
f = torch.rand(100, 4)
features.append(f)
features = torch.cat(features, dim=0).cuda()
coordinates = torch.cat(coordinates, dim=0).cuda()
x = torchsparse.SparseTensor(feats=features, coords=coordinates)
# Test forward
seg_logits = minkunet_head.forward(x)
self.assertEqual(seg_logits.shape, torch.Size([200, 19]))
# When truth is non-empty then losses
# should be nonzero for random inputs
voxel_semantic_mask = torch.randint(0, 19, (100, )).long().cuda()
gt_pts_seg = PointData(voxel_semantic_mask=voxel_semantic_mask)
datasample = Det3DDataSample()
datasample.gt_pts_seg = gt_pts_seg
gt_losses = minkunet_head.loss(x, [datasample, datasample], {})
gt_sem_seg_loss = gt_losses['loss_sem_seg'].item()
self.assertGreater(gt_sem_seg_loss, 0,
'semantic seg loss should be positive')
|