climateGAN / climategan /masker.py
vict0rsch's picture
initial commit from `vict0rsch/climateGAN`
ce190ee
raw
history blame
8.18 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from climategan.blocks import (
BaseDecoder,
Conv2dBlock,
InterpolateNearest2d,
SPADEResnetBlock,
)
def create_mask_decoder(opts, no_init=False, verbose=0):
if opts.gen.m.use_spade:
if verbose > 0:
print(" - Add Spade Mask Decoder")
assert "d" in opts.tasks or "s" in opts.tasks
return MaskSpadeDecoder(opts)
else:
if verbose > 0:
print(" - Add Base Mask Decoder")
return MaskBaseDecoder(opts)
class MaskBaseDecoder(BaseDecoder):
def __init__(self, opts):
low_level_feats_dim = -1
use_v3 = opts.gen.encoder.architecture == "deeplabv3"
use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet"
use_low = opts.gen.m.use_low_level_feats
use_dada = ("d" in opts.tasks) and opts.gen.m.use_dada
if use_v3 and use_mobile_net:
input_dim = 320
if use_low:
low_level_feats_dim = 24
elif use_v3:
input_dim = 2048
if use_low:
low_level_feats_dim = 256
else:
input_dim = 2048
super().__init__(
n_upsample=opts.gen.m.n_upsample,
n_res=opts.gen.m.n_res,
input_dim=input_dim,
proj_dim=opts.gen.m.proj_dim,
output_dim=opts.gen.m.output_dim,
norm=opts.gen.m.norm,
activ=opts.gen.m.activ,
pad_type=opts.gen.m.pad_type,
output_activ="none",
low_level_feats_dim=low_level_feats_dim,
use_dada=use_dada,
)
class MaskSpadeDecoder(nn.Module):
def __init__(self, opts):
"""Create a SPADE-based decoder, which forwards z and the conditioning
tensors seg (in the original paper, conditioning is on a semantic map only).
All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink
the channel dimension, and an upsampling is applied after each. Therefore
2 upsamplings at this point. Then, for each remaining upsamplings
(w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3
channels, the number of channels is therefore:
final_nc = channels(z) * 2 ** (spade_n_up - 2)
Args:
latent_dim (tuple): z's shape (only the number of channels matters)
cond_nc (int): conditioning tensor's expected number of channels
spade_n_up (int): Number of total upsamplings from z
spade_use_spectral_norm (bool): use spectral normalization?
spade_param_free_norm (str): norm to use before SPADE de-normalization
spade_kernel_size (int): SPADE conv layers' kernel size
Returns:
[type]: [description]
"""
super().__init__()
self.opts = opts
latent_dim = opts.gen.m.spade.latent_dim
cond_nc = opts.gen.m.spade.cond_nc
spade_use_spectral_norm = opts.gen.m.spade.spade_use_spectral_norm
spade_param_free_norm = opts.gen.m.spade.spade_param_free_norm
if self.opts.gen.m.spade.activations.all_lrelu:
spade_activation = "lrelu"
else:
spade_activation = None
spade_kernel_size = 3
self.num_layers = opts.gen.m.spade.num_layers
self.z_nc = latent_dim
if (
opts.gen.encoder.architecture == "deeplabv3"
and opts.gen.deeplabv3.backbone == "mobilenet"
):
self.input_dim = [320, 24]
self.low_level_conv = Conv2dBlock(
self.input_dim[1],
self.input_dim[0],
3,
padding=1,
activation="lrelu",
pad_type="reflect",
norm="spectral_batch",
)
self.merge_feats_conv = Conv2dBlock(
self.input_dim[0] * 2,
self.z_nc,
3,
padding=1,
activation="lrelu",
pad_type="reflect",
norm="spectral_batch",
)
elif (
opts.gen.encoder.architecture == "deeplabv3"
and opts.gen.deeplabv3.backbone == "resnet"
):
self.input_dim = [2048, 256]
if self.opts.gen.m.use_proj:
proj_dim = self.opts.gen.m.proj_dim
self.low_level_conv = Conv2dBlock(
self.input_dim[1],
proj_dim,
3,
padding=1,
activation="lrelu",
pad_type="reflect",
norm="spectral_batch",
)
self.high_level_conv = Conv2dBlock(
self.input_dim[0],
proj_dim,
3,
padding=1,
activation="lrelu",
pad_type="reflect",
norm="spectral_batch",
)
self.merge_feats_conv = Conv2dBlock(
proj_dim * 2,
self.z_nc,
3,
padding=1,
activation="lrelu",
pad_type="reflect",
norm="spectral_batch",
)
else:
self.low_level_conv = Conv2dBlock(
self.input_dim[1],
self.input_dim[0],
3,
padding=1,
activation="lrelu",
pad_type="reflect",
norm="spectral_batch",
)
self.merge_feats_conv = Conv2dBlock(
self.input_dim[0] * 2,
self.z_nc,
3,
padding=1,
activation="lrelu",
pad_type="reflect",
norm="spectral_batch",
)
elif opts.gen.encoder.architecture == "deeplabv2":
self.input_dim = 2048
self.fc_conv = Conv2dBlock(
self.input_dim,
self.z_nc,
3,
padding=1,
activation="lrelu",
pad_type="reflect",
norm="spectral_batch",
)
else:
raise ValueError("Unknown encoder type")
self.spade_blocks = []
for i in range(self.num_layers):
self.spade_blocks.append(
SPADEResnetBlock(
int(self.z_nc / (2**i)),
int(self.z_nc / (2 ** (i + 1))),
cond_nc,
spade_use_spectral_norm,
spade_param_free_norm,
spade_kernel_size,
spade_activation,
)
)
self.spade_blocks = nn.Sequential(*self.spade_blocks)
self.final_nc = int(self.z_nc / (2**self.num_layers))
self.mask_conv = Conv2dBlock(
self.final_nc,
1,
3,
padding=1,
activation="none",
pad_type="reflect",
norm="spectral",
)
self.upsample = InterpolateNearest2d(scale_factor=2)
def forward(self, z, cond, z_depth=None):
if isinstance(z, (list, tuple)):
z_h, z_l = z
if self.opts.gen.m.use_proj:
z_l = self.low_level_conv(z_l)
z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear")
z_h = self.high_level_conv(z_h)
else:
z_l = self.low_level_conv(z_l)
z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear")
z = torch.cat([z_h, z_l], axis=1)
y = self.merge_feats_conv(z)
else:
y = self.fc_conv(z)
for i in range(self.num_layers):
y = self.spade_blocks[i](y, cond)
y = self.upsample(y)
y = self.mask_conv(y)
return y
def __str__(self):
return "MaskerSpadeDecoder"