|
|
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch import nn as nn |
|
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss |
|
|
|
from mmdet3d.registry import MODELS |
|
|
|
|
|
def chamfer_distance( |
|
src: Tensor, |
|
dst: Tensor, |
|
src_weight: Union[Tensor, float] = 1.0, |
|
dst_weight: Union[Tensor, float] = 1.0, |
|
criterion_mode: str = 'l2', |
|
reduction: str = 'mean') -> Tuple[Tensor, Tensor, Tensor, Tensor]: |
|
"""Calculate Chamfer Distance of two sets. |
|
|
|
Args: |
|
src (Tensor): Source set with shape [B, N, C] to |
|
calculate Chamfer Distance. |
|
dst (Tensor): Destination set with shape [B, M, C] to |
|
calculate Chamfer Distance. |
|
src_weight (Tensor or float): Weight of source loss. Defaults to 1.0. |
|
dst_weight (Tensor or float): Weight of destination loss. |
|
Defaults to 1.0. |
|
criterion_mode (str): Criterion mode to calculate distance. |
|
The valid modes are 'smooth_l1', 'l1' or 'l2'. Defaults to 'l2'. |
|
reduction (str): Method to reduce losses. |
|
The valid reduction method are 'none', 'sum' or 'mean'. |
|
Defaults to 'mean'. |
|
|
|
Returns: |
|
tuple: Source and Destination loss with the corresponding indices. |
|
|
|
- loss_src (Tensor): The min distance |
|
from source to destination. |
|
- loss_dst (Tensor): The min distance |
|
from destination to source. |
|
- indices1 (Tensor): Index the min distance point |
|
for each point in source to destination. |
|
- indices2 (Tensor): Index the min distance point |
|
for each point in destination to source. |
|
""" |
|
|
|
if criterion_mode == 'smooth_l1': |
|
criterion = smooth_l1_loss |
|
elif criterion_mode == 'l1': |
|
criterion = l1_loss |
|
elif criterion_mode == 'l2': |
|
criterion = mse_loss |
|
else: |
|
raise NotImplementedError |
|
|
|
src_expand = src.unsqueeze(2).repeat(1, 1, dst.shape[1], 1) |
|
dst_expand = dst.unsqueeze(1).repeat(1, src.shape[1], 1, 1) |
|
|
|
distance = criterion(src_expand, dst_expand, reduction='none').sum(-1) |
|
src2dst_distance, indices1 = torch.min(distance, dim=2) |
|
dst2src_distance, indices2 = torch.min(distance, dim=1) |
|
|
|
loss_src = (src2dst_distance * src_weight) |
|
loss_dst = (dst2src_distance * dst_weight) |
|
|
|
if reduction == 'sum': |
|
loss_src = torch.sum(loss_src) |
|
loss_dst = torch.sum(loss_dst) |
|
elif reduction == 'mean': |
|
loss_src = torch.mean(loss_src) |
|
loss_dst = torch.mean(loss_dst) |
|
elif reduction == 'none': |
|
pass |
|
else: |
|
raise NotImplementedError |
|
|
|
return loss_src, loss_dst, indices1, indices2 |
|
|
|
|
|
@MODELS.register_module() |
|
class ChamferDistance(nn.Module): |
|
"""Calculate Chamfer Distance of two sets. |
|
|
|
Args: |
|
mode (str): Criterion mode to calculate distance. |
|
The valid modes are 'smooth_l1', 'l1' or 'l2'. Defaults to 'l2'. |
|
reduction (str): Method to reduce losses. |
|
The valid reduction method are 'none', 'sum' or 'mean'. |
|
Defaults to 'mean'. |
|
loss_src_weight (float): Weight of loss_source. Defaults to l.0. |
|
loss_dst_weight (float): Weight of loss_target. Defaults to 1.0. |
|
""" |
|
|
|
def __init__(self, |
|
mode: str = 'l2', |
|
reduction: str = 'mean', |
|
loss_src_weight: float = 1.0, |
|
loss_dst_weight: float = 1.0) -> None: |
|
super(ChamferDistance, self).__init__() |
|
|
|
assert mode in ['smooth_l1', 'l1', 'l2'] |
|
assert reduction in ['none', 'sum', 'mean'] |
|
self.mode = mode |
|
self.reduction = reduction |
|
self.loss_src_weight = loss_src_weight |
|
self.loss_dst_weight = loss_dst_weight |
|
|
|
def forward( |
|
self, |
|
source: Tensor, |
|
target: Tensor, |
|
src_weight: Union[Tensor, float] = 1.0, |
|
dst_weight: Union[Tensor, float] = 1.0, |
|
reduction_override: Optional[str] = None, |
|
return_indices: bool = False, |
|
**kwargs |
|
) -> Union[Tuple[Tensor, Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]: |
|
"""Forward function of loss calculation. |
|
|
|
Args: |
|
source (Tensor): Source set with shape [B, N, C] to |
|
calculate Chamfer Distance. |
|
target (Tensor): Destination set with shape [B, M, C] to |
|
calculate Chamfer Distance. |
|
src_weight (Tensor | float): |
|
Weight of source loss. Defaults to 1.0. |
|
dst_weight (Tensor | float): |
|
Weight of destination loss. Defaults to 1.0. |
|
reduction_override (str, optional): Method to reduce losses. |
|
The valid reduction method are 'none', 'sum' or 'mean'. |
|
Defaults to None. |
|
return_indices (bool): Whether to return indices. |
|
Defaults to False. |
|
|
|
Returns: |
|
tuple[Tensor]: If ``return_indices=True``, return losses of |
|
source and target with their corresponding indices in the |
|
order of ``(loss_source, loss_target, indices1, indices2)``. |
|
If ``return_indices=False``, return |
|
``(loss_source, loss_target)``. |
|
""" |
|
assert reduction_override in (None, 'none', 'mean', 'sum') |
|
reduction = ( |
|
reduction_override if reduction_override else self.reduction) |
|
|
|
loss_source, loss_target, indices1, indices2 = chamfer_distance( |
|
source, target, src_weight, dst_weight, self.mode, reduction) |
|
|
|
loss_source *= self.loss_src_weight |
|
loss_target *= self.loss_dst_weight |
|
|
|
if return_indices: |
|
return loss_source, loss_target, indices1, indices2 |
|
else: |
|
return loss_source, loss_target |
|
|