jhaozhuang
app
77771e4
raw
history blame
60.2 kB
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
import numpy as np
import torch.nn.functional as F
from torch.nn.modules.normalization import LayerNorm
import os
from torch.nn.utils import spectral_norm
from torchvision import models
###############################################################################
# Helper functions
###############################################################################
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
#init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
init.kaiming_normal_(m.weight.data, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init=True):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
if init:
init_weights(net, init_type, init_gain=init_gain)
return net
def get_scheduler(optimizer, opt):
"""Return a learning rate scheduler
Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
For 'linear', we keep the same learning rate for the first <opt.niter> epochs
and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
See https://pytorch.org/docs/stable/optim.html for more details.
"""
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
class LayerNormWarpper(nn.Module):
def __init__(self, num_features):
super(LayerNormWarpper, self).__init__()
self.num_features = int(num_features)
def forward(self, x):
x = nn.LayerNorm([self.num_features, x.size()[2], x.size()[3]], elementwise_affine=False).cuda()(x)
return x
def get_norm_layer(norm_type='instance'):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'layer':
norm_layer = functools.partial(LayerNormWarpper)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def get_non_linearity(layer_type='relu'):
if layer_type == 'relu':
nl_layer = functools.partial(nn.ReLU, inplace=True)
elif layer_type == 'lrelu':
nl_layer = functools.partial(
nn.LeakyReLU, negative_slope=0.2, inplace=True)
elif layer_type == 'elu':
nl_layer = functools.partial(nn.ELU, inplace=True)
elif layer_type == 'selu':
nl_layer = functools.partial(nn.SELU, inplace=True)
elif layer_type == 'prelu':
nl_layer = functools.partial(nn.PReLU)
else:
raise NotImplementedError(
'nonlinearity activitation [%s] is not found' % layer_type)
return nl_layer
def define_G(input_nc, output_nc, nz, ngf, netG='unet_128', norm='batch', nl='relu', use_noise=False,
use_dropout=False, init_type='xavier', init_gain=0.02, gpu_ids=[], where_add='input', upsample='bilinear'):
net = None
norm_layer = get_norm_layer(norm_type=norm)
nl_layer = get_non_linearity(layer_type=nl)
# print(norm, norm_layer)
if nz == 0:
where_add = 'input'
if netG == 'unet_128' and where_add == 'input':
net = G_Unet_add_input(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
elif netG == 'unet_128_G' and where_add == 'input':
net = G_Unet_add_input_G(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
elif netG == 'unet_256' and where_add == 'input':
net = G_Unet_add_input(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
elif netG == 'unet_256_G' and where_add == 'input':
net = G_Unet_add_input_G(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
use_dropout=use_dropout, upsample=upsample, device=gpu_ids)
elif netG == 'unet_128' and where_add == 'all':
net = G_Unet_add_all(input_nc, output_nc, nz, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
use_dropout=use_dropout, upsample=upsample)
elif netG == 'unet_256' and where_add == 'all':
net = G_Unet_add_all(input_nc, output_nc, nz, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer, use_noise=use_noise,
use_dropout=use_dropout, upsample=upsample)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % net)
# print(net)
return init_net(net, init_type, init_gain, gpu_ids)
def define_C(input_nc, output_nc, nz, ngf, netC='unet_128', norm='instance', nl='relu',
use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], upsample='basic'):
net = None
norm_layer = get_norm_layer(norm_type=norm)
nl_layer = get_non_linearity(layer_type=nl)
if netC == 'resnet_9blocks':
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
elif netC == 'resnet_6blocks':
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
elif netC == 'unet_128':
net = G_Unet_add_input_C(input_nc, output_nc, 0, 7, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
use_dropout=use_dropout, upsample=upsample)
elif netC == 'unet_256':
net = G_Unet_add_input(input_nc, output_nc, 0, 8, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
use_dropout=use_dropout, upsample=upsample)
elif netC == 'unet_32':
net = G_Unet_add_input(input_nc, output_nc, 0, 5, ngf, norm_layer=norm_layer, nl_layer=nl_layer,
use_dropout=use_dropout, upsample=upsample)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % net)
return init_net(net, init_type, init_gain, gpu_ids)
def define_D(input_nc, ndf, netD, norm='batch', nl='lrelu', init_type='xavier', init_gain=0.02, num_Ds=1, gpu_ids=[]):
net = None
norm_layer = get_norm_layer(norm_type=norm)
nl = 'lrelu' # use leaky relu for D
nl_layer = get_non_linearity(layer_type=nl)
if netD == 'basic_128':
net = D_NLayers(input_nc, ndf, n_layers=2, norm_layer=norm_layer, nl_layer=nl_layer)
elif netD == 'basic_256':
net = D_NLayers(input_nc, ndf, n_layers=3, norm_layer=norm_layer, nl_layer=nl_layer)
elif netD == 'basic_128_multi':
net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=2, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer)
elif netD == 'basic_256_multi':
net = D_NLayersMulti(input_nc=input_nc, ndf=ndf, n_layers=3, norm_layer=norm_layer, num_D=num_Ds, nl_layer=nl_layer)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
return init_net(net, init_type, init_gain, gpu_ids)
def define_E(input_nc, output_nc, ndf, netE, norm='batch', nl='lrelu',
init_type='xavier', init_gain=0.02, gpu_ids=[], vaeLike=False):
net = None
norm_layer = get_norm_layer(norm_type=norm)
nl = 'lrelu' # use leaky relu for E
nl_layer = get_non_linearity(layer_type=nl)
if netE == 'resnet_128':
net = E_ResNet(input_nc, output_nc, ndf, n_blocks=4, norm_layer=norm_layer,
nl_layer=nl_layer, vaeLike=vaeLike)
elif netE == 'resnet_256':
net = E_ResNet(input_nc, output_nc, ndf, n_blocks=5, norm_layer=norm_layer,
nl_layer=nl_layer, vaeLike=vaeLike)
elif netE == 'conv_128':
net = E_NLayers(input_nc, output_nc, ndf, n_layers=4, norm_layer=norm_layer,
nl_layer=nl_layer, vaeLike=vaeLike)
elif netE == 'conv_256':
net = E_NLayers(input_nc, output_nc, ndf, n_layers=5, norm_layer=norm_layer,
nl_layer=nl_layer, vaeLike=vaeLike)
else:
raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
return init_net(net, init_type, init_gain, gpu_ids, False)
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, norm_layer=None, use_dropout=False, n_blocks=6, padding_type='replicate'):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
model = [nn.ReplicationPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias)]
if norm_layer is not None:
model += [norm_layer(ngf)]
model += [nn.ReLU(True)]
# n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.ReplicationPad2d(1),nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=0, bias=use_bias)]
# model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
# stride=2, padding=1, bias=use_bias)]
if norm_layer is not None:
model += [norm_layer(ngf * mult * 2)]
model += [nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
# model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
# kernel_size=3, stride=2,
# padding=1, output_padding=1,
# bias=use_bias)]
# if norm_layer is not None:
# model += [norm_layer(ngf * mult / 2)]
# model += [nn.ReLU(True)]
model += upsampleLayer(ngf * mult, int(ngf * mult / 2), upsample='bilinear', padding_type=padding_type)
if norm_layer is not None:
model += [norm_layer(int(ngf * mult / 2))]
model += [nn.ReLU(True)]
model +=[nn.ReplicationPad2d(1),
nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2), kernel_size=3, padding=0)]
if norm_layer is not None:
model += [norm_layer(ngf * mult / 2)]
model += [nn.ReLU(True)]
model += [nn.ReplicationPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
#model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
if norm_layer is not None:
conv_block += [norm_layer(dim)]
conv_block += [nn.ReLU(True)]
# if use_dropout:
# conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
if norm_layer is not None:
conv_block += [norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class D_NLayersMulti(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3,
norm_layer=nn.BatchNorm2d, num_D=1, nl_layer=None):
super(D_NLayersMulti, self).__init__()
# st()
self.num_D = num_D
self.nl_layer=nl_layer
if num_D == 1:
layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
self.model = nn.Sequential(*layers)
else:
layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
self.add_module("model_0", nn.Sequential(*layers))
self.down = nn.functional.interpolate
for i in range(1, num_D):
ndf_i = int(round(ndf / (2**i)))
layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer)
self.add_module("model_%d" % i, nn.Sequential(*layers))
def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
kw = 3
padw = 1
sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw,
stride=2, padding=padw)), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw))]
if norm_layer:
sequence += [norm_layer(ndf * nf_mult)]
sequence += [self.nl_layer()]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw))]
if norm_layer:
sequence += [norm_layer(ndf * nf_mult)]
sequence += [self.nl_layer()]
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult, 1,
kernel_size=kw, stride=1, padding=padw))]
return sequence
def forward(self, input):
if self.num_D == 1:
return self.model(input)
result = []
down = input
for i in range(self.num_D):
model = getattr(self, "model_%d" % i)
result.append(model(down))
if i != self.num_D - 1:
down = self.down(down, scale_factor=0.5, mode='bilinear')
return result
class D_NLayers(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(D_NLayers, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 3
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.model(input)
class G_Unet_add_input(nn.Module):
def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
upsample='basic', device=0):
super(G_Unet_add_input, self).__init__()
self.nz = nz
max_nchn = 8
noise = []
for i in range(num_downs+1):
if use_noise:
noise.append(True)
else:
noise.append(False)
# construct unet structure
#print(num_downs)
unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=noise[num_downs-1],
innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
for i in range(num_downs - 5):
unet_block = UnetBlock_A(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise[num_downs-i-3],
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
unet_block = UnetBlock_A(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2],
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
unet_block = UnetBlock_A(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1],
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
unet_block = UnetBlock_A(ngf, ngf, ngf * 2, unet_block, noise[0],
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
unet_block = UnetBlock_A(input_nc + nz, output_nc, ngf, unet_block, None,
outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
self.model = unet_block
def forward(self, x, z=None):
if self.nz > 0:
z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
z.size(0), z.size(1), x.size(2), x.size(3))
x_with_z = torch.cat([x, z_img], 1)
else:
x_with_z = x # no z
return torch.tanh(self.model(x_with_z))
# return self.model(x_with_z)
class G_Unet_add_input_G(nn.Module):
def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
upsample='basic', device=0):
super(G_Unet_add_input_G, self).__init__()
self.nz = nz
max_nchn = 8
noise = []
for i in range(num_downs+1):
if use_noise:
noise.append(True)
else:
noise.append(False)
# construct unet structure
#print(num_downs)
unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False,
innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
for i in range(num_downs - 5):
unet_block = UnetBlock_G(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False,
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
unet_block = UnetBlock_G(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise[2],
norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
unet_block = UnetBlock_G(ngf * 2, ngf * 2, ngf * 4, unet_block, noise[1],
norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
unet_block = UnetBlock_G(ngf, ngf, ngf * 2, unet_block, noise[0],
norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
unet_block = UnetBlock_G(input_nc + nz, output_nc, ngf, unet_block, None,
outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample='basic')
self.model = unet_block
def forward(self, x, z=None):
if self.nz > 0:
z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
z.size(0), z.size(1), x.size(2), x.size(3))
x_with_z = torch.cat([x, z_img], 1)
else:
x_with_z = x # no z
# return F.tanh(self.model(x_with_z))
return self.model(x_with_z)
class G_Unet_add_input_C(nn.Module):
def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False,
upsample='basic', device=0):
super(G_Unet_add_input_C, self).__init__()
self.nz = nz
max_nchn = 8
# construct unet structure
#print(num_downs)
unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, noise=False,
innermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
for i in range(num_downs - 5):
unet_block = UnetBlock(ngf * max_nchn, ngf * max_nchn, ngf * max_nchn, unet_block, noise=False,
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
unet_block = UnetBlock(ngf * 4, ngf * 4, ngf * max_nchn, unet_block, noise=False,
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
unet_block = UnetBlock(ngf * 2, ngf * 2, ngf * 4, unet_block, noise=False,
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
unet_block = UnetBlock(ngf, ngf, ngf * 2, unet_block, noise=False,
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
unet_block = UnetBlock(input_nc + nz, output_nc, ngf, unet_block, noise=False,
outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
self.model = unet_block
def forward(self, x, z=None):
if self.nz > 0:
z_img = z.view(z.size(0), z.size(1), 1, 1).expand(
z.size(0), z.size(1), x.size(2), x.size(3))
x_with_z = torch.cat([x, z_img], 1)
else:
x_with_z = x # no z
# return torch.tanh(self.model(x_with_z))
return self.model(x_with_z)
def upsampleLayer(inplanes, outplanes, kw=1, upsample='basic', padding_type='replicate'):
# padding_type = 'zero'
if upsample == 'basic':
upconv = [nn.ConvTranspose2d(inplanes, outplanes, kernel_size=4, stride=2, padding=1)]#, padding_mode='replicate'
elif upsample == 'bilinear' or upsample == 'nearest' or upsample == 'linear':
upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True),
#nn.ReplicationPad2d(1),
nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0)]
# p = kw//2
# upconv = [nn.Upsample(scale_factor=2, mode=upsample, align_corners=True),
# nn.Conv2d(inplanes, outplanes, kernel_size=kw, stride=1, padding=p, padding_mode='replicate')]
else:
raise NotImplementedError(
'upsample layer [%s] not implemented' % upsample)
return upconv
class UnetBlock_G(nn.Module):
def __init__(self, input_nc, outer_nc, inner_nc,
submodule=None, noise=None, outermost=False, innermost=False,
norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
super(UnetBlock_G, self).__init__()
self.outermost = outermost
p = 0
downconv = []
if padding_type == 'reflect':
downconv += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
downconv += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError(
'padding [%s] is not implemented' % padding_type)
downconv += [nn.Conv2d(input_nc, inner_nc,
kernel_size=3, stride=2, padding=p)]
# downsample is different from upsample
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc) if norm_layer is not None else None
uprelu = nl_layer()
uprelu2 = nl_layer()
uppad = nn.ReplicationPad2d(1)
upnorm = norm_layer(outer_nc) if norm_layer is not None else None
upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
self.noiseblock = ApplyNoise(outer_nc)
self.noise = noise
if outermost:
upconv = upsampleLayer(inner_nc * 2, inner_nc, upsample=upsample, padding_type=padding_type)
uppad = nn.ReplicationPad2d(3)
upconv2 = nn.Conv2d(inner_nc, outer_nc, kernel_size=7, padding=0)
down = downconv
up = [uprelu] + upconv
if upnorm is not None:
up += [norm_layer(inner_nc)]
# upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
# upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=0)
# down = downconv
# up = [uprelu] + upconv
# if upnorm is not None:
# up += [norm_layer(outer_nc)]
up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
down = [downrelu] + downconv
up = [uprelu] + upconv
if upnorm is not None:
up += [upnorm]
up += [uprelu2, uppad, upconv2]
if upnorm2 is not None:
up += [upnorm2]
model = down + up
else:
upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
down = [downrelu] + downconv
if downnorm is not None:
down += [downnorm]
up = [uprelu] + upconv
if upnorm is not None:
up += [upnorm]
up += [uprelu2, uppad, upconv2]
if upnorm2 is not None:
up += [upnorm2]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
x2 = self.model(x)
if self.noise:
x2 = self.noiseblock(x2, self.noise)
return torch.cat([x2, x], 1)
class UnetBlock(nn.Module):
def __init__(self, input_nc, outer_nc, inner_nc,
submodule=None, noise=None, outermost=False, innermost=False,
norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
super(UnetBlock, self).__init__()
self.outermost = outermost
p = 0
downconv = []
if padding_type == 'reflect':
downconv += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
downconv += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError(
'padding [%s] is not implemented' % padding_type)
downconv += [nn.Conv2d(input_nc, inner_nc,
kernel_size=3, stride=2, padding=p)]
# downsample is different from upsample
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc) if norm_layer is not None else None
uprelu = nl_layer()
uprelu2 = nl_layer()
uppad = nn.ReplicationPad2d(1)
upnorm = norm_layer(outer_nc) if norm_layer is not None else None
upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
self.noiseblock = ApplyNoise(outer_nc)
self.noise = noise
if outermost:
upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
down = downconv
up = [uprelu] + upconv
if upnorm is not None:
up += [upnorm]
up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
down = [downrelu] + downconv
up = [uprelu] + upconv
if upnorm is not None:
up += [upnorm]
up += [uprelu2, uppad, upconv2]
if upnorm2 is not None:
up += [upnorm2]
model = down + up
else:
upconv = upsampleLayer(inner_nc * 2, outer_nc, upsample=upsample, padding_type=padding_type)
upconv2 = nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p)
down = [downrelu] + downconv
if downnorm is not None:
down += [downnorm]
up = [uprelu] + upconv
if upnorm is not None:
up += [upnorm]
up += [uprelu2, uppad, upconv2]
if upnorm2 is not None:
up += [upnorm2]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
x2 = self.model(x)
if self.noise:
x2 = self.noiseblock(x2, self.noise)
return torch.cat([x2, x], 1)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetBlock_A(nn.Module):
def __init__(self, input_nc, outer_nc, inner_nc,
submodule=None, noise=None, outermost=False, innermost=False,
norm_layer=None, nl_layer=None, use_dropout=False, upsample='basic', padding_type='replicate'):
super(UnetBlock_A, self).__init__()
self.outermost = outermost
p = 0
downconv = []
if padding_type == 'reflect':
downconv += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
downconv += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError(
'padding [%s] is not implemented' % padding_type)
downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc,
kernel_size=3, stride=2, padding=p))]
# downsample is different from upsample
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc) if norm_layer is not None else None
uprelu = nl_layer()
uprelu2 = nl_layer()
uppad = nn.ReplicationPad2d(1)
upnorm = norm_layer(outer_nc) if norm_layer is not None else None
upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
self.noiseblock = ApplyNoise(outer_nc)
self.noise = noise
if outermost:
upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type)
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
down = downconv
up = [uprelu] + upconv
if upnorm is not None:
up += [upnorm]
up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = upsampleLayer(inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
down = [downrelu] + downconv
up = [uprelu] + upconv
if upnorm is not None:
up += [upnorm]
up += [uprelu2, uppad, upconv2]
if upnorm2 is not None:
up += [upnorm2]
model = down + up
else:
upconv = upsampleLayer(inner_nc * 1, outer_nc, upsample=upsample, padding_type=padding_type)
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
down = [downrelu] + downconv
if downnorm is not None:
down += [downnorm]
up = [uprelu] + upconv
if upnorm is not None:
up += [upnorm]
up += [uprelu2, uppad, upconv2]
if upnorm2 is not None:
up += [upnorm2]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
x2 = self.model(x)
if self.noise:
x2 = self.noiseblock(x2, self.noise)
if x2.shape[-1]==x.shape[-1]:
return x2 + x
else:
x2 = F.interpolate(x2, x.shape[2:])
return x2 + x
class E_ResNet(nn.Module):
def __init__(self, input_nc=3, output_nc=1, ndf=64, n_blocks=4,
norm_layer=None, nl_layer=None, vaeLike=False):
super(E_ResNet, self).__init__()
self.vaeLike = vaeLike
max_ndf = 4
conv_layers = [
nn.Conv2d(input_nc, ndf, kernel_size=3, stride=2, padding=1, bias=True)]
for n in range(1, n_blocks):
input_ndf = ndf * min(max_ndf, n)
output_ndf = ndf * min(max_ndf, n + 1)
conv_layers += [BasicBlock(input_ndf,
output_ndf, norm_layer, nl_layer)]
conv_layers += [nl_layer(), nn.AdaptiveAvgPool2d(4)]
if vaeLike:
self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
self.fcVar = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
else:
self.fc = nn.Sequential(*[nn.Linear(output_ndf * 16, output_nc)])
self.conv = nn.Sequential(*conv_layers)
def forward(self, x):
x_conv = self.conv(x)
conv_flat = x_conv.view(x.size(0), -1)
output = self.fc(conv_flat)
if self.vaeLike:
outputVar = self.fcVar(conv_flat)
return output, outputVar
else:
return output
return output
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class G_Unet_add_all(nn.Module):
def __init__(self, input_nc, output_nc, nz, num_downs, ngf=64,
norm_layer=None, nl_layer=None, use_dropout=False, use_noise=False, upsample='basic'):
super(G_Unet_add_all, self).__init__()
self.nz = nz
self.mapping = G_mapping(self.nz, self.nz, 512, normalize_latents=False, lrmul=1)
self.truncation_psi = 0
self.truncation_cutoff = 0
# - 2 means we start from feature map with height and width equals 4.
# as this example, we get num_layers = 18.
num_layers = int(np.log2(512)) * 2 - 2
# Noise inputs.
self.noise_inputs = []
for layer_idx in range(num_layers):
res = layer_idx // 2 + 2
shape = [1, 1, 2 ** res, 2 ** res]
self.noise_inputs.append(torch.randn(*shape).to("cuda"))
# construct unet structure
unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=None, innermost=True,
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block,
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
for i in range(num_downs - 6):
unet_block = UnetBlock_with_z(ngf * 8, ngf * 8, ngf * 8, nz, submodule=unet_block,
norm_layer=norm_layer, nl_layer=nl_layer, use_dropout=use_dropout, upsample=upsample)
unet_block = UnetBlock_with_z(ngf * 4, ngf * 4, ngf * 8, nz, submodule=unet_block,
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
unet_block = UnetBlock_with_z(ngf * 2, ngf * 2, ngf * 4, nz, submodule=unet_block,
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
unet_block = UnetBlock_with_z(ngf, ngf, ngf * 2, nz, submodule=unet_block,
norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
unet_block = UnetBlock_with_z(input_nc, output_nc, ngf, nz, submodule=unet_block,
outermost=True, norm_layer=norm_layer, nl_layer=nl_layer, upsample=upsample)
self.model = unet_block
def forward(self, x, z):
dlatents1, num_layers = self.mapping(z)
dlatents1 = dlatents1.unsqueeze(1)
dlatents1 = dlatents1.expand(-1, int(num_layers), -1)
# Apply truncation trick.
if self.truncation_psi and self.truncation_cutoff:
coefs = np.ones([1, num_layers, 1], dtype=np.float32)
for i in range(num_layers):
if i < self.truncation_cutoff:
coefs[:, i, :] *= self.truncation_psi
"""Linear interpolation.
a + (b - a) * t (a = 0)
reduce to
b * t
"""
dlatents1 = dlatents1 * torch.Tensor(coefs).to(dlatents1.device)
return torch.tanh(self.model(x, dlatents1, self.noise_inputs))
class ApplyNoise(nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels
self.weight = nn.Parameter(torch.randn(channels), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True)
def forward(self, x, noise):
W,_ = torch.split(self.weight.view(1, -1, 1, 1), self.channels // 2, dim=1)
B,_ = torch.split(self.bias.view(1, -1, 1, 1), self.channels // 2, dim=1)
Z = torch.zeros_like(W)
w = torch.cat([W,Z], dim=1).to(x.device)
b = torch.cat([B,Z], dim=1).to(x.device)
adds = w * torch.randn_like(x) + b
return x + adds.type_as(x)
class FC(nn.Module):
def __init__(self,
in_channels,
out_channels,
gain=2**(0.5),
use_wscale=False,
lrmul=1.0,
bias=True):
"""
The complete conversion of Dense/FC/Linear Layer of original Tensorflow version.
"""
super(FC, self).__init__()
he_std = gain * in_channels ** (-0.5) # He init
if use_wscale:
init_std = 1.0 / lrmul
self.w_lrmul = he_std * lrmul
else:
init_std = he_std / lrmul
self.w_lrmul = lrmul
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels) * init_std)
if bias:
self.bias = torch.nn.Parameter(torch.zeros(out_channels))
self.b_lrmul = lrmul
else:
self.bias = None
def forward(self, x):
if self.bias is not None:
out = F.linear(x, self.weight * self.w_lrmul, self.bias * self.b_lrmul)
else:
out = F.linear(x, self.weight * self.w_lrmul)
out = F.leaky_relu(out, 0.2, inplace=True)
return out
class ApplyStyle(nn.Module):
"""
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
"""
def __init__(self, latent_size, channels, use_wscale, nl_layer):
super(ApplyStyle, self).__init__()
modules = [nn.Linear(latent_size, channels*2)]
if nl_layer:
modules += [nl_layer()]
self.linear = nn.Sequential(*modules)
def forward(self, x, latent):
style = self.linear(latent) # style => [batch_size, n_channels*2]
shape = [-1, 2, x.size(1), 1, 1]
style = style.view(shape) # [batch_size, 2, n_channels, ...]
x = x * (style[:, 0] + 1.) + style[:, 1]
return x
class PixelNorm(nn.Module):
def __init__(self, epsilon=1e-8):
"""
@notice: avoid in-place ops.
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
"""
super(PixelNorm, self).__init__()
self.epsilon = epsilon
def forward(self, x):
tmp = torch.mul(x, x) # or x ** 2
tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon)
return x * tmp1
class InstanceNorm(nn.Module):
def __init__(self, epsilon=1e-8):
"""
@notice: avoid in-place ops.
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
"""
super(InstanceNorm, self).__init__()
self.epsilon = epsilon
def forward(self, x):
x = x - torch.mean(x, (2, 3), True)
tmp = torch.mul(x, x) # or x ** 2
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
return x * tmp
class LayerEpilogue(nn.Module):
def __init__(self, channels, dlatent_size, use_wscale, use_noise,
use_pixel_norm, use_instance_norm, use_styles, nl_layer=None):
super(LayerEpilogue, self).__init__()
self.use_noise = use_noise
if use_noise:
self.noise = ApplyNoise(channels)
self.act = nn.LeakyReLU(negative_slope=0.2)
if use_pixel_norm:
self.pixel_norm = PixelNorm()
else:
self.pixel_norm = None
if use_instance_norm:
self.instance_norm = InstanceNorm()
else:
self.instance_norm = None
if use_styles:
self.style_mod = ApplyStyle(dlatent_size, channels, use_wscale=use_wscale, nl_layer=nl_layer)
else:
self.style_mod = None
def forward(self, x, noise, dlatents_in_slice=None):
# if noise is not None:
if self.use_noise:
x = self.noise(x, noise)
x = self.act(x)
if self.pixel_norm is not None:
x = self.pixel_norm(x)
if self.instance_norm is not None:
x = self.instance_norm(x)
if self.style_mod is not None:
x = self.style_mod(x, dlatents_in_slice)
return x
class G_mapping(nn.Module):
def __init__(self,
mapping_fmaps=512,
dlatent_size=512,
resolution=512,
normalize_latents=True, # Normalize latent vectors (Z) before feeding them to the mapping layers?
use_wscale=True, # Enable equalized learning rate?
lrmul=0.01, # Learning rate multiplier for the mapping layers.
gain=2**(0.5), # original gain in tensorflow.
nl_layer=None
):
super(G_mapping, self).__init__()
self.mapping_fmaps = mapping_fmaps
func = [
nn.Linear(self.mapping_fmaps, dlatent_size)
]
if nl_layer:
func += [nl_layer()]
for j in range(0,4):
func += [
nn.Linear(dlatent_size, dlatent_size)
]
if nl_layer:
func += [nl_layer()]
self.func = nn.Sequential(*func)
#FC(self.mapping_fmaps, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
#FC(dlatent_size, dlatent_size, gain, lrmul=lrmul, use_wscale=use_wscale),
self.normalize_latents = normalize_latents
self.resolution_log2 = int(np.log2(resolution))
self.num_layers = self.resolution_log2 * 2 - 2
self.pixel_norm = PixelNorm()
# - 2 means we start from feature map with height and width equals 4.
# as this example, we get num_layers = 18.
def forward(self, x):
if self.normalize_latents:
x = self.pixel_norm(x)
out = self.func(x)
return out, self.num_layers
class UnetBlock_with_z(nn.Module):
def __init__(self, input_nc, outer_nc, inner_nc, nz=0,
submodule=None, outermost=False, innermost=False,
norm_layer=None, nl_layer=None, use_dropout=False,
upsample='basic', padding_type='replicate'):
super(UnetBlock_with_z, self).__init__()
p = 0
downconv = []
if padding_type == 'reflect':
downconv += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
downconv += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError(
'padding [%s] is not implemented' % padding_type)
self.outermost = outermost
self.innermost = innermost
self.nz = nz
# input_nc = input_nc + nz
downconv += [spectral_norm(nn.Conv2d(input_nc, inner_nc,
kernel_size=3, stride=2, padding=p))]
# downsample is different from upsample
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc) if norm_layer is not None else None
uprelu = nl_layer()
uprelu2 = nl_layer()
uppad = nn.ReplicationPad2d(1)
upnorm = norm_layer(outer_nc) if norm_layer is not None else None
upnorm2 = norm_layer(outer_nc) if norm_layer is not None else None
use_styles=False
uprelu = nl_layer()
if self.nz >0:
use_styles=True
if outermost:
self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False,
use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
upconv = upsampleLayer(
inner_nc , outer_nc, upsample=upsample, padding_type=padding_type)
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
down = downconv
up = [uprelu] + upconv
if upnorm is not None:
up += [upnorm]
up +=[uprelu2, uppad, upconv2] #+ [nn.Tanh()]
elif innermost:
self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=True,
use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
upconv = upsampleLayer(
inner_nc, outer_nc, upsample=upsample, padding_type=padding_type)
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
down = [downrelu] + downconv
up = [uprelu] + upconv
if norm_layer is not None:
up += [norm_layer(outer_nc)]
up += [uprelu2, uppad, upconv2]
if upnorm2 is not None:
up += [upnorm2]
else:
self.adaIn = LayerEpilogue(inner_nc, self.nz, use_wscale=True, use_noise=False,
use_pixel_norm=True, use_instance_norm=True, use_styles=use_styles, nl_layer=nl_layer)
upconv = upsampleLayer(
inner_nc , outer_nc, upsample=upsample, padding_type=padding_type)
upconv2 = spectral_norm(nn.Conv2d(outer_nc, outer_nc, kernel_size=3, padding=p))
down = [downrelu] + downconv
if norm_layer is not None:
down += [norm_layer(inner_nc)]
up = [uprelu] + upconv
if norm_layer is not None:
up += [norm_layer(outer_nc)]
up += [uprelu2, uppad, upconv2]
if upnorm2 is not None:
up += [upnorm2]
if use_dropout:
up += [nn.Dropout(0.5)]
self.down = nn.Sequential(*down)
self.submodule = submodule
self.up = nn.Sequential(*up)
def forward(self, x, z, noise):
if self.outermost:
x1 = self.down(x)
x2 = self.submodule(x1, z[:,2:], noise[2:])
return self.up(x2)
elif self.innermost:
x1 = self.down(x)
x_and_z = self.adaIn(x1, noise[0], z[:,0])
x2 = self.up(x_and_z)
x2 = F.interpolate(x2, x.shape[2:])
return x2 + x
else:
x1 = self.down(x)
x2 = self.submodule(x1, z[:,2:], noise[2:])
x_and_z = self.adaIn(x2, noise[0], z[:,0])
return self.up(x_and_z) + x
class E_NLayers(nn.Module):
def __init__(self, input_nc, output_nc=1, ndf=64, n_layers=4,
norm_layer=None, nl_layer=None, vaeLike=False):
super(E_NLayers, self).__init__()
self.vaeLike = vaeLike
kw, padw = 3, 1
sequence = [spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kw,
stride=2, padding=padw, padding_mode='replicate')), nl_layer()]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, padding_mode='replicate'))]
if norm_layer is not None:
sequence += [norm_layer(ndf * nf_mult)]
sequence += [nl_layer()]
sequence += [nn.AdaptiveAvgPool2d(4)]
self.conv = nn.Sequential(*sequence)
self.fc = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))])
if vaeLike:
self.fcVar = nn.Sequential(*[spectral_norm(nn.Linear(ndf * nf_mult * 16, output_nc))])
def forward(self, x):
x_conv = self.conv(x)
conv_flat = x_conv.view(x.size(0), -1)
output = self.fc(conv_flat)
if self.vaeLike:
outputVar = self.fcVar(conv_flat)
return output, outputVar
return output
class BasicBlock(nn.Module):
def __init__(self, inplanes, outplanes):
super(BasicBlock, self).__init__()
layers = []
norm_layer=get_norm_layer(norm_type='layer') #functools.partial(LayerNorm)
# norm_layer = None
nl_layer=nn.ReLU()
if norm_layer is not None:
layers += [norm_layer(inplanes)]
layers += [nl_layer]
layers += [nn.ReplicationPad2d(1),
nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=1,
padding=0, bias=True)]
self.conv = nn.Sequential(*layers)
def forward(self, x):
return self.conv(x)
def define_SVAE(inc=96, outc=3, outplanes=64, blocks=1, netVAE='SVAE', model_name='', load_ext=True, save_dir='',
init_type="normal", init_gain=0.02, gpu_ids=[]):
if netVAE == 'SVAE':
net = ScreenVAE(inc=inc, outc=outc, outplanes=outplanes, blocks=blocks, save_dir=save_dir,
init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
else:
raise NotImplementedError('Encoder model name [%s] is not recognized' % net)
init_net(net, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
net.load_networks('latest')
return net
class ScreenVAE(nn.Module):
def __init__(self,inc=1,outc=4, outplanes=64, downs=5, blocks=2,load_ext=True, save_dir='',init_type="normal", init_gain=0.02, gpu_ids=[]):
super(ScreenVAE, self).__init__()
self.inc = inc
self.outc = outc
self.save_dir = save_dir
norm_layer=functools.partial(LayerNormWarpper)
nl_layer=nn.LeakyReLU
self.model_names=['enc','dec']
self.enc=define_C(inc+1, outc*2, 0, 24, netC='resnet_6blocks',
norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming',
gpu_ids=gpu_ids, upsample='bilinear')
self.dec=define_G(outc, inc, 0, 48, netG='unet_128_G',
norm='layer', nl='lrelu', use_dropout=True, init_type='kaiming',
gpu_ids=gpu_ids, where_add='input', upsample='bilinear', use_noise=True)
for param in self.parameters():
param.requires_grad = False
def load_networks(self, epoch):
"""Load all the networks from the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
state_dict = torch.load(
load_path, map_location=lambda storage, loc: storage.cuda())
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
net.load_state_dict(state_dict)
del state_dict
def npad(self, im, pad=128):
h,w = im.shape[-2:]
hp = h //pad*pad+pad
wp = w //pad*pad+pad
return F.pad(im, (0, wp-w, 0, hp-h), mode='replicate')
def forward(self, x, line=None, img_input=True, output_screen_only=True):
if img_input:
if line is None:
line = torch.ones_like(x)
else:
line = torch.sign(line)
x = torch.clamp(x + (1-line),-1,1)
h,w = x.shape[-2:]
input = torch.cat([x, line], 1)
input = self.npad(input)
inter = self.enc(input)[:,:,:h,:w]
scr, logvar = torch.split(inter, (self.outc, self.outc), dim=1)
if output_screen_only:
return scr
recons = self.dec(scr)
return recons, scr, logvar
else:
h,w = x.shape[-2:]
x = self.npad(x)
recons = self.dec(x)[:,:,:h,:w]
recons = (recons+1)*(line+1)/2-1
return torch.clamp(recons,-1,1)