|
|
|
from typing import Optional |
|
|
|
import torch |
|
from mmcv.ops.diff_iou_rotated import box2corners, oriented_box_intersection_2d |
|
from mmdet.models.losses.utils import weighted_loss |
|
from torch import Tensor |
|
from torch import nn as nn |
|
|
|
from mmdet3d.models import rotated_iou_3d_loss |
|
from mmdet3d.registry import MODELS |
|
|
|
|
|
def diff_diou_rotated_3d(box3d1: Tensor, box3d2: Tensor) -> Tensor: |
|
"""Calculate differentiable DIoU of rotated 3d boxes. |
|
|
|
Args: |
|
box3d1 (Tensor): (B, N, 3+3+1) First box (x,y,z,w,h,l,alpha). |
|
box3d2 (Tensor): (B, N, 3+3+1) Second box (x,y,z,w,h,l,alpha). |
|
Returns: |
|
Tensor: (B, N) DIoU. |
|
""" |
|
box1 = box3d1[..., [0, 1, 3, 4, 6]] |
|
box2 = box3d2[..., [0, 1, 3, 4, 6]] |
|
corners1 = box2corners(box1) |
|
corners2 = box2corners(box2) |
|
intersection, _ = oriented_box_intersection_2d(corners1, corners2) |
|
zmax1 = box3d1[..., 2] + box3d1[..., 5] * 0.5 |
|
zmin1 = box3d1[..., 2] - box3d1[..., 5] * 0.5 |
|
zmax2 = box3d2[..., 2] + box3d2[..., 5] * 0.5 |
|
zmin2 = box3d2[..., 2] - box3d2[..., 5] * 0.5 |
|
z_overlap = (torch.min(zmax1, zmax2) - |
|
torch.max(zmin1, zmin2)).clamp_(min=0.) |
|
intersection_3d = intersection * z_overlap |
|
volume1 = box3d1[..., 3] * box3d1[..., 4] * box3d1[..., 5] |
|
volume2 = box3d2[..., 3] * box3d2[..., 4] * box3d2[..., 5] |
|
union_3d = volume1 + volume2 - intersection_3d |
|
|
|
x1_max = torch.max(corners1[..., 0], dim=2)[0] |
|
x1_min = torch.min(corners1[..., 0], dim=2)[0] |
|
y1_max = torch.max(corners1[..., 1], dim=2)[0] |
|
y1_min = torch.min(corners1[..., 1], dim=2)[0] |
|
|
|
x2_max = torch.max(corners2[..., 0], dim=2)[0] |
|
x2_min = torch.min(corners2[..., 0], dim=2)[0] |
|
y2_max = torch.max(corners2[..., 1], dim=2)[0] |
|
y2_min = torch.min(corners2[..., 1], dim=2)[0] |
|
|
|
x_max = torch.max(x1_max, x2_max) |
|
x_min = torch.min(x1_min, x2_min) |
|
y_max = torch.max(y1_max, y2_max) |
|
y_min = torch.min(y1_min, y2_min) |
|
|
|
z_max = torch.max(zmax1, zmax2) |
|
z_min = torch.min(zmin1, zmin2) |
|
|
|
r2 = ((box1[..., :3] - box2[..., :3])**2).sum(dim=-1) |
|
c2 = (x_min - x_max)**2 + (y_min - y_max)**2 + (z_min - z_max)**2 |
|
|
|
return intersection_3d / union_3d - r2 / c2 |
|
|
|
|
|
@weighted_loss |
|
def rotated_diou_3d_loss(pred: Tensor, target: Tensor) -> Tensor: |
|
"""Calculate the DIoU loss (1-DIoU) of two sets of rotated bounding boxes. |
|
Note that predictions and targets are one-to-one corresponded. |
|
|
|
Args: |
|
pred (torch.Tensor): Bbox predictions with shape [N, 7] |
|
(x, y, z, w, l, h, alpha). |
|
target (torch.Tensor): Bbox targets (gt) with shape [N, 7] |
|
(x, y, z, w, l, h, alpha). |
|
|
|
Returns: |
|
torch.Tensor: IoU loss between predictions and targets. |
|
""" |
|
diou_loss = 1 - diff_diou_rotated_3d( |
|
pred.unsqueeze(0), target.unsqueeze(0))[0] |
|
return diou_loss |
|
|
|
|
|
@MODELS.register_module() |
|
class TR3DRotatedIoU3DLoss(nn.Module): |
|
"""Calculate the IoU loss (1-IoU) of rotated bounding boxes. The only |
|
difference with original RotatedIoU3DLoss is the addition of DIoU mode. |
|
These classes should be merged in the future. |
|
|
|
Args: |
|
mode (str): 'iou' for intersection over union or 'diou' for |
|
distance-iou loss. Defaults to 'iou'. |
|
reduction (str): Method to reduce losses. |
|
The valid reduction method are 'none', 'sum' or 'mean'. |
|
Defaults to 'mean'. |
|
loss_weight (float): Weight of loss. Defaults to 1.0. |
|
""" |
|
|
|
def __init__(self, |
|
mode: str = 'iou', |
|
reduction: str = 'mean', |
|
loss_weight: float = 1.0) -> None: |
|
super(TR3DRotatedIoU3DLoss, self).__init__() |
|
assert mode in ['iou', 'diou'] |
|
self.loss = rotated_iou_3d_loss if mode == 'iou' \ |
|
else rotated_diou_3d_loss |
|
assert reduction in ['none', 'sum', 'mean'] |
|
self.reduction = reduction |
|
self.loss_weight = loss_weight |
|
|
|
def forward(self, |
|
pred: Tensor, |
|
target: Tensor, |
|
weight: Optional[Tensor] = None, |
|
avg_factor: Optional[float] = None, |
|
reduction_override: Optional[str] = None, |
|
**kwargs) -> Tensor: |
|
"""Forward function of loss calculation. |
|
|
|
Args: |
|
pred (Tensor): Bbox predictions with shape [..., 7] |
|
(x, y, z, w, l, h, alpha). |
|
target (Tensor): Bbox targets (gt) with shape [..., 7] |
|
(x, y, z, w, l, h, alpha). |
|
weight (Tensor, optional): Weight of loss. |
|
Defaults to None. |
|
avg_factor (float, optional): Average factor that is used to |
|
average the loss. Defaults to None. |
|
reduction_override (str, optional): Method to reduce losses. |
|
The valid reduction method are 'none', 'sum' or 'mean'. |
|
Defaults to None. |
|
|
|
Returns: |
|
Tensor: IoU loss between predictions and targets. |
|
""" |
|
if weight is not None and not torch.any(weight > 0): |
|
return pred.sum() * weight.sum() |
|
assert reduction_override in (None, 'none', 'mean', 'sum') |
|
reduction = ( |
|
reduction_override if reduction_override else self.reduction) |
|
if weight is not None and weight.dim() > 1: |
|
weight = weight.mean(-1) |
|
loss = self.loss_weight * self.loss( |
|
pred, |
|
target, |
|
weight, |
|
reduction=reduction, |
|
avg_factor=avg_factor, |
|
**kwargs) |
|
|
|
return loss |
|
|