miewid-msv2 / heads.py
lashao's picture
Upload MiewIdNet
a37ced9 verified
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class ArcMarginProduct(nn.Module):
r"""Implement of large margin arc distance: :
Args:
in_features: size of each input sample
out_features: size of each output sample
s: norm of input feature
m: margin
cos(theta + m)wandb: ERROR Abnormal program exit
"""
def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False, ls_eps=0.0):
super(ArcMarginProduct, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.ls_eps = ls_eps # label smoothing
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
def forward(self, input, label):
# --------------------------- cos(theta) & phi(theta) ---------------------------
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# --------------------------- convert label to one-hot ---------------------------
# one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
if self.ls_eps > 0:
one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.s
return output
def l2_norm(input, axis = 1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
class ElasticArcFace(nn.Module):
def __init__(self, in_features, out_features, s=64.0, m=0.50,std=0.0125,plus=False, k=None):
super(ElasticArcFace, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m
self.kernel = nn.Parameter(torch.FloatTensor(in_features, out_features))
nn.init.normal_(self.kernel, std=0.01)
self.std=std
self.plus=plus
def forward(self, embbedings, label):
embbedings = l2_norm(embbedings, axis=1)
kernel_norm = l2_norm(self.kernel, axis=0)
cos_theta = torch.mm(embbedings, kernel_norm)
cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
index = torch.where(label != -1)[0]
m_hot = torch.zeros(index.size()[0], cos_theta.size()[1], device=cos_theta.device)
margin = torch.normal(mean=self.m, std=self.std, size=label[index, None].size(), device=cos_theta.device) # Fast converge .clamp(self.m-self.std, self.m+self.std)
if self.plus:
with torch.no_grad():
distmat = cos_theta[index, label.view(-1)].detach().clone()
_, idicate_cosie = torch.sort(distmat, dim=0, descending=True)
margin, _ = torch.sort(margin, dim=0)
m_hot.scatter_(1, label[index, None], margin[idicate_cosie])
else:
m_hot.scatter_(1, label[index, None], margin)
cos_theta.acos_()
cos_theta[index] += m_hot
cos_theta.cos_().mul_(self.s)
return cos_theta
########## Subcenter Arcface with dynamic margin ##########
class ArcMarginProduct_subcenter(nn.Module):
def __init__(self, in_features, out_features, k=3):
super().__init__()
self.weight = nn.Parameter(torch.FloatTensor(out_features*k, in_features))
self.reset_parameters()
self.k = k
self.out_features = out_features
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, features):
cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
cosine_all = cosine_all.view(-1, self.out_features, self.k)
cosine, _ = torch.max(cosine_all, dim=2)
return cosine
class ArcFaceLossAdaptiveMargin(nn.modules.Module):
def __init__(self, margins, out_dim, s):
super().__init__()
# self.crit = nn.CrossEntropyLoss()
self.s = s
self.register_buffer('margins', torch.tensor(margins))
self.out_dim = out_dim
def forward(self, logits, labels):
#ms = []
#ms = self.margins[labels.cpu().numpy()]
ms = self.margins[labels]
cos_m = torch.cos(ms) #torch.from_numpy(np.cos(ms)).float().cuda()
sin_m = torch.sin(ms) #torch.from_numpy(np.sin(ms)).float().cuda()
th = torch.cos(math.pi - ms)#torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
mm = torch.sin(math.pi - ms) * ms#torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
labels = F.one_hot(labels, self.out_dim).float()
cosine = logits
sine = torch.sqrt(1.0 - cosine * cosine)
phi = cosine * cos_m.view(-1,1) - sine * sin_m.view(-1,1)
phi = torch.where(cosine > th.view(-1,1), phi, cosine - mm.view(-1,1))
output = (labels * phi) + ((1.0 - labels) * cosine)
output *= self.s
return output
class ArcFaceSubCenterDynamic(nn.Module):
def __init__(
self,
embedding_dim,
output_classes,
margins,
s,
k=2,
):
super().__init__()
self.embedding_dim = embedding_dim
self.output_classes = output_classes
self.margins = margins
self.s = s
self.wmetric_classify = ArcMarginProduct_subcenter(self.embedding_dim, self.output_classes, k=k)
self.warcface_margin = ArcFaceLossAdaptiveMargin(margins=self.margins,
out_dim=self.output_classes,
s=self.s)
def forward(self, features, labels):
logits = self.wmetric_classify(features.float())
logits = self.warcface_margin(logits, labels)
return logits