File size: 5,601 Bytes
34d1f8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmcv.ops.diff_iou_rotated import box2corners, oriented_box_intersection_2d
from mmdet.models.losses.utils import weighted_loss
from torch import Tensor
from torch import nn as nn
from mmdet3d.models import rotated_iou_3d_loss
from mmdet3d.registry import MODELS
def diff_diou_rotated_3d(box3d1: Tensor, box3d2: Tensor) -> Tensor:
"""Calculate differentiable DIoU of rotated 3d boxes.
Args:
box3d1 (Tensor): (B, N, 3+3+1) First box (x,y,z,w,h,l,alpha).
box3d2 (Tensor): (B, N, 3+3+1) Second box (x,y,z,w,h,l,alpha).
Returns:
Tensor: (B, N) DIoU.
"""
box1 = box3d1[..., [0, 1, 3, 4, 6]]
box2 = box3d2[..., [0, 1, 3, 4, 6]]
corners1 = box2corners(box1)
corners2 = box2corners(box2)
intersection, _ = oriented_box_intersection_2d(corners1, corners2)
zmax1 = box3d1[..., 2] + box3d1[..., 5] * 0.5
zmin1 = box3d1[..., 2] - box3d1[..., 5] * 0.5
zmax2 = box3d2[..., 2] + box3d2[..., 5] * 0.5
zmin2 = box3d2[..., 2] - box3d2[..., 5] * 0.5
z_overlap = (torch.min(zmax1, zmax2) -
torch.max(zmin1, zmin2)).clamp_(min=0.)
intersection_3d = intersection * z_overlap
volume1 = box3d1[..., 3] * box3d1[..., 4] * box3d1[..., 5]
volume2 = box3d2[..., 3] * box3d2[..., 4] * box3d2[..., 5]
union_3d = volume1 + volume2 - intersection_3d
x1_max = torch.max(corners1[..., 0], dim=2)[0]
x1_min = torch.min(corners1[..., 0], dim=2)[0]
y1_max = torch.max(corners1[..., 1], dim=2)[0]
y1_min = torch.min(corners1[..., 1], dim=2)[0]
x2_max = torch.max(corners2[..., 0], dim=2)[0]
x2_min = torch.min(corners2[..., 0], dim=2)[0]
y2_max = torch.max(corners2[..., 1], dim=2)[0]
y2_min = torch.min(corners2[..., 1], dim=2)[0]
x_max = torch.max(x1_max, x2_max)
x_min = torch.min(x1_min, x2_min)
y_max = torch.max(y1_max, y2_max)
y_min = torch.min(y1_min, y2_min)
z_max = torch.max(zmax1, zmax2)
z_min = torch.min(zmin1, zmin2)
r2 = ((box1[..., :3] - box2[..., :3])**2).sum(dim=-1)
c2 = (x_min - x_max)**2 + (y_min - y_max)**2 + (z_min - z_max)**2
return intersection_3d / union_3d - r2 / c2
@weighted_loss
def rotated_diou_3d_loss(pred: Tensor, target: Tensor) -> Tensor:
"""Calculate the DIoU loss (1-DIoU) of two sets of rotated bounding boxes.
Note that predictions and targets are one-to-one corresponded.
Args:
pred (torch.Tensor): Bbox predictions with shape [N, 7]
(x, y, z, w, l, h, alpha).
target (torch.Tensor): Bbox targets (gt) with shape [N, 7]
(x, y, z, w, l, h, alpha).
Returns:
torch.Tensor: IoU loss between predictions and targets.
"""
diou_loss = 1 - diff_diou_rotated_3d(
pred.unsqueeze(0), target.unsqueeze(0))[0]
return diou_loss
@MODELS.register_module()
class TR3DRotatedIoU3DLoss(nn.Module):
"""Calculate the IoU loss (1-IoU) of rotated bounding boxes. The only
difference with original RotatedIoU3DLoss is the addition of DIoU mode.
These classes should be merged in the future.
Args:
mode (str): 'iou' for intersection over union or 'diou' for
distance-iou loss. Defaults to 'iou'.
reduction (str): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to 'mean'.
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(self,
mode: str = 'iou',
reduction: str = 'mean',
loss_weight: float = 1.0) -> None:
super(TR3DRotatedIoU3DLoss, self).__init__()
assert mode in ['iou', 'diou']
self.loss = rotated_iou_3d_loss if mode == 'iou' \
else rotated_diou_3d_loss
assert reduction in ['none', 'sum', 'mean']
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
avg_factor: Optional[float] = None,
reduction_override: Optional[str] = None,
**kwargs) -> Tensor:
"""Forward function of loss calculation.
Args:
pred (Tensor): Bbox predictions with shape [..., 7]
(x, y, z, w, l, h, alpha).
target (Tensor): Bbox targets (gt) with shape [..., 7]
(x, y, z, w, l, h, alpha).
weight (Tensor, optional): Weight of loss.
Defaults to None.
avg_factor (float, optional): Average factor that is used to
average the loss. Defaults to None.
reduction_override (str, optional): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to None.
Returns:
Tensor: IoU loss between predictions and targets.
"""
if weight is not None and not torch.any(weight > 0):
return pred.sum() * weight.sum() # 0
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if weight is not None and weight.dim() > 1:
weight = weight.mean(-1)
loss = self.loss_weight * self.loss(
pred,
target,
weight,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss
|