climateGAN / climategan /discriminator.py
vict0rsch's picture
initial commit from `vict0rsch/climateGAN`
ce190ee
raw
history blame
12.8 kB
"""Discriminator architecture for ClimateGAN's GAN components (a and t)
"""
import functools
import torch
import torch.nn as nn
from climategan.blocks import SpectralNorm
from climategan.tutils import init_weights
# from torch.optim import lr_scheduler
# mainly from https://github.com/sangwoomo/instagan/blob/master/models/networks.py
def create_discriminator(opts, device, no_init=False, verbose=0):
disc = OmniDiscriminator(opts)
if no_init:
return disc
for task, model in disc.items():
if isinstance(model, nn.ModuleDict):
for domain, domain_model in model.items():
init_weights(
domain_model,
init_type=opts.dis[task].init_type,
init_gain=opts.dis[task].init_gain,
verbose=verbose,
caller=f"create_discriminator {task} {domain}",
)
else:
init_weights(
model,
init_type=opts.dis[task].init_type,
init_gain=opts.dis[task].init_gain,
verbose=verbose,
caller=f"create_discriminator {task}",
)
return disc.to(device)
def define_D(
input_nc,
ndf,
n_layers=3,
norm="batch",
use_sigmoid=False,
get_intermediate_features=False,
num_D=1,
):
norm_layer = get_norm_layer(norm_type=norm)
net = MultiscaleDiscriminator(
input_nc,
ndf,
n_layers=n_layers,
norm_layer=norm_layer,
use_sigmoid=use_sigmoid,
get_intermediate_features=get_intermediate_features,
num_D=num_D,
)
return net
def get_norm_layer(norm_type="instance"):
if not norm_type:
print("norm_type is {}, defaulting to instance")
norm_type = "instance"
if norm_type == "batch":
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == "instance":
norm_layer = functools.partial(
nn.InstanceNorm2d, affine=False, track_running_stats=False
)
elif norm_type == "none":
norm_layer = None
else:
raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
return norm_layer
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(
self,
input_nc=3,
ndf=64,
n_layers=3,
norm_layer=nn.BatchNorm2d,
use_sigmoid=False,
get_intermediate_features=True,
):
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
self.get_intermediate_features = get_intermediate_features
kw = 4
padw = 1
sequence = [
[
# Use spectral normalization
SpectralNorm(
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)
),
nn.LeakyReLU(0.2, True),
]
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
[
# Use spectral normalization
SpectralNorm( # TODO replace with Conv2dBlock
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw,
bias=use_bias,
)
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
[
# Use spectral normalization
SpectralNorm(
nn.Conv2d(
ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
padding=padw,
bias=use_bias,
)
),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True),
]
]
# Use spectral normalization
sequence += [
[
SpectralNorm(
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
)
]
]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
# We divide the layers into groups to extract intermediate layer outputs
for n in range(len(sequence)):
self.add_module("model" + str(n), nn.Sequential(*sequence[n]))
# self.model = nn.Sequential(*sequence)
def forward(self, input):
results = [input]
for submodel in self.children():
intermediate_output = submodel(results[-1])
results.append(intermediate_output)
get_intermediate_features = self.get_intermediate_features
if get_intermediate_features:
return results[1:]
else:
return results[-1]
# def forward(self, input):
# return self.model(input)
# Source: https://github.com/NVIDIA/pix2pixHD
class MultiscaleDiscriminator(nn.Module):
def __init__(
self,
input_nc=3,
ndf=64,
n_layers=3,
norm_layer=nn.BatchNorm2d,
use_sigmoid=False,
get_intermediate_features=True,
num_D=3,
):
super(MultiscaleDiscriminator, self).__init__()
# self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
# use_sigmoid=False, num_D=3, getIntermFeat=False
self.n_layers = n_layers
self.ndf = ndf
self.norm_layer = norm_layer
self.use_sigmoid = use_sigmoid
self.get_intermediate_features = get_intermediate_features
self.num_D = num_D
for i in range(self.num_D):
netD = NLayerDiscriminator(
input_nc=input_nc,
ndf=self.ndf,
n_layers=self.n_layers,
norm_layer=self.norm_layer,
use_sigmoid=self.use_sigmoid,
get_intermediate_features=self.get_intermediate_features,
)
self.add_module("discriminator_%d" % i, netD)
self.downsample = nn.AvgPool2d(
3, stride=2, padding=[1, 1], count_include_pad=False
)
def forward(self, input):
result = []
get_intermediate_features = self.get_intermediate_features
for name, D in self.named_children():
if "discriminator" not in name:
continue
out = D(input)
if not get_intermediate_features:
out = [out]
result.append(out)
input = self.downsample(input)
return result
class OmniDiscriminator(nn.ModuleDict):
def __init__(self, opts):
super().__init__()
if "p" in opts.tasks:
if opts.dis.p.use_local_discriminator:
self["p"] = nn.ModuleDict(
{
"global": define_D(
input_nc=3,
ndf=opts.dis.p.ndf,
n_layers=opts.dis.p.n_layers,
norm=opts.dis.p.norm,
use_sigmoid=opts.dis.p.use_sigmoid,
get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501
num_D=opts.dis.p.num_D,
),
"local": define_D(
input_nc=3,
ndf=opts.dis.p.ndf,
n_layers=opts.dis.p.n_layers,
norm=opts.dis.p.norm,
use_sigmoid=opts.dis.p.use_sigmoid,
get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501
num_D=opts.dis.p.num_D,
),
}
)
else:
self["p"] = define_D(
input_nc=4, # image + mask
ndf=opts.dis.p.ndf,
n_layers=opts.dis.p.n_layers,
norm=opts.dis.p.norm,
use_sigmoid=opts.dis.p.use_sigmoid,
get_intermediate_features=opts.dis.p.get_intermediate_features,
num_D=opts.dis.p.num_D,
)
if "m" in opts.tasks:
if opts.gen.m.use_advent:
if opts.dis.m.architecture == "base":
if opts.dis.m.gan_type == "WGAN_norm":
self["m"] = nn.ModuleDict(
{
"Advent": get_fc_discriminator(
num_classes=2, use_norm=True
)
}
)
else:
self["m"] = nn.ModuleDict(
{
"Advent": get_fc_discriminator(
num_classes=2, use_norm=False
)
}
)
elif opts.dis.m.architecture == "OmniDiscriminator":
self["m"] = nn.ModuleDict(
{
"Advent": define_D(
input_nc=2,
ndf=opts.dis.m.ndf,
n_layers=opts.dis.m.n_layers,
norm=opts.dis.m.norm,
use_sigmoid=opts.dis.m.use_sigmoid,
get_intermediate_features=opts.dis.m.get_intermediate_features, # noqa: E501
num_D=opts.dis.m.num_D,
)
}
)
else:
raise Exception("This Discriminator is currently not supported!")
if "s" in opts.tasks:
if opts.gen.s.use_advent:
if opts.dis.s.gan_type == "WGAN_norm":
self["s"] = nn.ModuleDict(
{"Advent": get_fc_discriminator(num_classes=11, use_norm=True)}
)
else:
self["s"] = nn.ModuleDict(
{"Advent": get_fc_discriminator(num_classes=11, use_norm=False)}
)
def get_fc_discriminator(num_classes=2, ndf=64, use_norm=False):
if use_norm:
return torch.nn.Sequential(
SpectralNorm(
torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
SpectralNorm(
torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)
),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
SpectralNorm(
torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
SpectralNorm(
torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)
),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
SpectralNorm(
torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1)
),
)
else:
return torch.nn.Sequential(
torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1),
torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1),
)