|
import torch |
|
import torch.nn as nn |
|
import math |
|
import torch.nn.functional as F |
|
from torch.nn.parameter import Parameter |
|
|
|
class ArcMarginProduct(nn.Module): |
|
r"""Implement of large margin arc distance: : |
|
Args: |
|
in_features: size of each input sample |
|
out_features: size of each output sample |
|
s: norm of input feature |
|
m: margin |
|
cos(theta + m)wandb: ERROR Abnormal program exit |
|
|
|
""" |
|
def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False, ls_eps=0.0): |
|
super(ArcMarginProduct, self).__init__() |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.s = s |
|
self.m = m |
|
self.ls_eps = ls_eps |
|
self.weight = Parameter(torch.FloatTensor(out_features, in_features)) |
|
nn.init.xavier_uniform_(self.weight) |
|
|
|
self.easy_margin = easy_margin |
|
self.cos_m = math.cos(m) |
|
self.sin_m = math.sin(m) |
|
self.th = math.cos(math.pi - m) |
|
self.mm = math.sin(math.pi - m) * m |
|
|
|
def forward(self, input, label): |
|
|
|
cosine = F.linear(F.normalize(input), F.normalize(self.weight)) |
|
sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) |
|
phi = cosine * self.cos_m - sine * self.sin_m |
|
if self.easy_margin: |
|
phi = torch.where(cosine > 0, phi, cosine) |
|
else: |
|
phi = torch.where(cosine > self.th, phi, cosine - self.mm) |
|
|
|
|
|
one_hot = torch.zeros(cosine.size(), device='cuda') |
|
one_hot.scatter_(1, label.view(-1, 1).long(), 1) |
|
if self.ls_eps > 0: |
|
one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features |
|
|
|
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) |
|
output *= self.s |
|
|
|
return output |
|
|
|
|
|
def l2_norm(input, axis = 1): |
|
norm = torch.norm(input, 2, axis, True) |
|
output = torch.div(input, norm) |
|
|
|
return output |
|
class ElasticArcFace(nn.Module): |
|
def __init__(self, in_features, out_features, s=64.0, m=0.50,std=0.0125,plus=False, k=None): |
|
super(ElasticArcFace, self).__init__() |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.s = s |
|
self.m = m |
|
self.kernel = nn.Parameter(torch.FloatTensor(in_features, out_features)) |
|
nn.init.normal_(self.kernel, std=0.01) |
|
self.std=std |
|
self.plus=plus |
|
def forward(self, embbedings, label): |
|
embbedings = l2_norm(embbedings, axis=1) |
|
kernel_norm = l2_norm(self.kernel, axis=0) |
|
cos_theta = torch.mm(embbedings, kernel_norm) |
|
cos_theta = cos_theta.clamp(-1, 1) |
|
index = torch.where(label != -1)[0] |
|
m_hot = torch.zeros(index.size()[0], cos_theta.size()[1], device=cos_theta.device) |
|
margin = torch.normal(mean=self.m, std=self.std, size=label[index, None].size(), device=cos_theta.device) |
|
if self.plus: |
|
with torch.no_grad(): |
|
distmat = cos_theta[index, label.view(-1)].detach().clone() |
|
_, idicate_cosie = torch.sort(distmat, dim=0, descending=True) |
|
margin, _ = torch.sort(margin, dim=0) |
|
m_hot.scatter_(1, label[index, None], margin[idicate_cosie]) |
|
else: |
|
m_hot.scatter_(1, label[index, None], margin) |
|
cos_theta.acos_() |
|
cos_theta[index] += m_hot |
|
cos_theta.cos_().mul_(self.s) |
|
return cos_theta |
|
|
|
|
|
|
|
class ArcMarginProduct_subcenter(nn.Module): |
|
def __init__(self, in_features, out_features, k=3): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.FloatTensor(out_features*k, in_features)) |
|
self.reset_parameters() |
|
self.k = k |
|
self.out_features = out_features |
|
|
|
def reset_parameters(self): |
|
stdv = 1. / math.sqrt(self.weight.size(1)) |
|
self.weight.data.uniform_(-stdv, stdv) |
|
|
|
def forward(self, features): |
|
cosine_all = F.linear(F.normalize(features), F.normalize(self.weight)) |
|
cosine_all = cosine_all.view(-1, self.out_features, self.k) |
|
cosine, _ = torch.max(cosine_all, dim=2) |
|
return cosine |
|
|
|
class ArcFaceLossAdaptiveMargin(nn.modules.Module): |
|
def __init__(self, margins, out_dim, s): |
|
super().__init__() |
|
|
|
self.s = s |
|
self.register_buffer('margins', torch.tensor(margins)) |
|
self.out_dim = out_dim |
|
|
|
def forward(self, logits, labels): |
|
|
|
|
|
ms = self.margins[labels] |
|
cos_m = torch.cos(ms) |
|
sin_m = torch.sin(ms) |
|
th = torch.cos(math.pi - ms) |
|
mm = torch.sin(math.pi - ms) * ms |
|
labels = F.one_hot(labels, self.out_dim).float() |
|
cosine = logits |
|
sine = torch.sqrt(1.0 - cosine * cosine) |
|
phi = cosine * cos_m.view(-1,1) - sine * sin_m.view(-1,1) |
|
phi = torch.where(cosine > th.view(-1,1), phi, cosine - mm.view(-1,1)) |
|
output = (labels * phi) + ((1.0 - labels) * cosine) |
|
output *= self.s |
|
return output |
|
|
|
class ArcFaceSubCenterDynamic(nn.Module): |
|
def __init__( |
|
self, |
|
embedding_dim, |
|
output_classes, |
|
margins, |
|
s, |
|
k=2, |
|
): |
|
super().__init__() |
|
|
|
self.embedding_dim = embedding_dim |
|
self.output_classes = output_classes |
|
self.margins = margins |
|
self.s = s |
|
self.wmetric_classify = ArcMarginProduct_subcenter(self.embedding_dim, self.output_classes, k=k) |
|
|
|
self.warcface_margin = ArcFaceLossAdaptiveMargin(margins=self.margins, |
|
out_dim=self.output_classes, |
|
s=self.s) |
|
|
|
def forward(self, features, labels): |
|
logits = self.wmetric_classify(features.float()) |
|
logits = self.warcface_margin(logits, labels) |
|
return logits |