|
|
|
from unittest import TestCase |
|
|
|
import pytest |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from mmdet3d.models.decode_heads import MinkUNetHead |
|
from mmdet3d.structures import Det3DDataSample, PointData |
|
|
|
|
|
class TestMinkUNetHead(TestCase): |
|
|
|
def test_minkunet_head_loss(self): |
|
"""Tests PAConv head loss.""" |
|
|
|
try: |
|
import torchsparse |
|
except ImportError: |
|
pytest.skip('test requires Torchsparse installation') |
|
if torch.cuda.is_available(): |
|
minkunet_head = MinkUNetHead(channels=4, num_classes=19) |
|
|
|
minkunet_head.cuda() |
|
coordinates, features = [], [] |
|
for i in range(2): |
|
c = torch.randint(0, 10, (100, 3)).int() |
|
c = F.pad(c, (0, 1), mode='constant', value=i) |
|
coordinates.append(c) |
|
f = torch.rand(100, 4) |
|
features.append(f) |
|
features = torch.cat(features, dim=0).cuda() |
|
coordinates = torch.cat(coordinates, dim=0).cuda() |
|
x = torchsparse.SparseTensor(feats=features, coords=coordinates) |
|
|
|
|
|
seg_logits = minkunet_head.forward(x) |
|
|
|
self.assertEqual(seg_logits.shape, torch.Size([200, 19])) |
|
|
|
|
|
|
|
voxel_semantic_mask = torch.randint(0, 19, (100, )).long().cuda() |
|
gt_pts_seg = PointData(voxel_semantic_mask=voxel_semantic_mask) |
|
|
|
datasample = Det3DDataSample() |
|
datasample.gt_pts_seg = gt_pts_seg |
|
|
|
gt_losses = minkunet_head.loss(x, [datasample, datasample], {}) |
|
|
|
gt_sem_seg_loss = gt_losses['loss_sem_seg'].item() |
|
|
|
self.assertGreater(gt_sem_seg_loss, 0, |
|
'semantic seg loss should be positive') |
|
|