|
|
|
import pytest |
|
import torch |
|
|
|
from mmdet3d.registry import MODELS |
|
|
|
|
|
def test_multibin_loss(): |
|
from mmdet3d.models.losses import MultiBinLoss |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
multibin_loss = MultiBinLoss(reduction='l2') |
|
|
|
pred = torch.tensor([[ |
|
0.81, 0.32, 0.78, 0.52, 0.24, 0.12, 0.32, 0.11, 1.20, 1.30, 0.20, 0.11, |
|
0.12, 0.11, 0.23, 0.31 |
|
], |
|
[ |
|
0.02, 0.19, 0.78, 0.22, 0.31, 0.12, 0.22, 0.11, |
|
1.20, 1.30, 0.45, 0.51, 0.12, 0.11, 0.13, 0.61 |
|
]]) |
|
target = torch.tensor([[1, 1, 0, 0, 2.14, 3.12, 0.68, -2.15], |
|
[1, 1, 0, 0, 3.12, 3.12, 2.34, 1.23]]) |
|
multibin_loss_cfg = dict( |
|
type='MultiBinLoss', reduction='none', loss_weight=1.0) |
|
multibin_loss = MODELS.build(multibin_loss_cfg) |
|
output_multibin_loss = multibin_loss(pred, target, num_dir_bins=4) |
|
expected_multibin_loss = torch.tensor(2.1120) |
|
assert torch.allclose( |
|
output_multibin_loss, expected_multibin_loss, atol=1e-4) |
|
|