import torch import torch.nn as nn class EvoNorm2d(nn.Module): __constants__ = ['num_features', 'eps', 'nonlinearity'] def __init__(self, num_features, eps=1e-5, nonlinearity=True, group=32): super(EvoNorm2d, self).__init__() self.num_features = num_features self.eps = eps self.nonlinearity = nonlinearity self.group = group self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) if self.nonlinearity: self.v = nn.Parameter(torch.Tensor(1, num_features, 1, 1)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) if self.nonlinearity: nn.init.ones_(self.v) def group_std(self, x, groups=32): N, C, H, W = x.shape x = torch.reshape(x, (N, groups, C // groups, H, W)) std = torch.std(x, (3, 4), keepdim=True) return torch.reshape(std + self.eps, (N, C, 1, 1)) def forward(self, x): if self.nonlinearity: num = x * torch.sigmoid(self.v * x) return num / self.group_std(x, self.group) * self.weight + self.bias else: return x * self.weight + self.bias