File size: 1,184 Bytes
c2ca15f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmdet3d.registry import MODELS


def test_multibin_loss():
    from mmdet3d.models.losses import MultiBinLoss

    # reduction should be in ['none', 'mean', 'sum']
    with pytest.raises(AssertionError):
        multibin_loss = MultiBinLoss(reduction='l2')

    pred = torch.tensor([[
        0.81, 0.32, 0.78, 0.52, 0.24, 0.12, 0.32, 0.11, 1.20, 1.30, 0.20, 0.11,
        0.12, 0.11, 0.23, 0.31
    ],
                         [
                             0.02, 0.19, 0.78, 0.22, 0.31, 0.12, 0.22, 0.11,
                             1.20, 1.30, 0.45, 0.51, 0.12, 0.11, 0.13, 0.61
                         ]])
    target = torch.tensor([[1, 1, 0, 0, 2.14, 3.12, 0.68, -2.15],
                           [1, 1, 0, 0, 3.12, 3.12, 2.34, 1.23]])
    multibin_loss_cfg = dict(
        type='MultiBinLoss', reduction='none', loss_weight=1.0)
    multibin_loss = MODELS.build(multibin_loss_cfg)
    output_multibin_loss = multibin_loss(pred, target, num_dir_bins=4)
    expected_multibin_loss = torch.tensor(2.1120)
    assert torch.allclose(
        output_multibin_loss, expected_multibin_loss, atol=1e-4)