Spaces:
Runtime error
Runtime error
"""File for all blocks which are parts of decoders | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import climategan.strings as strings | |
from climategan.norms import SPADE, AdaptiveInstanceNorm2d, LayerNorm, SpectralNorm | |
class InterpolateNearest2d(nn.Module): | |
""" | |
Custom implementation of nn.Upsample because pytorch/xla | |
does not yet support scale_factor and needs to be provided with | |
the output_size | |
""" | |
def __init__(self, scale_factor=2): | |
""" | |
Create an InterpolateNearest2d module | |
Args: | |
scale_factor (int, optional): Output size multiplier. Defaults to 2. | |
""" | |
super().__init__() | |
self.scale_factor = scale_factor | |
def forward(self, x): | |
""" | |
Interpolate x in "nearest" mode on its last 2 dimensions | |
Args: | |
x (torch.Tensor): input to interpolate | |
Returns: | |
torch.Tensor: upsampled tensor with shape | |
(...x.shape, x.shape[-2] * scale_factor, x.shape[-1] * scale_factor) | |
""" | |
return F.interpolate( | |
x, | |
size=(x.shape[-2] * self.scale_factor, x.shape[-1] * self.scale_factor), | |
mode="nearest", | |
) | |
# ----------------------------------------- | |
# ----- Generic Convolutional Block ----- | |
# ----------------------------------------- | |
class Conv2dBlock(nn.Module): | |
def __init__( | |
self, | |
input_dim, | |
output_dim, | |
kernel_size, | |
stride=1, | |
padding=0, | |
dilation=1, | |
norm="none", | |
activation="relu", | |
pad_type="zero", | |
bias=True, | |
): | |
super().__init__() | |
self.use_bias = bias | |
# initialize padding | |
if pad_type == "reflect": | |
self.pad = nn.ReflectionPad2d(padding) | |
elif pad_type == "replicate": | |
self.pad = nn.ReplicationPad2d(padding) | |
elif pad_type == "zero": | |
self.pad = nn.ZeroPad2d(padding) | |
else: | |
assert 0, "Unsupported padding type: {}".format(pad_type) | |
# initialize normalization | |
use_spectral_norm = False | |
if norm.startswith("spectral_"): | |
norm = norm.replace("spectral_", "") | |
use_spectral_norm = True | |
norm_dim = output_dim | |
if norm == "batch": | |
self.norm = nn.BatchNorm2d(norm_dim) | |
elif norm == "instance": | |
# self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) | |
self.norm = nn.InstanceNorm2d(norm_dim) | |
elif norm == "layer": | |
self.norm = LayerNorm(norm_dim) | |
elif norm == "adain": | |
self.norm = AdaptiveInstanceNorm2d(norm_dim) | |
elif norm == "spectral" or norm.startswith("spectral_"): | |
self.norm = None # dealt with later in the code | |
elif norm == "none": | |
self.norm = None | |
else: | |
raise ValueError("Unsupported normalization: {}".format(norm)) | |
# initialize activation | |
if activation == "relu": | |
self.activation = nn.ReLU(inplace=False) | |
elif activation == "lrelu": | |
self.activation = nn.LeakyReLU(0.2, inplace=False) | |
elif activation == "prelu": | |
self.activation = nn.PReLU() | |
elif activation == "selu": | |
self.activation = nn.SELU(inplace=False) | |
elif activation == "tanh": | |
self.activation = nn.Tanh() | |
elif activation == "sigmoid": | |
self.activation = nn.Sigmoid() | |
elif activation == "none": | |
self.activation = None | |
else: | |
raise ValueError("Unsupported activation: {}".format(activation)) | |
# initialize convolution | |
if norm == "spectral" or use_spectral_norm: | |
self.conv = SpectralNorm( | |
nn.Conv2d( | |
input_dim, | |
output_dim, | |
kernel_size, | |
stride, | |
dilation=dilation, | |
bias=self.use_bias, | |
) | |
) | |
else: | |
self.conv = nn.Conv2d( | |
input_dim, | |
output_dim, | |
kernel_size, | |
stride, | |
dilation=dilation, | |
bias=self.use_bias if norm != "batch" else False, | |
) | |
def forward(self, x): | |
x = self.conv(self.pad(x)) | |
if self.norm is not None: | |
x = self.norm(x) | |
if self.activation is not None: | |
x = self.activation(x) | |
return x | |
def __str__(self): | |
return strings.conv2dblock(self) | |
# ----------------------------- | |
# ----- Residual Blocks ----- | |
# ----------------------------- | |
class ResBlocks(nn.Module): | |
""" | |
From https://github.com/NVlabs/MUNIT/blob/master/networks.py | |
""" | |
def __init__(self, num_blocks, dim, norm="in", activation="relu", pad_type="zero"): | |
super().__init__() | |
self.model = nn.Sequential( | |
*[ | |
ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type) | |
for _ in range(num_blocks) | |
] | |
) | |
def forward(self, x): | |
return self.model(x) | |
def __str__(self): | |
return strings.resblocks(self) | |
class ResBlock(nn.Module): | |
def __init__(self, dim, norm="in", activation="relu", pad_type="zero"): | |
super().__init__() | |
self.dim = dim | |
self.norm = norm | |
self.activation = activation | |
model = [] | |
model += [ | |
Conv2dBlock( | |
dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type | |
) | |
] | |
model += [ | |
Conv2dBlock( | |
dim, dim, 3, 1, 1, norm=norm, activation="none", pad_type=pad_type | |
) | |
] | |
self.model = nn.Sequential(*model) | |
def forward(self, x): | |
residual = x | |
out = self.model(x) | |
out += residual | |
return out | |
def __str__(self): | |
return strings.resblock(self) | |
# -------------------------- | |
# ----- Base Decoder ----- | |
# -------------------------- | |
class BaseDecoder(nn.Module): | |
def __init__( | |
self, | |
n_upsample=4, | |
n_res=4, | |
input_dim=2048, | |
proj_dim=64, | |
output_dim=3, | |
norm="batch", | |
activ="relu", | |
pad_type="zero", | |
output_activ="tanh", | |
low_level_feats_dim=-1, | |
use_dada=False, | |
): | |
super().__init__() | |
self.low_level_feats_dim = low_level_feats_dim | |
self.use_dada = use_dada | |
self.model = [] | |
if proj_dim != -1: | |
self.proj_conv = Conv2dBlock( | |
input_dim, proj_dim, 1, 1, 0, norm=norm, activation=activ | |
) | |
else: | |
self.proj_conv = None | |
proj_dim = input_dim | |
if low_level_feats_dim > 0: | |
self.low_level_conv = Conv2dBlock( | |
input_dim=low_level_feats_dim, | |
output_dim=proj_dim, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
pad_type=pad_type, | |
norm=norm, | |
activation=activ, | |
) | |
self.merge_feats_conv = Conv2dBlock( | |
input_dim=2 * proj_dim, | |
output_dim=proj_dim, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
pad_type=pad_type, | |
norm=norm, | |
activation=activ, | |
) | |
else: | |
self.low_level_conv = None | |
self.model += [ResBlocks(n_res, proj_dim, norm, activ, pad_type=pad_type)] | |
dim = proj_dim | |
# upsampling blocks | |
for i in range(n_upsample): | |
self.model += [ | |
InterpolateNearest2d(scale_factor=2), | |
Conv2dBlock( | |
input_dim=dim, | |
output_dim=dim // 2, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
pad_type=pad_type, | |
norm=norm, | |
activation=activ, | |
), | |
] | |
dim //= 2 | |
# use reflection padding in the last conv layer | |
self.model += [ | |
Conv2dBlock( | |
input_dim=dim, | |
output_dim=output_dim, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
pad_type=pad_type, | |
norm="none", | |
activation=output_activ, | |
) | |
] | |
self.model = nn.Sequential(*self.model) | |
def forward(self, z, cond=None, z_depth=None): | |
low_level_feat = None | |
if isinstance(z, (list, tuple)): | |
if self.low_level_conv is None: | |
z = z[0] | |
else: | |
z, low_level_feat = z | |
low_level_feat = self.low_level_conv(low_level_feat) | |
low_level_feat = F.interpolate( | |
low_level_feat, size=z.shape[-2:], mode="bilinear" | |
) | |
if z_depth is not None and self.use_dada: | |
z = z * z_depth | |
if self.proj_conv is not None: | |
z = self.proj_conv(z) | |
if low_level_feat is not None: | |
z = self.merge_feats_conv(torch.cat([low_level_feat, z], dim=1)) | |
return self.model(z) | |
def __str__(self): | |
return strings.basedecoder(self) | |
# -------------------------- | |
# ----- SPADE Blocks ----- | |
# -------------------------- | |
# https://github.com/NVlabs/SPADE/blob/0ff661e70131c9b85091d11a66e019c0f2062d4c | |
# /models/networks/generator.py | |
# 0ff661e on 13 Apr 2019 | |
class SPADEResnetBlock(nn.Module): | |
def __init__( | |
self, | |
fin, | |
fout, | |
cond_nc, | |
spade_use_spectral_norm, | |
spade_param_free_norm, | |
spade_kernel_size, | |
last_activation=None, | |
): | |
super().__init__() | |
# Attributes | |
self.fin = fin | |
self.fout = fout | |
self.use_spectral_norm = spade_use_spectral_norm | |
self.param_free_norm = spade_param_free_norm | |
self.kernel_size = spade_kernel_size | |
self.learned_shortcut = fin != fout | |
self.last_activation = last_activation | |
fmiddle = min(fin, fout) | |
# create conv layers | |
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) | |
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) | |
if self.learned_shortcut: | |
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) | |
# apply spectral norm if specified | |
if spade_use_spectral_norm: | |
self.conv_0 = SpectralNorm(self.conv_0) | |
self.conv_1 = SpectralNorm(self.conv_1) | |
if self.learned_shortcut: | |
self.conv_s = SpectralNorm(self.conv_s) | |
self.norm_0 = SPADE(spade_param_free_norm, spade_kernel_size, fin, cond_nc) | |
self.norm_1 = SPADE(spade_param_free_norm, spade_kernel_size, fmiddle, cond_nc) | |
if self.learned_shortcut: | |
self.norm_s = SPADE(spade_param_free_norm, spade_kernel_size, fin, cond_nc) | |
# note the resnet block with SPADE also takes in |seg|, | |
# the semantic segmentation map as input | |
def forward(self, x, seg): | |
x_s = self.shortcut(x, seg) | |
dx = self.conv_0(self.activation(self.norm_0(x, seg))) | |
dx = self.conv_1(self.activation(self.norm_1(dx, seg))) | |
out = x_s + dx | |
if self.last_activation == "lrelu": | |
return self.activation(out) | |
elif self.last_activation is None: | |
return out | |
else: | |
raise NotImplementedError( | |
"The type of activation is not supported: {}".format( | |
self.last_activation | |
) | |
) | |
def shortcut(self, x, seg): | |
if self.learned_shortcut: | |
x_s = self.conv_s(self.norm_s(x, seg)) | |
else: | |
x_s = x | |
return x_s | |
def activation(self, x): | |
return F.leaky_relu(x, 2e-1) | |
def __str__(self): | |
return strings.spaderesblock(self) | |