|
|
|
import torch |
|
from torch import Tensor, nn |
|
from torch.nn import functional as F |
|
|
|
from mmdet3d.registry import MODELS |
|
|
|
|
|
def one_hot(label: Tensor, |
|
n_classes: int, |
|
requires_grad: bool = True) -> Tensor: |
|
"""Return One Hot Label.""" |
|
device = label.device |
|
one_hot_label = torch.eye( |
|
n_classes, device=device, requires_grad=requires_grad)[label] |
|
one_hot_label = one_hot_label.transpose(1, 3).transpose(2, 3) |
|
|
|
return one_hot_label |
|
|
|
|
|
@MODELS.register_module() |
|
class BoundaryLoss(nn.Module): |
|
"""Boundary loss.""" |
|
|
|
def __init__(self, theta0=3, theta=5, loss_weight: float = 1.0) -> None: |
|
super(BoundaryLoss, self).__init__() |
|
self.theta0 = theta0 |
|
self.theta = theta |
|
self.loss_weight = loss_weight |
|
|
|
def forward(self, pred: Tensor, gt: Tensor) -> Tensor: |
|
"""Forward function. |
|
|
|
Args: |
|
pred (Tensor): The output from model. |
|
gt (Tensor): Ground truth map. |
|
|
|
Returns: |
|
Tensor: Loss tensor. |
|
""" |
|
pred = F.softmax(pred, dim=1) |
|
n, c, _, _ = pred.shape |
|
|
|
|
|
one_hot_gt = one_hot(gt, c) |
|
|
|
|
|
gt_b = F.max_pool2d( |
|
1 - one_hot_gt, |
|
kernel_size=self.theta0, |
|
stride=1, |
|
padding=(self.theta0 - 1) // 2) |
|
gt_b -= 1 - one_hot_gt |
|
|
|
pred_b = F.max_pool2d( |
|
1 - pred, |
|
kernel_size=self.theta0, |
|
stride=1, |
|
padding=(self.theta0 - 1) // 2) |
|
pred_b -= 1 - pred |
|
|
|
gt_b = gt_b.view(n, c, -1) |
|
pred_b = pred_b.view(n, c, -1) |
|
|
|
|
|
P = torch.sum(pred_b * gt_b, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7) |
|
R = torch.sum(pred_b * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7) |
|
|
|
|
|
BF1 = 2 * P * R / (P + R + 1e-7) |
|
|
|
|
|
loss = torch.mean(1 - BF1) |
|
|
|
return self.loss_weight * loss |
|
|