Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import climategan.strings as strings | |
from climategan.blocks import InterpolateNearest2d, SPADEResnetBlock | |
from climategan.norms import SpectralNorm | |
def create_painter(opts, no_init=False, verbose=0): | |
if verbose > 0: | |
print(" - Add PainterSpadeDecoder Painter") | |
return PainterSpadeDecoder(opts) | |
class PainterSpadeDecoder(nn.Module): | |
def __init__(self, opts): | |
"""Create a SPADE-based decoder, which forwards z and the conditioning | |
tensors seg (in the original paper, conditioning is on a semantic map only). | |
All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink | |
the channel dimension, and an upsampling is applied after each. Therefore | |
2 upsamplings at this point. Then, for each remaining upsamplings | |
(w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3 | |
channels, the number of channels is therefore: | |
final_nc = channels(z) * 2 ** (spade_n_up - 2) | |
Args: | |
latent_dim (tuple): z's shape (only the number of channels matters) | |
cond_nc (int): conditioning tensor's expected number of channels | |
spade_n_up (int): Number of total upsamplings from z | |
spade_use_spectral_norm (bool): use spectral normalization? | |
spade_param_free_norm (str): norm to use before SPADE de-normalization | |
spade_kernel_size (int): SPADE conv layers' kernel size | |
Returns: | |
[type]: [description] | |
""" | |
super().__init__() | |
latent_dim = opts.gen.p.latent_dim | |
cond_nc = 3 | |
spade_n_up = opts.gen.p.spade_n_up | |
spade_use_spectral_norm = opts.gen.p.spade_use_spectral_norm | |
spade_param_free_norm = opts.gen.p.spade_param_free_norm | |
spade_kernel_size = 3 | |
self.z_nc = latent_dim | |
self.spade_n_up = spade_n_up | |
self.z_h = self.z_w = None | |
self.fc = nn.Conv2d(3, latent_dim, 3, padding=1) | |
self.head_0 = SPADEResnetBlock( | |
self.z_nc, | |
self.z_nc, | |
cond_nc, | |
spade_use_spectral_norm, | |
spade_param_free_norm, | |
spade_kernel_size, | |
) | |
self.G_middle_0 = SPADEResnetBlock( | |
self.z_nc, | |
self.z_nc, | |
cond_nc, | |
spade_use_spectral_norm, | |
spade_param_free_norm, | |
spade_kernel_size, | |
) | |
self.G_middle_1 = SPADEResnetBlock( | |
self.z_nc, | |
self.z_nc, | |
cond_nc, | |
spade_use_spectral_norm, | |
spade_param_free_norm, | |
spade_kernel_size, | |
) | |
self.up_spades = nn.Sequential( | |
*[ | |
SPADEResnetBlock( | |
self.z_nc // 2 ** i, | |
self.z_nc // 2 ** (i + 1), | |
cond_nc, | |
spade_use_spectral_norm, | |
spade_param_free_norm, | |
spade_kernel_size, | |
) | |
for i in range(spade_n_up - 2) | |
] | |
) | |
self.final_nc = self.z_nc // 2 ** (spade_n_up - 2) | |
self.final_spade = SPADEResnetBlock( | |
self.final_nc, | |
self.final_nc, | |
cond_nc, | |
spade_use_spectral_norm, | |
spade_param_free_norm, | |
spade_kernel_size, | |
) | |
self.final_shortcut = None | |
if opts.gen.p.use_final_shortcut: | |
self.final_shortcut = nn.Sequential( | |
*[ | |
SpectralNorm(nn.Conv2d(self.final_nc, 3, 1)), | |
nn.BatchNorm2d(3), | |
nn.LeakyReLU(0.2, True), | |
] | |
) | |
self.conv_img = nn.Conv2d(self.final_nc, 3, 3, padding=1) | |
self.upsample = InterpolateNearest2d(scale_factor=2) | |
def set_latent_shape(self, shape, is_input=True): | |
""" | |
Sets the latent shape to start the upsampling from, i.e. z_h and z_w. | |
If is_input is True, then this is the actual input shape which should | |
be divided by 2 ** spade_n_up | |
Otherwise, just sets z_h and z_w from shape[-2] and shape[-1] | |
Args: | |
shape (tuple): The shape to start sampling from. | |
is_input (bool, optional): Whether to divide shape by 2 ** spade_n_up | |
""" | |
if isinstance(shape, (list, tuple)): | |
self.z_h = shape[-2] | |
self.z_w = shape[-1] | |
elif isinstance(shape, int): | |
self.z_h = self.z_w = shape | |
else: | |
raise ValueError("Unknown shape type:", shape) | |
if is_input: | |
self.z_h = self.z_h // (2 ** self.spade_n_up) | |
self.z_w = self.z_w // (2 ** self.spade_n_up) | |
def _apply(self, fn): | |
# print("Applying SpadeDecoder", fn) | |
super()._apply(fn) | |
# self.head_0 = fn(self.head_0) | |
# self.G_middle_0 = fn(self.G_middle_0) | |
# self.G_middle_1 = fn(self.G_middle_1) | |
# for i, up in enumerate(self.up_spades): | |
# self.up_spades[i] = fn(up) | |
# self.conv_img = fn(self.conv_img) | |
return self | |
def forward(self, z, cond): | |
if z is None: | |
assert self.z_h is not None and self.z_w is not None | |
z = self.fc(F.interpolate(cond, size=(self.z_h, self.z_w))) | |
y = self.head_0(z, cond) | |
y = self.upsample(y) | |
y = self.G_middle_0(y, cond) | |
y = self.upsample(y) | |
y = self.G_middle_1(y, cond) | |
for i, up in enumerate(self.up_spades): | |
y = self.upsample(y) | |
y = up(y, cond) | |
if self.final_shortcut is not None: | |
cond = self.final_shortcut(y) | |
y = self.final_spade(y, cond) | |
y = self.conv_img(F.leaky_relu(y, 2e-1)) | |
y = torch.tanh(y) | |
return y | |
def __str__(self): | |
return strings.spadedecoder(self) | |