Spaces:
Running
on
L40S
Running
on
L40S
''' | |
Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. | |
BSD License. All rights reserved. | |
Redistribution and use in source and binary forms, with or without | |
modification, are permitted provided that the following conditions are met: | |
* Redistributions of source code must retain the above copyright notice, this | |
list of conditions and the following disclaimer. | |
* Redistributions in binary form must reproduce the above copyright notice, | |
this list of conditions and the following disclaimer in the documentation | |
and/or other materials provided with the distribution. | |
THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL | |
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. | |
IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL | |
DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, | |
WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING | |
OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. | |
''' | |
import torch | |
import torch.nn as nn | |
import functools | |
import numpy as np | |
import pytorch_lightning as pl | |
############################################################################### | |
# Functions | |
############################################################################### | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
m.weight.data.normal_(0.0, 0.02) | |
elif classname.find('BatchNorm2d') != -1: | |
m.weight.data.normal_(1.0, 0.02) | |
m.bias.data.fill_(0) | |
def get_norm_layer(norm_type='instance'): | |
if norm_type == 'batch': | |
norm_layer = functools.partial(nn.BatchNorm2d, affine=True) | |
elif norm_type == 'instance': | |
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) | |
else: | |
raise NotImplementedError('normalization layer [%s] is not found' % | |
norm_type) | |
return norm_layer | |
def define_G(input_nc, | |
output_nc, | |
ngf, | |
netG, | |
n_downsample_global=3, | |
n_blocks_global=9, | |
n_local_enhancers=1, | |
n_blocks_local=3, | |
norm='instance', | |
gpu_ids=[], | |
last_op=nn.Tanh()): | |
norm_layer = get_norm_layer(norm_type=norm) | |
if netG == 'global': | |
netG = GlobalGenerator(input_nc, | |
output_nc, | |
ngf, | |
n_downsample_global, | |
n_blocks_global, | |
norm_layer, | |
last_op=last_op) | |
elif netG == 'local': | |
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, | |
n_blocks_global, n_local_enhancers, | |
n_blocks_local, norm_layer) | |
elif netG == 'encoder': | |
netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, | |
norm_layer) | |
else: | |
raise ('generator not implemented!') | |
# print(netG) | |
if len(gpu_ids) > 0: | |
assert (torch.cuda.is_available()) | |
netG.cuda(gpu_ids[0]) | |
netG.apply(weights_init) | |
return netG | |
def print_network(net): | |
if isinstance(net, list): | |
net = net[0] | |
num_params = 0 | |
for param in net.parameters(): | |
num_params += param.numel() | |
print(net) | |
print('Total number of parameters: %d' % num_params) | |
############################################################################## | |
# Generator | |
############################################################################## | |
class LocalEnhancer(pl.LightningModule): | |
def __init__(self, | |
input_nc, | |
output_nc, | |
ngf=32, | |
n_downsample_global=3, | |
n_blocks_global=9, | |
n_local_enhancers=1, | |
n_blocks_local=3, | |
norm_layer=nn.BatchNorm2d, | |
padding_type='reflect'): | |
super(LocalEnhancer, self).__init__() | |
self.n_local_enhancers = n_local_enhancers | |
###### global generator model ##### | |
ngf_global = ngf * (2**n_local_enhancers) | |
model_global = GlobalGenerator(input_nc, output_nc, ngf_global, | |
n_downsample_global, n_blocks_global, | |
norm_layer).model | |
model_global = [model_global[i] for i in range(len(model_global) - 3) | |
] # get rid of final convolution layers | |
self.model = nn.Sequential(*model_global) | |
###### local enhancer layers ##### | |
for n in range(1, n_local_enhancers + 1): | |
# downsample | |
ngf_global = ngf * (2**(n_local_enhancers - n)) | |
model_downsample = [ | |
nn.ReflectionPad2d(3), | |
nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), | |
norm_layer(ngf_global), | |
nn.ReLU(True), | |
nn.Conv2d(ngf_global, | |
ngf_global * 2, | |
kernel_size=3, | |
stride=2, | |
padding=1), | |
norm_layer(ngf_global * 2), | |
nn.ReLU(True) | |
] | |
# residual blocks | |
model_upsample = [] | |
for i in range(n_blocks_local): | |
model_upsample += [ | |
ResnetBlock(ngf_global * 2, | |
padding_type=padding_type, | |
norm_layer=norm_layer) | |
] | |
# upsample | |
model_upsample += [ | |
nn.ConvTranspose2d(ngf_global * 2, | |
ngf_global, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
output_padding=1), | |
norm_layer(ngf_global), | |
nn.ReLU(True) | |
] | |
# final convolution | |
if n == n_local_enhancers: | |
model_upsample += [ | |
nn.ReflectionPad2d(3), | |
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), | |
nn.Tanh() | |
] | |
setattr(self, 'model' + str(n) + '_1', | |
nn.Sequential(*model_downsample)) | |
setattr(self, 'model' + str(n) + '_2', | |
nn.Sequential(*model_upsample)) | |
self.downsample = nn.AvgPool2d(3, | |
stride=2, | |
padding=[1, 1], | |
count_include_pad=False) | |
def forward(self, input): | |
# create input pyramid | |
input_downsampled = [input] | |
for i in range(self.n_local_enhancers): | |
input_downsampled.append(self.downsample(input_downsampled[-1])) | |
# output at coarest level | |
output_prev = self.model(input_downsampled[-1]) | |
# build up one layer at a time | |
for n_local_enhancers in range(1, self.n_local_enhancers + 1): | |
model_downsample = getattr(self, | |
'model' + str(n_local_enhancers) + '_1') | |
model_upsample = getattr(self, | |
'model' + str(n_local_enhancers) + '_2') | |
input_i = input_downsampled[self.n_local_enhancers - | |
n_local_enhancers] | |
output_prev = model_upsample( | |
model_downsample(input_i) + output_prev) | |
return output_prev | |
class GlobalGenerator(pl.LightningModule): | |
def __init__(self, | |
input_nc, | |
output_nc, | |
ngf=64, | |
n_downsampling=3, | |
n_blocks=9, | |
norm_layer=nn.BatchNorm2d, | |
padding_type='reflect', | |
last_op=nn.Tanh()): | |
assert (n_blocks >= 0) | |
super(GlobalGenerator, self).__init__() | |
activation = nn.ReLU(True) | |
model = [ | |
nn.ReflectionPad2d(3), | |
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), | |
norm_layer(ngf), activation | |
] | |
# downsample | |
for i in range(n_downsampling): | |
mult = 2**i | |
model += [ | |
nn.Conv2d(ngf * mult, | |
ngf * mult * 2, | |
kernel_size=3, | |
stride=2, | |
padding=1), | |
norm_layer(ngf * mult * 2), activation | |
] | |
# resnet blocks | |
mult = 2**n_downsampling | |
for i in range(n_blocks): | |
model += [ | |
ResnetBlock(ngf * mult, | |
padding_type=padding_type, | |
activation=activation, | |
norm_layer=norm_layer) | |
] | |
# upsample | |
for i in range(n_downsampling): | |
mult = 2**(n_downsampling - i) | |
model += [ | |
nn.ConvTranspose2d(ngf * mult, | |
int(ngf * mult / 2), | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
output_padding=1), | |
norm_layer(int(ngf * mult / 2)), activation | |
] | |
model += [ | |
nn.ReflectionPad2d(3), | |
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0) | |
] | |
if last_op is not None: | |
model += [last_op] | |
self.model = nn.Sequential(*model) | |
def forward(self, input): | |
return self.model(input) | |
# Define a resnet block | |
class ResnetBlock(pl.LightningModule): | |
def __init__(self, | |
dim, | |
padding_type, | |
norm_layer, | |
activation=nn.ReLU(True), | |
use_dropout=False): | |
super(ResnetBlock, self).__init__() | |
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, | |
activation, use_dropout) | |
def build_conv_block(self, dim, padding_type, norm_layer, activation, | |
use_dropout): | |
conv_block = [] | |
p = 0 | |
if padding_type == 'reflect': | |
conv_block += [nn.ReflectionPad2d(1)] | |
elif padding_type == 'replicate': | |
conv_block += [nn.ReplicationPad2d(1)] | |
elif padding_type == 'zero': | |
p = 1 | |
else: | |
raise NotImplementedError('padding [%s] is not implemented' % | |
padding_type) | |
conv_block += [ | |
nn.Conv2d(dim, dim, kernel_size=3, padding=p), | |
norm_layer(dim), activation | |
] | |
if use_dropout: | |
conv_block += [nn.Dropout(0.5)] | |
p = 0 | |
if padding_type == 'reflect': | |
conv_block += [nn.ReflectionPad2d(1)] | |
elif padding_type == 'replicate': | |
conv_block += [nn.ReplicationPad2d(1)] | |
elif padding_type == 'zero': | |
p = 1 | |
else: | |
raise NotImplementedError('padding [%s] is not implemented' % | |
padding_type) | |
conv_block += [ | |
nn.Conv2d(dim, dim, kernel_size=3, padding=p), | |
norm_layer(dim) | |
] | |
return nn.Sequential(*conv_block) | |
def forward(self, x): | |
out = x + self.conv_block(x) | |
return out | |
class Encoder(pl.LightningModule): | |
def __init__(self, | |
input_nc, | |
output_nc, | |
ngf=32, | |
n_downsampling=4, | |
norm_layer=nn.BatchNorm2d): | |
super(Encoder, self).__init__() | |
self.output_nc = output_nc | |
model = [ | |
nn.ReflectionPad2d(3), | |
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), | |
norm_layer(ngf), | |
nn.ReLU(True) | |
] | |
# downsample | |
for i in range(n_downsampling): | |
mult = 2**i | |
model += [ | |
nn.Conv2d(ngf * mult, | |
ngf * mult * 2, | |
kernel_size=3, | |
stride=2, | |
padding=1), | |
norm_layer(ngf * mult * 2), | |
nn.ReLU(True) | |
] | |
# upsample | |
for i in range(n_downsampling): | |
mult = 2**(n_downsampling - i) | |
model += [ | |
nn.ConvTranspose2d(ngf * mult, | |
int(ngf * mult / 2), | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
output_padding=1), | |
norm_layer(int(ngf * mult / 2)), | |
nn.ReLU(True) | |
] | |
model += [ | |
nn.ReflectionPad2d(3), | |
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), | |
nn.Tanh() | |
] | |
self.model = nn.Sequential(*model) | |
def forward(self, input, inst): | |
outputs = self.model(input) | |
# instance-wise average pooling | |
outputs_mean = outputs.clone() | |
inst_list = np.unique(inst.cpu().numpy().astype(int)) | |
for i in inst_list: | |
for b in range(input.size()[0]): | |
indices = (inst[b:b + 1] == int(i)).nonzero() # n x 4 | |
for j in range(self.output_nc): | |
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, | |
indices[:, 2], indices[:, 3]] | |
mean_feat = torch.mean(output_ins).expand_as(output_ins) | |
outputs_mean[indices[:, 0] + b, indices[:, 1] + j, | |
indices[:, 2], indices[:, 3]] = mean_feat | |
return outputs_mean | |