|
""" |
|
Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. |
|
BSD License. All rights reserved. |
|
|
|
Redistribution and use in source and binary forms, with or without |
|
modification, are permitted provided that the following conditions are met: |
|
|
|
* Redistributions of source code must retain the above copyright notice, this |
|
list of conditions and the following disclaimer. |
|
|
|
* Redistributions in binary form must reproduce the above copyright notice, |
|
this list of conditions and the following disclaimer in the documentation |
|
and/or other materials provided with the distribution. |
|
|
|
THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL |
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. |
|
IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL |
|
DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, |
|
WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING |
|
OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. |
|
""" |
|
import functools |
|
|
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision import models |
|
|
|
|
|
|
|
|
|
|
|
def weights_init(m): |
|
classname = m.__class__.__name__ |
|
if classname.find("Conv") != -1: |
|
m.weight.data.normal_(0.0, 0.02) |
|
elif classname.find("BatchNorm2d") != -1: |
|
m.weight.data.normal_(1.0, 0.02) |
|
m.bias.data.fill_(0) |
|
|
|
|
|
def get_norm_layer(norm_type="instance"): |
|
if norm_type == "batch": |
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True) |
|
elif norm_type == "instance": |
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) |
|
else: |
|
raise NotImplementedError("normalization layer [%s] is not found" % norm_type) |
|
return norm_layer |
|
|
|
|
|
def define_G( |
|
input_nc, |
|
output_nc, |
|
ngf, |
|
netG, |
|
n_downsample_global=3, |
|
n_blocks_global=9, |
|
n_local_enhancers=1, |
|
n_blocks_local=3, |
|
norm="instance", |
|
gpu_ids=[], |
|
last_op=nn.Tanh(), |
|
): |
|
norm_layer = get_norm_layer(norm_type=norm) |
|
if netG == "global": |
|
netG = GlobalGenerator( |
|
input_nc, |
|
output_nc, |
|
ngf, |
|
n_downsample_global, |
|
n_blocks_global, |
|
norm_layer, |
|
last_op=last_op, |
|
) |
|
elif netG == "local": |
|
netG = LocalEnhancer( |
|
input_nc, |
|
output_nc, |
|
ngf, |
|
n_downsample_global, |
|
n_blocks_global, |
|
n_local_enhancers, |
|
n_blocks_local, |
|
norm_layer, |
|
) |
|
elif netG == "encoder": |
|
netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer) |
|
else: |
|
raise ("generator not implemented!") |
|
|
|
if len(gpu_ids) > 0: |
|
assert torch.cuda.is_available() |
|
netG.cuda(gpu_ids[0]) |
|
netG.apply(weights_init) |
|
return netG |
|
|
|
|
|
def define_D( |
|
input_nc, |
|
ndf, |
|
n_layers_D, |
|
norm='instance', |
|
use_sigmoid=False, |
|
num_D=1, |
|
getIntermFeat=False, |
|
gpu_ids=[] |
|
): |
|
norm_layer = get_norm_layer(norm_type=norm) |
|
netD = MultiscaleDiscriminator( |
|
input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat |
|
) |
|
if len(gpu_ids) > 0: |
|
assert (torch.cuda.is_available()) |
|
netD.cuda(gpu_ids[0]) |
|
netD.apply(weights_init) |
|
return netD |
|
|
|
|
|
def print_network(net): |
|
if isinstance(net, list): |
|
net = net[0] |
|
num_params = 0 |
|
for param in net.parameters(): |
|
num_params += param.numel() |
|
print(net) |
|
print("Total number of parameters: %d" % num_params) |
|
|
|
|
|
|
|
|
|
|
|
class LocalEnhancer(pl.LightningModule): |
|
def __init__( |
|
self, |
|
input_nc, |
|
output_nc, |
|
ngf=32, |
|
n_downsample_global=3, |
|
n_blocks_global=9, |
|
n_local_enhancers=1, |
|
n_blocks_local=3, |
|
norm_layer=nn.BatchNorm2d, |
|
padding_type="reflect", |
|
): |
|
super(LocalEnhancer, self).__init__() |
|
self.n_local_enhancers = n_local_enhancers |
|
|
|
|
|
ngf_global = ngf * (2**n_local_enhancers) |
|
model_global = GlobalGenerator( |
|
input_nc, |
|
output_nc, |
|
ngf_global, |
|
n_downsample_global, |
|
n_blocks_global, |
|
norm_layer, |
|
).model |
|
model_global = [ |
|
model_global[i] for i in range(len(model_global) - 3) |
|
] |
|
self.model = nn.Sequential(*model_global) |
|
|
|
|
|
for n in range(1, n_local_enhancers + 1): |
|
|
|
ngf_global = ngf * (2**(n_local_enhancers - n)) |
|
model_downsample = [ |
|
nn.ReflectionPad2d(3), |
|
nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), |
|
norm_layer(ngf_global), |
|
nn.ReLU(True), |
|
nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1), |
|
norm_layer(ngf_global * 2), |
|
nn.ReLU(True), |
|
] |
|
|
|
model_upsample = [] |
|
for i in range(n_blocks_local): |
|
model_upsample += [ |
|
ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer) |
|
] |
|
|
|
|
|
model_upsample += [ |
|
nn.ConvTranspose2d( |
|
ngf_global * 2, |
|
ngf_global, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
output_padding=1, |
|
), |
|
norm_layer(ngf_global), |
|
nn.ReLU(True), |
|
] |
|
|
|
|
|
if n == n_local_enhancers: |
|
model_upsample += [ |
|
nn.ReflectionPad2d(3), |
|
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), |
|
nn.Tanh(), |
|
] |
|
|
|
setattr(self, "model" + str(n) + "_1", nn.Sequential(*model_downsample)) |
|
setattr(self, "model" + str(n) + "_2", nn.Sequential(*model_upsample)) |
|
|
|
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) |
|
|
|
def forward(self, input): |
|
|
|
input_downsampled = [input] |
|
for i in range(self.n_local_enhancers): |
|
input_downsampled.append(self.downsample(input_downsampled[-1])) |
|
|
|
|
|
output_prev = self.model(input_downsampled[-1]) |
|
|
|
for n_local_enhancers in range(1, self.n_local_enhancers + 1): |
|
model_downsample = getattr(self, "model" + str(n_local_enhancers) + "_1") |
|
model_upsample = getattr(self, "model" + str(n_local_enhancers) + "_2") |
|
input_i = input_downsampled[self.n_local_enhancers - n_local_enhancers] |
|
output_prev = model_upsample(model_downsample(input_i) + output_prev) |
|
return output_prev |
|
|
|
|
|
class GlobalGenerator(pl.LightningModule): |
|
def __init__( |
|
self, |
|
input_nc, |
|
output_nc, |
|
ngf=64, |
|
n_downsampling=3, |
|
n_blocks=9, |
|
norm_layer=nn.BatchNorm2d, |
|
padding_type="reflect", |
|
last_op=nn.Tanh(), |
|
): |
|
assert n_blocks >= 0 |
|
super(GlobalGenerator, self).__init__() |
|
activation = nn.ReLU(True) |
|
|
|
model = [ |
|
nn.ReflectionPad2d(3), |
|
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), |
|
norm_layer(ngf), |
|
activation, |
|
] |
|
|
|
for i in range(n_downsampling): |
|
mult = 2**i |
|
model += [ |
|
nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), |
|
norm_layer(ngf * mult * 2), |
|
activation, |
|
] |
|
|
|
|
|
mult = 2**n_downsampling |
|
for i in range(n_blocks): |
|
model += [ |
|
ResnetBlock( |
|
ngf * mult, |
|
padding_type=padding_type, |
|
activation=activation, |
|
norm_layer=norm_layer, |
|
) |
|
] |
|
|
|
|
|
for i in range(n_downsampling): |
|
mult = 2**(n_downsampling - i) |
|
model += [ |
|
nn.ConvTranspose2d( |
|
ngf * mult, |
|
int(ngf * mult / 2), |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
output_padding=1, |
|
), |
|
norm_layer(int(ngf * mult / 2)), |
|
activation, |
|
] |
|
model += [ |
|
nn.ReflectionPad2d(3), |
|
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), |
|
] |
|
if last_op is not None: |
|
model += [last_op] |
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, input): |
|
return self.model(input) |
|
|
|
|
|
|
|
class NLayerDiscriminator(nn.Module): |
|
def __init__( |
|
self, |
|
input_nc, |
|
ndf=64, |
|
n_layers=3, |
|
norm_layer=nn.BatchNorm2d, |
|
use_sigmoid=False, |
|
getIntermFeat=False |
|
): |
|
super(NLayerDiscriminator, self).__init__() |
|
self.getIntermFeat = getIntermFeat |
|
self.n_layers = n_layers |
|
|
|
kw = 4 |
|
padw = int(np.ceil((kw - 1.0) / 2)) |
|
sequence = [[ |
|
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), |
|
nn.LeakyReLU(0.2, True) |
|
]] |
|
|
|
nf = ndf |
|
for n in range(1, n_layers): |
|
nf_prev = nf |
|
nf = min(nf * 2, 512) |
|
sequence += [[ |
|
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), |
|
norm_layer(nf), |
|
nn.LeakyReLU(0.2, True) |
|
]] |
|
|
|
nf_prev = nf |
|
nf = min(nf * 2, 512) |
|
sequence += [[ |
|
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), |
|
norm_layer(nf), |
|
nn.LeakyReLU(0.2, True) |
|
]] |
|
|
|
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] |
|
|
|
if use_sigmoid: |
|
sequence += [[nn.Sigmoid()]] |
|
|
|
if getIntermFeat: |
|
for n in range(len(sequence)): |
|
setattr(self, 'model' + str(n), nn.Sequential(*sequence[n])) |
|
else: |
|
sequence_stream = [] |
|
for n in range(len(sequence)): |
|
sequence_stream += sequence[n] |
|
self.model = nn.Sequential(*sequence_stream) |
|
|
|
def forward(self, input): |
|
if self.getIntermFeat: |
|
res = [input] |
|
for n in range(self.n_layers + 2): |
|
model = getattr(self, 'model' + str(n)) |
|
res.append(model(res[-1])) |
|
return res[1:] |
|
else: |
|
return self.model(input) |
|
|
|
|
|
class MultiscaleDiscriminator(pl.LightningModule): |
|
def __init__( |
|
self, |
|
input_nc, |
|
ndf=64, |
|
n_layers=3, |
|
norm_layer=nn.BatchNorm2d, |
|
use_sigmoid=False, |
|
num_D=3, |
|
getIntermFeat=False |
|
): |
|
super(MultiscaleDiscriminator, self).__init__() |
|
self.num_D = num_D |
|
self.n_layers = n_layers |
|
self.getIntermFeat = getIntermFeat |
|
|
|
for i in range(num_D): |
|
netD = NLayerDiscriminator( |
|
input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat |
|
) |
|
if getIntermFeat: |
|
for j in range(n_layers + 2): |
|
setattr( |
|
self, 'scale' + str(i) + '_layer' + str(j), getattr(netD, 'model' + str(j)) |
|
) |
|
else: |
|
setattr(self, 'layer' + str(i), netD.model) |
|
|
|
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) |
|
|
|
def singleD_forward(self, model, input): |
|
if self.getIntermFeat: |
|
result = [input] |
|
for i in range(len(model)): |
|
result.append(model[i](result[-1])) |
|
return result[1:] |
|
else: |
|
return [model(input)] |
|
|
|
def forward(self, input): |
|
num_D = self.num_D |
|
result = [] |
|
input_downsampled = input.clone() |
|
for i in range(num_D): |
|
if self.getIntermFeat: |
|
model = [ |
|
getattr(self, 'scale' + str(num_D - 1 - i) + '_layer' + str(j)) |
|
for j in range(self.n_layers + 2) |
|
] |
|
else: |
|
model = getattr(self, 'layer' + str(num_D - 1 - i)) |
|
result.append(self.singleD_forward(model, input_downsampled)) |
|
if i != (num_D - 1): |
|
input_downsampled = self.downsample(input_downsampled) |
|
return result |
|
|
|
|
|
|
|
class ResnetBlock(pl.LightningModule): |
|
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False): |
|
super(ResnetBlock, self).__init__() |
|
self.conv_block = self.build_conv_block( |
|
dim, padding_type, norm_layer, activation, use_dropout |
|
) |
|
|
|
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): |
|
conv_block = [] |
|
p = 0 |
|
if padding_type == "reflect": |
|
conv_block += [nn.ReflectionPad2d(1)] |
|
elif padding_type == "replicate": |
|
conv_block += [nn.ReplicationPad2d(1)] |
|
elif padding_type == "zero": |
|
p = 1 |
|
else: |
|
raise NotImplementedError("padding [%s] is not implemented" % padding_type) |
|
|
|
conv_block += [ |
|
nn.Conv2d(dim, dim, kernel_size=3, padding=p), |
|
norm_layer(dim), |
|
activation, |
|
] |
|
if use_dropout: |
|
conv_block += [nn.Dropout(0.5)] |
|
|
|
p = 0 |
|
if padding_type == "reflect": |
|
conv_block += [nn.ReflectionPad2d(1)] |
|
elif padding_type == "replicate": |
|
conv_block += [nn.ReplicationPad2d(1)] |
|
elif padding_type == "zero": |
|
p = 1 |
|
else: |
|
raise NotImplementedError("padding [%s] is not implemented" % padding_type) |
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim)] |
|
|
|
return nn.Sequential(*conv_block) |
|
|
|
def forward(self, x): |
|
out = x + self.conv_block(x) |
|
return out |
|
|
|
|
|
class Encoder(pl.LightningModule): |
|
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): |
|
super(Encoder, self).__init__() |
|
self.output_nc = output_nc |
|
|
|
model = [ |
|
nn.ReflectionPad2d(3), |
|
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), |
|
norm_layer(ngf), |
|
nn.ReLU(True), |
|
] |
|
|
|
for i in range(n_downsampling): |
|
mult = 2**i |
|
model += [ |
|
nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), |
|
norm_layer(ngf * mult * 2), |
|
nn.ReLU(True), |
|
] |
|
|
|
|
|
for i in range(n_downsampling): |
|
mult = 2**(n_downsampling - i) |
|
model += [ |
|
nn.ConvTranspose2d( |
|
ngf * mult, |
|
int(ngf * mult / 2), |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
output_padding=1, |
|
), |
|
norm_layer(int(ngf * mult / 2)), |
|
nn.ReLU(True), |
|
] |
|
|
|
model += [ |
|
nn.ReflectionPad2d(3), |
|
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), |
|
nn.Tanh(), |
|
] |
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, input, inst): |
|
outputs = self.model(input) |
|
|
|
|
|
outputs_mean = outputs.clone() |
|
inst_list = np.unique(inst.cpu().numpy().astype(int)) |
|
for i in inst_list: |
|
for b in range(input.size()[0]): |
|
indices = (inst[b:b + 1] == int(i)).nonzero() |
|
for j in range(self.output_nc): |
|
output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], |
|
indices[:, 3], ] |
|
mean_feat = torch.mean(output_ins).expand_as(output_ins) |
|
outputs_mean[indices[:, 0] + b, indices[:, 1] + j, indices[:, 2], |
|
indices[:, 3], ] = mean_feat |
|
return outputs_mean |
|
|
|
|
|
class Vgg19(nn.Module): |
|
def __init__(self, requires_grad=False): |
|
super(Vgg19, self).__init__() |
|
vgg_pretrained_features = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features |
|
self.slice1 = torch.nn.Sequential() |
|
self.slice2 = torch.nn.Sequential() |
|
self.slice3 = torch.nn.Sequential() |
|
self.slice4 = torch.nn.Sequential() |
|
self.slice5 = torch.nn.Sequential() |
|
for x in range(2): |
|
self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(2, 7): |
|
self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(7, 12): |
|
self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(12, 21): |
|
self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(21, 30): |
|
self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
|
if not requires_grad: |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, X): |
|
h_relu1 = self.slice1(X) |
|
h_relu2 = self.slice2(h_relu1) |
|
h_relu3 = self.slice3(h_relu2) |
|
h_relu4 = self.slice4(h_relu3) |
|
h_relu5 = self.slice5(h_relu4) |
|
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] |
|
return out |
|
|
|
|
|
class VGG19FeatLayer(nn.Module): |
|
def __init__(self): |
|
super(VGG19FeatLayer, self).__init__() |
|
self.vgg19 = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features.eval() |
|
|
|
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) |
|
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) |
|
|
|
def forward(self, x): |
|
|
|
out = {} |
|
x = x - self.mean |
|
x = x / self.std |
|
ci = 1 |
|
ri = 0 |
|
for layer in self.vgg19.children(): |
|
if isinstance(layer, nn.Conv2d): |
|
ri += 1 |
|
name = 'conv{}_{}'.format(ci, ri) |
|
elif isinstance(layer, nn.ReLU): |
|
ri += 1 |
|
name = 'relu{}_{}'.format(ci, ri) |
|
layer = nn.ReLU(inplace=False) |
|
elif isinstance(layer, nn.MaxPool2d): |
|
ri = 0 |
|
name = 'pool_{}'.format(ci) |
|
ci += 1 |
|
elif isinstance(layer, nn.BatchNorm2d): |
|
name = 'bn_{}'.format(ci) |
|
else: |
|
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) |
|
x = layer(x) |
|
out[name] = x |
|
|
|
return out |
|
|
|
|
|
class VGGLoss(pl.LightningModule): |
|
def __init__(self): |
|
super(VGGLoss, self).__init__() |
|
self.vgg = Vgg19().eval() |
|
self.criterion = nn.L1Loss() |
|
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] |
|
|
|
def forward(self, x, y): |
|
x_vgg, y_vgg = self.vgg(x), self.vgg(y) |
|
loss = 0 |
|
for i in range(len(x_vgg)): |
|
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) |
|
return loss |
|
|
|
|
|
class GANLoss(pl.LightningModule): |
|
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): |
|
super(GANLoss, self).__init__() |
|
self.real_label = target_real_label |
|
self.fake_label = target_fake_label |
|
self.real_label_var = None |
|
self.fake_label_var = None |
|
self.tensor = torch.cuda.FloatTensor |
|
if use_lsgan: |
|
self.loss = nn.MSELoss() |
|
else: |
|
self.loss = nn.BCELoss() |
|
|
|
def get_target_tensor(self, input, target_is_real): |
|
target_tensor = None |
|
if target_is_real: |
|
create_label = ((self.real_label_var is None) or |
|
(self.real_label_var.numel() != input.numel())) |
|
if create_label: |
|
real_tensor = self.tensor(input.size()).fill_(self.real_label) |
|
self.real_label_var = real_tensor |
|
self.real_label_var.requires_grad = False |
|
target_tensor = self.real_label_var |
|
else: |
|
create_label = ((self.fake_label_var is None) or |
|
(self.fake_label_var.numel() != input.numel())) |
|
if create_label: |
|
fake_tensor = self.tensor(input.size()).fill_(self.fake_label) |
|
self.fake_label_var = fake_tensor |
|
self.fake_label_var.requires_grad = False |
|
target_tensor = self.fake_label_var |
|
return target_tensor |
|
|
|
def __call__(self, input, target_is_real): |
|
if isinstance(input[0], list): |
|
loss = 0 |
|
for input_i in input: |
|
pred = input_i[-1] |
|
target_tensor = self.get_target_tensor(pred, target_is_real) |
|
loss += self.loss(pred, target_tensor) |
|
return loss |
|
else: |
|
target_tensor = self.get_target_tensor(input[-1], target_is_real) |
|
return self.loss(input[-1], target_tensor) |
|
|
|
|
|
class IDMRFLoss(pl.LightningModule): |
|
def __init__(self, featlayer=VGG19FeatLayer): |
|
super(IDMRFLoss, self).__init__() |
|
self.featlayer = featlayer() |
|
self.feat_style_layers = {'relu3_2': 1.0, 'relu4_2': 1.0} |
|
self.feat_content_layers = {'relu4_2': 1.0} |
|
self.bias = 1.0 |
|
self.nn_stretch_sigma = 0.5 |
|
self.lambda_style = 1.0 |
|
self.lambda_content = 1.0 |
|
|
|
def sum_normalize(self, featmaps): |
|
reduce_sum = torch.sum(featmaps, dim=1, keepdim=True) |
|
return featmaps / reduce_sum |
|
|
|
def patch_extraction(self, featmaps): |
|
patch_size = 1 |
|
patch_stride = 1 |
|
patches_as_depth_vectors = featmaps.unfold(2, patch_size, patch_stride).unfold( |
|
3, patch_size, patch_stride |
|
) |
|
self.patches_OIHW = patches_as_depth_vectors.permute(0, 2, 3, 1, 4, 5) |
|
dims = self.patches_OIHW.size() |
|
self.patches_OIHW = self.patches_OIHW.view(-1, dims[3], dims[4], dims[5]) |
|
return self.patches_OIHW |
|
|
|
def compute_relative_distances(self, cdist): |
|
epsilon = 1e-5 |
|
div = torch.min(cdist, dim=1, keepdim=True)[0] |
|
relative_dist = cdist / (div + epsilon) |
|
return relative_dist |
|
|
|
def exp_norm_relative_dist(self, relative_dist): |
|
scaled_dist = relative_dist |
|
dist_before_norm = torch.exp((self.bias - scaled_dist) / self.nn_stretch_sigma) |
|
self.cs_NCHW = self.sum_normalize(dist_before_norm) |
|
return self.cs_NCHW |
|
|
|
def mrf_loss(self, gen, tar): |
|
meanT = torch.mean(tar, 1, keepdim=True) |
|
gen_feats, tar_feats = gen - meanT, tar - meanT |
|
|
|
gen_feats_norm = torch.norm(gen_feats, p=2, dim=1, keepdim=True) |
|
tar_feats_norm = torch.norm(tar_feats, p=2, dim=1, keepdim=True) |
|
|
|
gen_normalized = gen_feats / gen_feats_norm |
|
tar_normalized = tar_feats / tar_feats_norm |
|
|
|
cosine_dist_l = [] |
|
BatchSize = tar.size(0) |
|
|
|
for i in range(BatchSize): |
|
tar_feat_i = tar_normalized[i:i + 1, :, :, :] |
|
gen_feat_i = gen_normalized[i:i + 1, :, :, :] |
|
patches_OIHW = self.patch_extraction(tar_feat_i) |
|
|
|
cosine_dist_i = F.conv2d(gen_feat_i, patches_OIHW) |
|
cosine_dist_l.append(cosine_dist_i) |
|
cosine_dist = torch.cat(cosine_dist_l, dim=0) |
|
cosine_dist_zero_2_one = -(cosine_dist - 1) / 2 |
|
relative_dist = self.compute_relative_distances(cosine_dist_zero_2_one) |
|
rela_dist = self.exp_norm_relative_dist(relative_dist) |
|
dims_div_mrf = rela_dist.size() |
|
k_max_nc = torch.max(rela_dist.view(dims_div_mrf[0], dims_div_mrf[1], -1), dim=2)[0] |
|
div_mrf = torch.mean(k_max_nc, dim=1) |
|
div_mrf_sum = -torch.log(div_mrf) |
|
div_mrf_sum = torch.sum(div_mrf_sum) |
|
return div_mrf_sum |
|
|
|
def forward(self, gen, tar): |
|
|
|
gen_vgg_feats = self.featlayer(gen) |
|
tar_vgg_feats = self.featlayer(tar) |
|
style_loss_list = [ |
|
self.feat_style_layers[layer] * |
|
self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) |
|
for layer in self.feat_style_layers |
|
] |
|
self.style_loss = functools.reduce(lambda x, y: x + y, style_loss_list) * self.lambda_style |
|
|
|
content_loss_list = [ |
|
self.feat_content_layers[layer] * |
|
self.mrf_loss(gen_vgg_feats[layer], tar_vgg_feats[layer]) |
|
for layer in self.feat_content_layers |
|
] |
|
self.content_loss = functools.reduce( |
|
lambda x, y: x + y, content_loss_list |
|
) * self.lambda_content |
|
|
|
return self.style_loss + self.content_loss |
|
|