Spaces:
Runtime error
Runtime error
"""Normalization layers used in blocks | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class AdaptiveInstanceNorm2d(nn.Module): | |
def __init__(self, num_features, eps=1e-5, momentum=0.1): | |
super(AdaptiveInstanceNorm2d, self).__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.momentum = momentum | |
# weight and bias are dynamically assigned | |
self.weight = None | |
self.bias = None | |
# just dummy buffers, not used | |
self.register_buffer("running_mean", torch.zeros(num_features)) | |
self.register_buffer("running_var", torch.ones(num_features)) | |
def forward(self, x): | |
assert ( | |
self.weight is not None and self.bias is not None | |
), "Please assign weight and bias before calling AdaIN!" | |
b, c = x.size(0), x.size(1) | |
running_mean = self.running_mean.repeat(b) | |
running_var = self.running_var.repeat(b) | |
# Apply instance norm | |
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) | |
out = F.batch_norm( | |
x_reshaped, | |
running_mean, | |
running_var, | |
self.weight, | |
self.bias, | |
True, | |
self.momentum, | |
self.eps, | |
) | |
return out.view(b, c, *x.size()[2:]) | |
def __repr__(self): | |
return self.__class__.__name__ + "(" + str(self.num_features) + ")" | |
class LayerNorm(nn.Module): | |
def __init__(self, num_features, eps=1e-5, affine=True): | |
super(LayerNorm, self).__init__() | |
self.num_features = num_features | |
self.affine = affine | |
self.eps = eps | |
if self.affine: | |
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) | |
self.beta = nn.Parameter(torch.zeros(num_features)) | |
def forward(self, x): | |
shape = [-1] + [1] * (x.dim() - 1) | |
# print(x.size()) | |
if x.size(0) == 1: | |
# These two lines run much faster in pytorch 0.4 | |
# than the two lines listed below. | |
mean = x.view(-1).mean().view(*shape) | |
std = x.view(-1).std().view(*shape) | |
else: | |
mean = x.view(x.size(0), -1).mean(1).view(*shape) | |
std = x.view(x.size(0), -1).std(1).view(*shape) | |
x = (x - mean) / (std + self.eps) | |
if self.affine: | |
shape = [1, -1] + [1] * (x.dim() - 2) | |
x = x * self.gamma.view(*shape) + self.beta.view(*shape) | |
return x | |
def l2normalize(v, eps=1e-12): | |
return v / (v.norm() + eps) | |
class SpectralNorm(nn.Module): | |
""" | |
Based on the paper "Spectral Normalization for Generative Adversarial Networks" | |
by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida and the | |
Pytorch implementation: | |
https://github.com/christiancosgrove/pytorch-spectral-normalization-gan | |
""" | |
def __init__(self, module, name="weight", power_iterations=1): | |
super().__init__() | |
self.module = module | |
self.name = name | |
self.power_iterations = power_iterations | |
if not self._made_params(): | |
self._make_params() | |
def _update_u_v(self): | |
u = getattr(self.module, self.name + "_u") | |
v = getattr(self.module, self.name + "_v") | |
w = getattr(self.module, self.name + "_bar") | |
height = w.data.shape[0] | |
for _ in range(self.power_iterations): | |
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) | |
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) | |
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) | |
sigma = u.dot(w.view(height, -1).mv(v)) | |
setattr(self.module, self.name, w / sigma.expand_as(w)) | |
def _made_params(self): | |
try: | |
u = getattr(self.module, self.name + "_u") # noqa: F841 | |
v = getattr(self.module, self.name + "_v") # noqa: F841 | |
w = getattr(self.module, self.name + "_bar") # noqa: F841 | |
return True | |
except AttributeError: | |
return False | |
def _make_params(self): | |
w = getattr(self.module, self.name) | |
height = w.data.shape[0] | |
width = w.view(height, -1).data.shape[1] | |
u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) | |
v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) | |
u.data = l2normalize(u.data) | |
v.data = l2normalize(v.data) | |
w_bar = nn.Parameter(w.data) | |
del self.module._parameters[self.name] | |
self.module.register_parameter(self.name + "_u", u) | |
self.module.register_parameter(self.name + "_v", v) | |
self.module.register_parameter(self.name + "_bar", w_bar) | |
def forward(self, *args): | |
self._update_u_v() | |
return self.module.forward(*args) | |
class SPADE(nn.Module): | |
def __init__(self, param_free_norm_type, kernel_size, norm_nc, cond_nc): | |
super().__init__() | |
if param_free_norm_type == "instance": | |
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) | |
# elif param_free_norm_type == "syncbatch": | |
# self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) | |
elif param_free_norm_type == "batch": | |
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) | |
else: | |
raise ValueError( | |
"%s is not a recognized param-free norm type in SPADE" | |
% param_free_norm_type | |
) | |
# The dimension of the intermediate embedding space. Yes, hardcoded. | |
nhidden = 128 | |
pw = kernel_size // 2 | |
self.mlp_shared = nn.Sequential( | |
nn.Conv2d(cond_nc, nhidden, kernel_size=kernel_size, padding=pw), nn.ReLU() | |
) | |
self.mlp_gamma = nn.Conv2d( | |
nhidden, norm_nc, kernel_size=kernel_size, padding=pw | |
) | |
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=kernel_size, padding=pw) | |
def forward(self, x, segmap): | |
# Part 1. generate parameter-free normalized activations | |
normalized = self.param_free_norm(x) | |
# Part 2. produce scaling and bias conditioned on semantic map | |
segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") | |
actv = self.mlp_shared(segmap) | |
gamma = self.mlp_gamma(actv) | |
beta = self.mlp_beta(actv) | |
# apply scale and bias | |
out = normalized * (1 + gamma) + beta | |
return out | |