|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class Loss(nn.modules.loss._Loss): |
|
"""Inherit this class to implement custom loss.""" |
|
|
|
def __init__(self, **kwargs): |
|
super(Loss, self).__init__(**kwargs) |
|
|
|
|
|
class AdditiveMarginSoftmaxLoss(Loss): |
|
"""Computes Additive Margin Softmax (CosFace) Loss |
|
|
|
Paper: CosFace: Large Margin Cosine Loss for Deep Face Recognition |
|
|
|
args: |
|
scale: scale value for cosine angle |
|
margin: margin value added to cosine angle |
|
""" |
|
|
|
def __init__(self, scale=30.0, margin=0.2): |
|
super().__init__() |
|
|
|
self.eps = 1e-7 |
|
self.scale = scale |
|
self.margin = margin |
|
|
|
def forward(self, logits: torch.Tensor, labels: torch.Tensor): |
|
|
|
logits_target = logits[torch.arange(logits.size(0)), labels] |
|
numerator = self.scale * (logits_target - self.margin) |
|
|
|
logits.scatter_(1, labels.unsqueeze(1), float('-inf')) |
|
denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * logits), dim=1) |
|
|
|
loss = -torch.log(torch.exp(numerator) / denominator) |
|
return loss.mean() |
|
|
|
|
|
class AdditiveAngularMarginSoftmaxLoss(Loss): |
|
"""Computes Additive Angular Margin Softmax (ArcFace) Loss |
|
|
|
Paper: ArcFace: Additive Angular Margin Loss for Deep Face Recognition |
|
|
|
Args: |
|
scale: scale value for cosine angle |
|
margin: margin value added to cosine angle |
|
""" |
|
|
|
def __init__(self, scale=20.0, margin=1.35): |
|
super().__init__() |
|
|
|
self.eps = 1e-7 |
|
self.scale = scale |
|
self.margin = margin |
|
|
|
def forward(self, logits: torch.Tensor, labels: torch.Tensor): |
|
numerator = self.scale * torch.cos( |
|
torch.acos(torch.clamp(torch.diagonal(logits.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps)) |
|
+ self.margin |
|
) |
|
excl = torch.cat( |
|
[torch.cat((logits[i, :y], logits[i, y + 1 :])).unsqueeze(0) for i, y in enumerate(labels)], dim=0 |
|
) |
|
denominator = torch.exp(numerator) + torch.sum(torch.exp(self.scale * excl), dim=1) |
|
L = numerator - torch.log(denominator) |
|
return -torch.mean(L) |