|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ..builder import LOSSES |
|
|
|
|
|
@LOSSES.register_module() |
|
class AdaptiveWingLoss(nn.Module): |
|
"""Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face |
|
Alignment via Heatmap Regression' Wang et al. ICCV'2019. |
|
|
|
Args: |
|
alpha (float), omega (float), epsilon (float), theta (float) |
|
are hyper-parameters. |
|
use_target_weight (bool): Option to use weighted MSE loss. |
|
Different joint types may have different target weights. |
|
loss_weight (float): Weight of the loss. Default: 1.0. |
|
""" |
|
|
|
def __init__(self, |
|
alpha=2.1, |
|
omega=14, |
|
epsilon=1, |
|
theta=0.5, |
|
use_target_weight=False, |
|
loss_weight=1.): |
|
super().__init__() |
|
self.alpha = float(alpha) |
|
self.omega = float(omega) |
|
self.epsilon = float(epsilon) |
|
self.theta = float(theta) |
|
self.use_target_weight = use_target_weight |
|
self.loss_weight = loss_weight |
|
|
|
def criterion(self, pred, target): |
|
"""Criterion of wingloss. |
|
|
|
Note: |
|
batch_size: N |
|
num_keypoints: K |
|
|
|
Args: |
|
pred (torch.Tensor[NxKxHxW]): Predicted heatmaps. |
|
target (torch.Tensor[NxKxHxW]): Target heatmaps. |
|
""" |
|
H, W = pred.shape[2:4] |
|
delta = (target - pred).abs() |
|
|
|
A = self.omega * ( |
|
1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) |
|
) * (self.alpha - target) * (torch.pow( |
|
self.theta / self.epsilon, |
|
self.alpha - target - 1)) * (1 / self.epsilon) |
|
C = self.theta * A - self.omega * torch.log( |
|
1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) |
|
|
|
losses = torch.where( |
|
delta < self.theta, |
|
self.omega * |
|
torch.log(1 + |
|
torch.pow(delta / self.epsilon, self.alpha - target)), |
|
A * delta - C) |
|
|
|
return torch.mean(losses) |
|
|
|
def forward(self, output, target, target_weight): |
|
"""Forward function. |
|
|
|
Note: |
|
batch_size: N |
|
num_keypoints: K |
|
|
|
Args: |
|
output (torch.Tensor[NxKxHxW]): Output heatmaps. |
|
target (torch.Tensor[NxKxHxW]): Target heatmaps. |
|
target_weight (torch.Tensor[NxKx1]): |
|
Weights across different joint types. |
|
""" |
|
if self.use_target_weight: |
|
loss = self.criterion(output * target_weight.unsqueeze(-1), |
|
target * target_weight.unsqueeze(-1)) |
|
else: |
|
loss = self.criterion(output, target) |
|
|
|
return loss * self.loss_weight |
|
|