Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import Parameter | |
import math | |
class ArcNet(nn.Module): | |
def __init__(self, | |
feature_dim, | |
class_dim, | |
margin=0.2, | |
scale=30.0, | |
easy_margin=False): | |
super().__init__() | |
self.feature_dim = feature_dim | |
self.class_dim = class_dim | |
self.margin = margin | |
self.scale = scale | |
self.easy_margin = easy_margin | |
self.weight = Parameter(torch.FloatTensor(feature_dim, class_dim)) | |
nn.init.xavier_uniform_(self.weight) | |
def forward(self, input, label): | |
input_norm = torch.sqrt(torch.sum(torch.square(input), dim=1, keepdim=True)) | |
input = torch.divide(input, input_norm) | |
weight_norm = torch.sqrt(torch.sum(torch.square(self.weight), dim=0, keepdim=True)) | |
weight = torch.divide(self.weight, weight_norm) | |
cos = torch.matmul(input, weight) | |
sin = torch.sqrt(1.0 - torch.square(cos) + 1e-6) | |
cos_m = math.cos(self.margin) | |
sin_m = math.sin(self.margin) | |
phi = cos * cos_m - sin * sin_m | |
th = math.cos(self.margin) * (-1) | |
mm = math.sin(self.margin) * self.margin | |
if self.easy_margin: | |
phi = self._paddle_where_more_than(cos, 0, phi, cos) | |
else: | |
phi = self._paddle_where_more_than(cos, th, phi, cos - mm) | |
one_hot = torch.nn.functional.one_hot(label, self.class_dim) | |
one_hot = torch.squeeze(one_hot, dim=1) | |
output = torch.multiply(one_hot, phi) + torch.multiply((1.0 - one_hot), cos) | |
output = output * self.scale | |
return output | |
def _paddle_where_more_than(self, target, limit, x, y): | |
mask = (target > limit).float() | |
output = torch.multiply(mask, x) + torch.multiply((1.0 - mask), y) | |
return output | |