|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import torch.nn.functional as F |
|
|
|
|
|
class BaseNetwork(nn.Module): |
|
def __init__(self): |
|
super(BaseNetwork, self).__init__() |
|
|
|
def init_weights(self, init_type='normal', gain=0.02): |
|
''' |
|
initialize network's weights |
|
init_type: normal | xavier | kaiming | orthogonal |
|
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 |
|
''' |
|
|
|
def init_func(m): |
|
classname = m.__class__.__name__ |
|
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): |
|
if init_type == 'normal': |
|
nn.init.normal_(m.weight.data, 0.0, gain) |
|
elif init_type == 'xavier': |
|
nn.init.xavier_normal_(m.weight.data, gain=gain) |
|
elif init_type == 'kaiming': |
|
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
|
elif init_type == 'orthogonal': |
|
nn.init.orthogonal_(m.weight.data, gain=gain) |
|
|
|
if hasattr(m, 'bias') and m.bias is not None: |
|
nn.init.constant_(m.bias.data, 0.0) |
|
|
|
elif classname.find('BatchNorm2d') != -1: |
|
nn.init.normal_(m.weight.data, 1.0, gain) |
|
nn.init.constant_(m.bias.data, 0.0) |
|
|
|
self.apply(init_func) |
|
|
|
def weights_init(init_type='gaussian'): |
|
def init_fun(m): |
|
classname = m.__class__.__name__ |
|
if (classname.find('Conv') == 0 or classname.find( |
|
'Linear') == 0) and hasattr(m, 'weight'): |
|
if init_type == 'gaussian': |
|
nn.init.normal_(m.weight, 0.0, 0.02) |
|
elif init_type == 'xavier': |
|
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) |
|
elif init_type == 'kaiming': |
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') |
|
elif init_type == 'orthogonal': |
|
nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) |
|
elif init_type == 'default': |
|
pass |
|
else: |
|
assert 0, "Unsupported initialization: {}".format(init_type) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
nn.init.constant_(m.bias, 0.0) |
|
|
|
return init_fun |
|
|
|
class PartialConv(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
|
padding=0, dilation=1, groups=1, bias=True): |
|
super().__init__() |
|
self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size, |
|
stride, padding, dilation, groups, bias) |
|
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, |
|
stride, padding, dilation, groups, False) |
|
self.input_conv.apply(weights_init('kaiming')) |
|
self.slide_winsize = in_channels * kernel_size * kernel_size |
|
|
|
torch.nn.init.constant_(self.mask_conv.weight, 1.0) |
|
|
|
|
|
for param in self.mask_conv.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, input, mask): |
|
|
|
|
|
|
|
output = self.input_conv(input * mask) |
|
if self.input_conv.bias is not None: |
|
output_bias = self.input_conv.bias.view(1, -1, 1, 1).expand_as( |
|
output) |
|
else: |
|
output_bias = torch.zeros_like(output) |
|
|
|
with torch.no_grad(): |
|
output_mask = self.mask_conv(mask) |
|
|
|
no_update_holes = output_mask == 0 |
|
|
|
mask_sum = output_mask.masked_fill_(no_update_holes, 1.0) |
|
|
|
output_pre = ((output - output_bias) * self.slide_winsize) / mask_sum + output_bias |
|
output = output_pre.masked_fill_(no_update_holes, 0.0) |
|
|
|
new_mask = torch.ones_like(output) |
|
new_mask = new_mask.masked_fill_(no_update_holes, 0.0) |
|
|
|
return output, new_mask |
|
|
|
|
|
class PCBActiv(nn.Module): |
|
def __init__(self, in_ch, out_ch, bn=True, sample='none-3', activ='relu', |
|
conv_bias=False): |
|
super().__init__() |
|
if sample == 'down-5': |
|
self.conv = PartialConv(in_ch, out_ch, 5, 2, 2, bias=conv_bias) |
|
elif sample == 'down-7': |
|
self.conv = PartialConv(in_ch, out_ch, 7, 2, 3, bias=conv_bias) |
|
elif sample == 'down-3': |
|
self.conv = PartialConv(in_ch, out_ch, 3, 2, 1, bias=conv_bias) |
|
else: |
|
self.conv = PartialConv(in_ch, out_ch, 3, 1, 1, bias=conv_bias) |
|
|
|
if bn: |
|
self.bn = nn.BatchNorm2d(out_ch) |
|
if activ == 'relu': |
|
self.activation = nn.ReLU() |
|
elif activ == 'leaky': |
|
self.activation = nn.LeakyReLU(negative_slope=0.2) |
|
|
|
def forward(self, input, input_mask): |
|
h, h_mask = self.conv(input, input_mask) |
|
if hasattr(self, 'bn'): |
|
h = self.bn(h) |
|
if hasattr(self, 'activation'): |
|
h = self.activation(h) |
|
return h, h_mask |
|
|
|
class Inpaint_Depth_Net(nn.Module): |
|
def __init__(self, layer_size=7, upsampling_mode='nearest'): |
|
super().__init__() |
|
in_channels = 4 |
|
out_channels = 1 |
|
self.freeze_enc_bn = False |
|
self.upsampling_mode = upsampling_mode |
|
self.layer_size = layer_size |
|
self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7', conv_bias=True) |
|
self.enc_2 = PCBActiv(64, 128, sample='down-5', conv_bias=True) |
|
self.enc_3 = PCBActiv(128, 256, sample='down-5') |
|
self.enc_4 = PCBActiv(256, 512, sample='down-3') |
|
for i in range(4, self.layer_size): |
|
name = 'enc_{:d}'.format(i + 1) |
|
setattr(self, name, PCBActiv(512, 512, sample='down-3')) |
|
|
|
for i in range(4, self.layer_size): |
|
name = 'dec_{:d}'.format(i + 1) |
|
setattr(self, name, PCBActiv(512 + 512, 512, activ='leaky')) |
|
self.dec_4 = PCBActiv(512 + 256, 256, activ='leaky') |
|
self.dec_3 = PCBActiv(256 + 128, 128, activ='leaky') |
|
self.dec_2 = PCBActiv(128 + 64, 64, activ='leaky') |
|
self.dec_1 = PCBActiv(64 + in_channels, out_channels, |
|
bn=False, activ=None, conv_bias=True) |
|
def add_border(self, input, mask_flag, PCONV=True): |
|
with torch.no_grad(): |
|
h = input.shape[-2] |
|
w = input.shape[-1] |
|
require_len_unit = 2 ** self.layer_size |
|
residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) |
|
residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) |
|
enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device) |
|
if mask_flag: |
|
if PCONV is False: |
|
enlarge_input += 1.0 |
|
enlarge_input = enlarge_input.clamp(0.0, 1.0) |
|
else: |
|
enlarge_input[:, 2, ...] = 0.0 |
|
anchor_h = residual_h//2 |
|
anchor_w = residual_w//2 |
|
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input |
|
|
|
return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w] |
|
|
|
def forward_3P(self, mask, context, depth, edge, unit_length=128, cuda=None): |
|
with torch.no_grad(): |
|
input = torch.cat((depth, edge, context, mask), dim=1) |
|
n, c, h, w = input.shape |
|
residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) |
|
residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) |
|
anchor_h = residual_h//2 |
|
anchor_w = residual_w//2 |
|
enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda) |
|
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input |
|
|
|
depth_output = self.forward(enlarge_input) |
|
depth_output = depth_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] |
|
|
|
|
|
return depth_output |
|
|
|
def forward(self, input_feat, refine_border=False, sample=False, PCONV=True): |
|
input = input_feat |
|
input_mask = (input_feat[:, -2:-1] + input_feat[:, -1:]).clamp(0, 1).repeat(1, input.shape[1], 1, 1) |
|
|
|
vis_input = input.cpu().data.numpy() |
|
vis_input_mask = input_mask.cpu().data.numpy() |
|
H, W = input.shape[-2:] |
|
if refine_border is True: |
|
input, anchor = self.add_border(input, mask_flag=False) |
|
input_mask, anchor = self.add_border(input_mask, mask_flag=True, PCONV=PCONV) |
|
h_dict = {} |
|
h_mask_dict = {} |
|
h_dict['h_0'], h_mask_dict['h_0'] = input, input_mask |
|
|
|
h_key_prev = 'h_0' |
|
for i in range(1, self.layer_size + 1): |
|
l_key = 'enc_{:d}'.format(i) |
|
h_key = 'h_{:d}'.format(i) |
|
h_dict[h_key], h_mask_dict[h_key] = getattr(self, l_key)( |
|
h_dict[h_key_prev], h_mask_dict[h_key_prev]) |
|
h_key_prev = h_key |
|
|
|
h_key = 'h_{:d}'.format(self.layer_size) |
|
h, h_mask = h_dict[h_key], h_mask_dict[h_key] |
|
|
|
for i in range(self.layer_size, 0, -1): |
|
enc_h_key = 'h_{:d}'.format(i - 1) |
|
dec_l_key = 'dec_{:d}'.format(i) |
|
|
|
h = F.interpolate(h, scale_factor=2, mode=self.upsampling_mode) |
|
h_mask = F.interpolate(h_mask, scale_factor=2, mode='nearest') |
|
|
|
h = torch.cat([h, h_dict[enc_h_key]], dim=1) |
|
h_mask = torch.cat([h_mask, h_mask_dict[enc_h_key]], dim=1) |
|
h, h_mask = getattr(self, dec_l_key)(h, h_mask) |
|
output = h |
|
if refine_border is True: |
|
h_mask = h_mask[..., anchor[0]:anchor[1], anchor[2]:anchor[3]] |
|
output = output[..., anchor[0]:anchor[1], anchor[2]:anchor[3]] |
|
|
|
return output |
|
|
|
class Inpaint_Edge_Net(BaseNetwork): |
|
def __init__(self, residual_blocks=8, init_weights=True): |
|
super(Inpaint_Edge_Net, self).__init__() |
|
in_channels = 7 |
|
out_channels = 1 |
|
self.encoder = [] |
|
|
|
self.encoder_0 = nn.Sequential( |
|
nn.ReflectionPad2d(3), |
|
spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, padding=0), True), |
|
nn.InstanceNorm2d(64, track_running_stats=False), |
|
nn.ReLU(True)) |
|
|
|
self.encoder_1 = nn.Sequential( |
|
spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), True), |
|
nn.InstanceNorm2d(128, track_running_stats=False), |
|
nn.ReLU(True)) |
|
|
|
self.encoder_2 = nn.Sequential( |
|
spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), True), |
|
nn.InstanceNorm2d(256, track_running_stats=False), |
|
nn.ReLU(True)) |
|
|
|
blocks = [] |
|
for _ in range(residual_blocks): |
|
block = ResnetBlock(256, 2) |
|
blocks.append(block) |
|
|
|
self.middle = nn.Sequential(*blocks) |
|
|
|
self.decoder_0 = nn.Sequential( |
|
spectral_norm(nn.ConvTranspose2d(in_channels=256+256, out_channels=128, kernel_size=4, stride=2, padding=1), True), |
|
nn.InstanceNorm2d(128, track_running_stats=False), |
|
nn.ReLU(True)) |
|
|
|
self.decoder_1 = nn.Sequential( |
|
spectral_norm(nn.ConvTranspose2d(in_channels=128+128, out_channels=64, kernel_size=4, stride=2, padding=1), True), |
|
nn.InstanceNorm2d(64, track_running_stats=False), |
|
nn.ReLU(True)) |
|
|
|
self.decoder_2 = nn.Sequential( |
|
nn.ReflectionPad2d(3), |
|
nn.Conv2d(in_channels=64+64, out_channels=out_channels, kernel_size=7, padding=0), |
|
) |
|
|
|
if init_weights: |
|
self.init_weights() |
|
|
|
def add_border(self, input, channel_pad_1=None): |
|
h = input.shape[-2] |
|
w = input.shape[-1] |
|
require_len_unit = 16 |
|
residual_h = int(np.ceil(h / float(require_len_unit)) * require_len_unit - h) |
|
residual_w = int(np.ceil(w / float(require_len_unit)) * require_len_unit - w) |
|
enlarge_input = torch.zeros((input.shape[0], input.shape[1], h + residual_h, w + residual_w)).to(input.device) |
|
if channel_pad_1 is not None: |
|
for channel in channel_pad_1: |
|
enlarge_input[:, channel] = 1 |
|
anchor_h = residual_h//2 |
|
anchor_w = residual_w//2 |
|
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input |
|
|
|
return enlarge_input, [anchor_h, anchor_h+h, anchor_w, anchor_w+w] |
|
|
|
def forward_3P(self, mask, context, rgb, disp, edge, unit_length=128, cuda=None): |
|
with torch.no_grad(): |
|
input = torch.cat((rgb, disp/disp.max(), edge, context, mask), dim=1) |
|
n, c, h, w = input.shape |
|
residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) |
|
residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) |
|
anchor_h = residual_h//2 |
|
anchor_w = residual_w//2 |
|
enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda) |
|
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input |
|
edge_output = self.forward(enlarge_input) |
|
edge_output = edge_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] |
|
|
|
return edge_output |
|
|
|
def forward(self, x, refine_border=False): |
|
if refine_border: |
|
x, anchor = self.add_border(x, [5]) |
|
x1 = self.encoder_0(x) |
|
x2 = self.encoder_1(x1) |
|
x3 = self.encoder_2(x2) |
|
x4 = self.middle(x3) |
|
x5 = self.decoder_0(torch.cat((x4, x3), dim=1)) |
|
x6 = self.decoder_1(torch.cat((x5, x2), dim=1)) |
|
x7 = self.decoder_2(torch.cat((x6, x1), dim=1)) |
|
x = torch.sigmoid(x7) |
|
if refine_border: |
|
x = x[..., anchor[0]:anchor[1], anchor[2]:anchor[3]] |
|
|
|
return x |
|
|
|
class Inpaint_Color_Net(nn.Module): |
|
def __init__(self, layer_size=7, upsampling_mode='nearest', add_hole_mask=False, add_two_layer=False, add_border=False): |
|
super().__init__() |
|
self.freeze_enc_bn = False |
|
self.upsampling_mode = upsampling_mode |
|
self.layer_size = layer_size |
|
in_channels = 6 |
|
self.enc_1 = PCBActiv(in_channels, 64, bn=False, sample='down-7') |
|
self.enc_2 = PCBActiv(64, 128, sample='down-5') |
|
self.enc_3 = PCBActiv(128, 256, sample='down-5') |
|
self.enc_4 = PCBActiv(256, 512, sample='down-3') |
|
self.enc_5 = PCBActiv(512, 512, sample='down-3') |
|
self.enc_6 = PCBActiv(512, 512, sample='down-3') |
|
self.enc_7 = PCBActiv(512, 512, sample='down-3') |
|
|
|
self.dec_7 = PCBActiv(512+512, 512, activ='leaky') |
|
self.dec_6 = PCBActiv(512+512, 512, activ='leaky') |
|
|
|
self.dec_5A = PCBActiv(512 + 512, 512, activ='leaky') |
|
self.dec_4A = PCBActiv(512 + 256, 256, activ='leaky') |
|
self.dec_3A = PCBActiv(256 + 128, 128, activ='leaky') |
|
self.dec_2A = PCBActiv(128 + 64, 64, activ='leaky') |
|
self.dec_1A = PCBActiv(64 + in_channels, 3, bn=False, activ=None, conv_bias=True) |
|
''' |
|
self.dec_5B = PCBActiv(512 + 512, 512, activ='leaky') |
|
self.dec_4B = PCBActiv(512 + 256, 256, activ='leaky') |
|
self.dec_3B = PCBActiv(256 + 128, 128, activ='leaky') |
|
self.dec_2B = PCBActiv(128 + 64, 64, activ='leaky') |
|
self.dec_1B = PCBActiv(64 + 4, 1, bn=False, activ=None, conv_bias=True) |
|
''' |
|
def cat(self, A, B): |
|
return torch.cat((A, B), dim=1) |
|
|
|
def upsample(self, feat, mask): |
|
feat = F.interpolate(feat, scale_factor=2, mode=self.upsampling_mode) |
|
mask = F.interpolate(mask, scale_factor=2, mode='nearest') |
|
|
|
return feat, mask |
|
|
|
def forward_3P(self, mask, context, rgb, edge, unit_length=128, cuda=None): |
|
with torch.no_grad(): |
|
input = torch.cat((rgb, edge, context, mask), dim=1) |
|
n, c, h, w = input.shape |
|
residual_h = int(np.ceil(h / float(unit_length)) * unit_length - h) |
|
residual_w = int(np.ceil(w / float(unit_length)) * unit_length - w) |
|
anchor_h = residual_h//2 |
|
anchor_w = residual_w//2 |
|
enlarge_input = torch.zeros((n, c, h + residual_h, w + residual_w)).to(cuda) |
|
enlarge_input[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] = input |
|
|
|
enlarge_input = enlarge_input.to(cuda) |
|
rgb_output = self.forward(enlarge_input) |
|
rgb_output = rgb_output[..., anchor_h:anchor_h+h, anchor_w:anchor_w+w] |
|
|
|
return rgb_output |
|
|
|
def forward(self, input, add_border=False): |
|
input_mask = (input[:, -2:-1] + input[:, -1:]).clamp(0, 1) |
|
H, W = input.shape[-2:] |
|
f_0, h_0 = input, input_mask.repeat((1,input.shape[1],1,1)) |
|
f_1, h_1 = self.enc_1(f_0, h_0) |
|
f_2, h_2 = self.enc_2(f_1, h_1) |
|
f_3, h_3 = self.enc_3(f_2, h_2) |
|
f_4, h_4 = self.enc_4(f_3, h_3) |
|
f_5, h_5 = self.enc_5(f_4, h_4) |
|
f_6, h_6 = self.enc_6(f_5, h_5) |
|
f_7, h_7 = self.enc_7(f_6, h_6) |
|
|
|
o_7, k_7 = self.upsample(f_7, h_7) |
|
o_6, k_6 = self.dec_7(self.cat(o_7, f_6), self.cat(k_7, h_6)) |
|
o_6, k_6 = self.upsample(o_6, k_6) |
|
o_5, k_5 = self.dec_6(self.cat(o_6, f_5), self.cat(k_6, h_5)) |
|
o_5, k_5 = self.upsample(o_5, k_5) |
|
o_5A, k_5A = o_5, k_5 |
|
o_5B, k_5B = o_5, k_5 |
|
|
|
o_4A, k_4A = self.dec_5A(self.cat(o_5A, f_4), self.cat(k_5A, h_4)) |
|
o_4A, k_4A = self.upsample(o_4A, k_4A) |
|
o_3A, k_3A = self.dec_4A(self.cat(o_4A, f_3), self.cat(k_4A, h_3)) |
|
o_3A, k_3A = self.upsample(o_3A, k_3A) |
|
o_2A, k_2A = self.dec_3A(self.cat(o_3A, f_2), self.cat(k_3A, h_2)) |
|
o_2A, k_2A = self.upsample(o_2A, k_2A) |
|
o_1A, k_1A = self.dec_2A(self.cat(o_2A, f_1), self.cat(k_2A, h_1)) |
|
o_1A, k_1A = self.upsample(o_1A, k_1A) |
|
o_0A, k_0A = self.dec_1A(self.cat(o_1A, f_0), self.cat(k_1A, h_0)) |
|
|
|
return torch.sigmoid(o_0A) |
|
|
|
def train(self, mode=True): |
|
""" |
|
Override the default train() to freeze the BN parameters |
|
""" |
|
super().train(mode) |
|
if self.freeze_enc_bn: |
|
for name, module in self.named_modules(): |
|
if isinstance(module, nn.BatchNorm2d) and 'enc' in name: |
|
module.eval() |
|
|
|
class Discriminator(BaseNetwork): |
|
def __init__(self, use_sigmoid=True, use_spectral_norm=True, init_weights=True, in_channels=None): |
|
super(Discriminator, self).__init__() |
|
self.use_sigmoid = use_sigmoid |
|
self.conv1 = self.features = nn.Sequential( |
|
spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
) |
|
|
|
self.conv2 = nn.Sequential( |
|
spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
) |
|
|
|
self.conv3 = nn.Sequential( |
|
spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
) |
|
|
|
self.conv4 = nn.Sequential( |
|
spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
) |
|
|
|
self.conv5 = nn.Sequential( |
|
spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), |
|
) |
|
|
|
if init_weights: |
|
self.init_weights() |
|
|
|
def forward(self, x): |
|
conv1 = self.conv1(x) |
|
conv2 = self.conv2(conv1) |
|
conv3 = self.conv3(conv2) |
|
conv4 = self.conv4(conv3) |
|
conv5 = self.conv5(conv4) |
|
|
|
outputs = conv5 |
|
if self.use_sigmoid: |
|
outputs = torch.sigmoid(conv5) |
|
|
|
return outputs, [conv1, conv2, conv3, conv4, conv5] |
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__(self, dim, dilation=1): |
|
super(ResnetBlock, self).__init__() |
|
self.conv_block = nn.Sequential( |
|
nn.ReflectionPad2d(dilation), |
|
spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not True), True), |
|
nn.InstanceNorm2d(dim, track_running_stats=False), |
|
nn.LeakyReLU(negative_slope=0.2), |
|
|
|
nn.ReflectionPad2d(1), |
|
spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not True), True), |
|
nn.InstanceNorm2d(dim, track_running_stats=False), |
|
) |
|
|
|
def forward(self, x): |
|
out = x + self.conv_block(x) |
|
|
|
|
|
|
|
|
|
return out |
|
|
|
|
|
def spectral_norm(module, mode=True): |
|
if mode: |
|
return nn.utils.spectral_norm(module) |
|
|
|
return module |
|
|