Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from climategan.blocks import ( | |
BaseDecoder, | |
Conv2dBlock, | |
InterpolateNearest2d, | |
SPADEResnetBlock, | |
) | |
def create_mask_decoder(opts, no_init=False, verbose=0): | |
if opts.gen.m.use_spade: | |
if verbose > 0: | |
print(" - Add Spade Mask Decoder") | |
assert "d" in opts.tasks or "s" in opts.tasks | |
return MaskSpadeDecoder(opts) | |
else: | |
if verbose > 0: | |
print(" - Add Base Mask Decoder") | |
return MaskBaseDecoder(opts) | |
class MaskBaseDecoder(BaseDecoder): | |
def __init__(self, opts): | |
low_level_feats_dim = -1 | |
use_v3 = opts.gen.encoder.architecture == "deeplabv3" | |
use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet" | |
use_low = opts.gen.m.use_low_level_feats | |
use_dada = ("d" in opts.tasks) and opts.gen.m.use_dada | |
if use_v3 and use_mobile_net: | |
input_dim = 320 | |
if use_low: | |
low_level_feats_dim = 24 | |
elif use_v3: | |
input_dim = 2048 | |
if use_low: | |
low_level_feats_dim = 256 | |
else: | |
input_dim = 2048 | |
super().__init__( | |
n_upsample=opts.gen.m.n_upsample, | |
n_res=opts.gen.m.n_res, | |
input_dim=input_dim, | |
proj_dim=opts.gen.m.proj_dim, | |
output_dim=opts.gen.m.output_dim, | |
norm=opts.gen.m.norm, | |
activ=opts.gen.m.activ, | |
pad_type=opts.gen.m.pad_type, | |
output_activ="none", | |
low_level_feats_dim=low_level_feats_dim, | |
use_dada=use_dada, | |
) | |
class MaskSpadeDecoder(nn.Module): | |
def __init__(self, opts): | |
"""Create a SPADE-based decoder, which forwards z and the conditioning | |
tensors seg (in the original paper, conditioning is on a semantic map only). | |
All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink | |
the channel dimension, and an upsampling is applied after each. Therefore | |
2 upsamplings at this point. Then, for each remaining upsamplings | |
(w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3 | |
channels, the number of channels is therefore: | |
final_nc = channels(z) * 2 ** (spade_n_up - 2) | |
Args: | |
latent_dim (tuple): z's shape (only the number of channels matters) | |
cond_nc (int): conditioning tensor's expected number of channels | |
spade_n_up (int): Number of total upsamplings from z | |
spade_use_spectral_norm (bool): use spectral normalization? | |
spade_param_free_norm (str): norm to use before SPADE de-normalization | |
spade_kernel_size (int): SPADE conv layers' kernel size | |
Returns: | |
[type]: [description] | |
""" | |
super().__init__() | |
self.opts = opts | |
latent_dim = opts.gen.m.spade.latent_dim | |
cond_nc = opts.gen.m.spade.cond_nc | |
spade_use_spectral_norm = opts.gen.m.spade.spade_use_spectral_norm | |
spade_param_free_norm = opts.gen.m.spade.spade_param_free_norm | |
if self.opts.gen.m.spade.activations.all_lrelu: | |
spade_activation = "lrelu" | |
else: | |
spade_activation = None | |
spade_kernel_size = 3 | |
self.num_layers = opts.gen.m.spade.num_layers | |
self.z_nc = latent_dim | |
if ( | |
opts.gen.encoder.architecture == "deeplabv3" | |
and opts.gen.deeplabv3.backbone == "mobilenet" | |
): | |
self.input_dim = [320, 24] | |
self.low_level_conv = Conv2dBlock( | |
self.input_dim[1], | |
self.input_dim[0], | |
3, | |
padding=1, | |
activation="lrelu", | |
pad_type="reflect", | |
norm="spectral_batch", | |
) | |
self.merge_feats_conv = Conv2dBlock( | |
self.input_dim[0] * 2, | |
self.z_nc, | |
3, | |
padding=1, | |
activation="lrelu", | |
pad_type="reflect", | |
norm="spectral_batch", | |
) | |
elif ( | |
opts.gen.encoder.architecture == "deeplabv3" | |
and opts.gen.deeplabv3.backbone == "resnet" | |
): | |
self.input_dim = [2048, 256] | |
if self.opts.gen.m.use_proj: | |
proj_dim = self.opts.gen.m.proj_dim | |
self.low_level_conv = Conv2dBlock( | |
self.input_dim[1], | |
proj_dim, | |
3, | |
padding=1, | |
activation="lrelu", | |
pad_type="reflect", | |
norm="spectral_batch", | |
) | |
self.high_level_conv = Conv2dBlock( | |
self.input_dim[0], | |
proj_dim, | |
3, | |
padding=1, | |
activation="lrelu", | |
pad_type="reflect", | |
norm="spectral_batch", | |
) | |
self.merge_feats_conv = Conv2dBlock( | |
proj_dim * 2, | |
self.z_nc, | |
3, | |
padding=1, | |
activation="lrelu", | |
pad_type="reflect", | |
norm="spectral_batch", | |
) | |
else: | |
self.low_level_conv = Conv2dBlock( | |
self.input_dim[1], | |
self.input_dim[0], | |
3, | |
padding=1, | |
activation="lrelu", | |
pad_type="reflect", | |
norm="spectral_batch", | |
) | |
self.merge_feats_conv = Conv2dBlock( | |
self.input_dim[0] * 2, | |
self.z_nc, | |
3, | |
padding=1, | |
activation="lrelu", | |
pad_type="reflect", | |
norm="spectral_batch", | |
) | |
elif opts.gen.encoder.architecture == "deeplabv2": | |
self.input_dim = 2048 | |
self.fc_conv = Conv2dBlock( | |
self.input_dim, | |
self.z_nc, | |
3, | |
padding=1, | |
activation="lrelu", | |
pad_type="reflect", | |
norm="spectral_batch", | |
) | |
else: | |
raise ValueError("Unknown encoder type") | |
self.spade_blocks = [] | |
for i in range(self.num_layers): | |
self.spade_blocks.append( | |
SPADEResnetBlock( | |
int(self.z_nc / (2**i)), | |
int(self.z_nc / (2 ** (i + 1))), | |
cond_nc, | |
spade_use_spectral_norm, | |
spade_param_free_norm, | |
spade_kernel_size, | |
spade_activation, | |
) | |
) | |
self.spade_blocks = nn.Sequential(*self.spade_blocks) | |
self.final_nc = int(self.z_nc / (2**self.num_layers)) | |
self.mask_conv = Conv2dBlock( | |
self.final_nc, | |
1, | |
3, | |
padding=1, | |
activation="none", | |
pad_type="reflect", | |
norm="spectral", | |
) | |
self.upsample = InterpolateNearest2d(scale_factor=2) | |
def forward(self, z, cond, z_depth=None): | |
if isinstance(z, (list, tuple)): | |
z_h, z_l = z | |
if self.opts.gen.m.use_proj: | |
z_l = self.low_level_conv(z_l) | |
z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear") | |
z_h = self.high_level_conv(z_h) | |
else: | |
z_l = self.low_level_conv(z_l) | |
z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear") | |
z = torch.cat([z_h, z_l], axis=1) | |
y = self.merge_feats_conv(z) | |
else: | |
y = self.fc_conv(z) | |
for i in range(self.num_layers): | |
y = self.spade_blocks[i](y, cond) | |
y = self.upsample(y) | |
y = self.mask_conv(y) | |
return y | |
def __str__(self): | |
return "MaskerSpadeDecoder" | |