Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.parameter import Parameter | |
class ResnetGenerator(nn.Module): | |
def __init__(self, ngf=64, img_size=256, light=False): | |
super(ResnetGenerator, self).__init__() | |
self.light = light | |
self.ConvBlock1 = nn.Sequential(nn.ReflectionPad2d(3), | |
nn.Conv2d(3, ngf, kernel_size=7, stride=1, padding=0, bias=False), | |
nn.InstanceNorm2d(ngf), | |
nn.ReLU(True)) | |
self.HourGlass1 = HourGlass(ngf, ngf) | |
self.HourGlass2 = HourGlass(ngf, ngf) | |
# Down-Sampling | |
self.DownBlock1 = nn.Sequential(nn.ReflectionPad2d(1), | |
nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=0, bias=False), | |
nn.InstanceNorm2d(ngf * 2), | |
nn.ReLU(True)) | |
self.DownBlock2 = nn.Sequential(nn.ReflectionPad2d(1), | |
nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=0, bias=False), | |
nn.InstanceNorm2d(ngf*4), | |
nn.ReLU(True)) | |
# Encoder Bottleneck | |
self.EncodeBlock1 = ResnetBlock(ngf*4) | |
self.EncodeBlock2 = ResnetBlock(ngf*4) | |
self.EncodeBlock3 = ResnetBlock(ngf*4) | |
self.EncodeBlock4 = ResnetBlock(ngf*4) | |
# Class Activation Map | |
self.gap_fc = nn.Linear(ngf*4, 1) | |
self.gmp_fc = nn.Linear(ngf*4, 1) | |
self.conv1x1 = nn.Conv2d(ngf*8, ngf*4, kernel_size=1, stride=1) | |
self.relu = nn.ReLU(True) | |
# Gamma, Beta block | |
if self.light: | |
self.FC = nn.Sequential(nn.Linear(ngf*4, ngf*4), | |
nn.ReLU(True), | |
nn.Linear(ngf*4, ngf*4), | |
nn.ReLU(True)) | |
else: | |
self.FC = nn.Sequential(nn.Linear(img_size//4*img_size//4*ngf*4, ngf*4), | |
nn.ReLU(True), | |
nn.Linear(ngf*4, ngf*4), | |
nn.ReLU(True)) | |
# Decoder Bottleneck | |
self.DecodeBlock1 = ResnetSoftAdaLINBlock(ngf*4) | |
self.DecodeBlock2 = ResnetSoftAdaLINBlock(ngf*4) | |
self.DecodeBlock3 = ResnetSoftAdaLINBlock(ngf*4) | |
self.DecodeBlock4 = ResnetSoftAdaLINBlock(ngf*4) | |
# Up-Sampling | |
self.UpBlock1 = nn.Sequential(nn.Upsample(scale_factor=2), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(ngf*4, ngf*2, kernel_size=3, stride=1, padding=0, bias=False), | |
LIN(ngf*2), | |
nn.ReLU(True)) | |
self.UpBlock2 = nn.Sequential(nn.Upsample(scale_factor=2), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(ngf*2, ngf, kernel_size=3, stride=1, padding=0, bias=False), | |
LIN(ngf), | |
nn.ReLU(True)) | |
self.HourGlass3 = HourGlass(ngf, ngf) | |
self.HourGlass4 = HourGlass(ngf, ngf, False) | |
self.ConvBlock2 = nn.Sequential(nn.ReflectionPad2d(3), | |
nn.Conv2d(3, 3, kernel_size=7, stride=1, padding=0, bias=False), | |
nn.Tanh()) | |
def forward(self, x): | |
x = self.ConvBlock1(x) | |
x = self.HourGlass1(x) | |
x = self.HourGlass2(x) | |
x = self.DownBlock1(x) | |
x = self.DownBlock2(x) | |
x = self.EncodeBlock1(x) | |
content_features1 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1) | |
x = self.EncodeBlock2(x) | |
content_features2 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1) | |
x = self.EncodeBlock3(x) | |
content_features3 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1) | |
x = self.EncodeBlock4(x) | |
content_features4 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1) | |
gap = F.adaptive_avg_pool2d(x, 1) | |
gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) | |
gap_weight = list(self.gap_fc.parameters())[0] | |
gap = x * gap_weight.unsqueeze(2).unsqueeze(3) | |
gmp = F.adaptive_max_pool2d(x, 1) | |
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) | |
gmp_weight = list(self.gmp_fc.parameters())[0] | |
gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) | |
cam_logit = torch.cat([gap_logit, gmp_logit], 1) | |
x = torch.cat([gap, gmp], 1) | |
x = self.relu(self.conv1x1(x)) | |
heatmap = torch.sum(x, dim=1, keepdim=True) | |
if self.light: | |
x_ = F.adaptive_avg_pool2d(x, 1) | |
style_features = self.FC(x_.view(x_.shape[0], -1)) | |
else: | |
style_features = self.FC(x.view(x.shape[0], -1)) | |
x = self.DecodeBlock1(x, content_features4, style_features) | |
x = self.DecodeBlock2(x, content_features3, style_features) | |
x = self.DecodeBlock3(x, content_features2, style_features) | |
x = self.DecodeBlock4(x, content_features1, style_features) | |
x = self.UpBlock1(x) | |
x = self.UpBlock2(x) | |
x = self.HourGlass3(x) | |
x = self.HourGlass4(x) | |
out = self.ConvBlock2(x) | |
return out, cam_logit, heatmap | |
class ConvBlock(nn.Module): | |
def __init__(self, dim_in, dim_out): | |
super(ConvBlock, self).__init__() | |
self.dim_out = dim_out | |
self.ConvBlock1 = nn.Sequential(nn.InstanceNorm2d(dim_in), | |
nn.ReLU(True), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(dim_in, dim_out//2, kernel_size=3, stride=1, bias=False)) | |
self.ConvBlock2 = nn.Sequential(nn.InstanceNorm2d(dim_out//2), | |
nn.ReLU(True), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(dim_out//2, dim_out//4, kernel_size=3, stride=1, bias=False)) | |
self.ConvBlock3 = nn.Sequential(nn.InstanceNorm2d(dim_out//4), | |
nn.ReLU(True), | |
nn.ReflectionPad2d(1), | |
nn.Conv2d(dim_out//4, dim_out//4, kernel_size=3, stride=1, bias=False)) | |
self.ConvBlock4 = nn.Sequential(nn.InstanceNorm2d(dim_in), | |
nn.ReLU(True), | |
nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False)) | |
def forward(self, x): | |
residual = x | |
x1 = self.ConvBlock1(x) | |
x2 = self.ConvBlock2(x1) | |
x3 = self.ConvBlock3(x2) | |
out = torch.cat((x1, x2, x3), 1) | |
if residual.size(1) != self.dim_out: | |
residual = self.ConvBlock4(residual) | |
return residual + out | |
class HourGlass(nn.Module): | |
def __init__(self, dim_in, dim_out, use_res=True): | |
super(HourGlass, self).__init__() | |
self.use_res = use_res | |
self.HG = nn.Sequential(HourGlassBlock(dim_in, dim_out), | |
ConvBlock(dim_out, dim_out), | |
nn.Conv2d(dim_out, dim_out, kernel_size=1, stride=1, bias=False), | |
nn.InstanceNorm2d(dim_out), | |
nn.ReLU(True)) | |
self.Conv1 = nn.Conv2d(dim_out, 3, kernel_size=1, stride=1) | |
if self.use_res: | |
self.Conv2 = nn.Conv2d(dim_out, dim_out, kernel_size=1, stride=1) | |
self.Conv3 = nn.Conv2d(3, dim_out, kernel_size=1, stride=1) | |
def forward(self, x): | |
ll = self.HG(x) | |
tmp_out = self.Conv1(ll) | |
if self.use_res: | |
ll = self.Conv2(ll) | |
tmp_out_ = self.Conv3(tmp_out) | |
return x + ll + tmp_out_ | |
else: | |
return tmp_out | |
class HourGlassBlock(nn.Module): | |
def __init__(self, dim_in, dim_out): | |
super(HourGlassBlock, self).__init__() | |
self.ConvBlock1_1 = ConvBlock(dim_in, dim_out) | |
self.ConvBlock1_2 = ConvBlock(dim_out, dim_out) | |
self.ConvBlock2_1 = ConvBlock(dim_out, dim_out) | |
self.ConvBlock2_2 = ConvBlock(dim_out, dim_out) | |
self.ConvBlock3_1 = ConvBlock(dim_out, dim_out) | |
self.ConvBlock3_2 = ConvBlock(dim_out, dim_out) | |
self.ConvBlock4_1 = ConvBlock(dim_out, dim_out) | |
self.ConvBlock4_2 = ConvBlock(dim_out, dim_out) | |
self.ConvBlock5 = ConvBlock(dim_out, dim_out) | |
self.ConvBlock6 = ConvBlock(dim_out, dim_out) | |
self.ConvBlock7 = ConvBlock(dim_out, dim_out) | |
self.ConvBlock8 = ConvBlock(dim_out, dim_out) | |
self.ConvBlock9 = ConvBlock(dim_out, dim_out) | |
def forward(self, x): | |
skip1 = self.ConvBlock1_1(x) | |
down1 = F.avg_pool2d(x, 2) | |
down1 = self.ConvBlock1_2(down1) | |
skip2 = self.ConvBlock2_1(down1) | |
down2 = F.avg_pool2d(down1, 2) | |
down2 = self.ConvBlock2_2(down2) | |
skip3 = self.ConvBlock3_1(down2) | |
down3 = F.avg_pool2d(down2, 2) | |
down3 = self.ConvBlock3_2(down3) | |
skip4 = self.ConvBlock4_1(down3) | |
down4 = F.avg_pool2d(down3, 2) | |
down4 = self.ConvBlock4_2(down4) | |
center = self.ConvBlock5(down4) | |
up4 = self.ConvBlock6(center) | |
up4 = F.upsample(up4, scale_factor=2) | |
up4 = skip4 + up4 | |
up3 = self.ConvBlock7(up4) | |
up3 = F.upsample(up3, scale_factor=2) | |
up3 = skip3 + up3 | |
up2 = self.ConvBlock8(up3) | |
up2 = F.upsample(up2, scale_factor=2) | |
up2 = skip2 + up2 | |
up1 = self.ConvBlock9(up2) | |
up1 = F.upsample(up1, scale_factor=2) | |
up1 = skip1 + up1 | |
return up1 | |
class ResnetBlock(nn.Module): | |
def __init__(self, dim, use_bias=False): | |
super(ResnetBlock, self).__init__() | |
conv_block = [] | |
conv_block += [nn.ReflectionPad2d(1), | |
nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias), | |
nn.InstanceNorm2d(dim), | |
nn.ReLU(True)] | |
conv_block += [nn.ReflectionPad2d(1), | |
nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias), | |
nn.InstanceNorm2d(dim)] | |
self.conv_block = nn.Sequential(*conv_block) | |
def forward(self, x): | |
out = x + self.conv_block(x) | |
return out | |
class ResnetSoftAdaLINBlock(nn.Module): | |
def __init__(self, dim, use_bias=False): | |
super(ResnetSoftAdaLINBlock, self).__init__() | |
self.pad1 = nn.ReflectionPad2d(1) | |
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) | |
self.norm1 = SoftAdaLIN(dim) | |
self.relu1 = nn.ReLU(True) | |
self.pad2 = nn.ReflectionPad2d(1) | |
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) | |
self.norm2 = SoftAdaLIN(dim) | |
def forward(self, x, content_features, style_features): | |
out = self.pad1(x) | |
out = self.conv1(out) | |
out = self.norm1(out, content_features, style_features) | |
out = self.relu1(out) | |
out = self.pad2(out) | |
out = self.conv2(out) | |
out = self.norm2(out, content_features, style_features) | |
return out + x | |
class ResnetAdaLINBlock(nn.Module): | |
def __init__(self, dim, use_bias=False): | |
super(ResnetAdaLINBlock, self).__init__() | |
self.pad1 = nn.ReflectionPad2d(1) | |
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) | |
self.norm1 = adaLIN(dim) | |
self.relu1 = nn.ReLU(True) | |
self.pad2 = nn.ReflectionPad2d(1) | |
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) | |
self.norm2 = adaLIN(dim) | |
def forward(self, x, gamma, beta): | |
out = self.pad1(x) | |
out = self.conv1(out) | |
out = self.norm1(out, gamma, beta) | |
out = self.relu1(out) | |
out = self.pad2(out) | |
out = self.conv2(out) | |
out = self.norm2(out, gamma, beta) | |
return out + x | |
class SoftAdaLIN(nn.Module): | |
def __init__(self, num_features, eps=1e-5): | |
super(SoftAdaLIN, self).__init__() | |
self.norm = adaLIN(num_features, eps) | |
self.w_gamma = Parameter(torch.zeros(1, num_features)) | |
self.w_beta = Parameter(torch.zeros(1, num_features)) | |
self.c_gamma = nn.Sequential(nn.Linear(num_features, num_features), | |
nn.ReLU(True), | |
nn.Linear(num_features, num_features)) | |
self.c_beta = nn.Sequential(nn.Linear(num_features, num_features), | |
nn.ReLU(True), | |
nn.Linear(num_features, num_features)) | |
self.s_gamma = nn.Linear(num_features, num_features) | |
self.s_beta = nn.Linear(num_features, num_features) | |
def forward(self, x, content_features, style_features): | |
content_gamma, content_beta = self.c_gamma(content_features), self.c_beta(content_features) | |
style_gamma, style_beta = self.s_gamma(style_features), self.s_beta(style_features) | |
w_gamma, w_beta = self.w_gamma.expand(x.shape[0], -1), self.w_beta.expand(x.shape[0], -1) | |
soft_gamma = (1. - w_gamma) * style_gamma + w_gamma * content_gamma | |
soft_beta = (1. - w_beta) * style_beta + w_beta * content_beta | |
out = self.norm(x, soft_gamma, soft_beta) | |
return out | |
class adaLIN(nn.Module): | |
def __init__(self, num_features, eps=1e-5): | |
super(adaLIN, self).__init__() | |
self.eps = eps | |
self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) | |
self.rho.data.fill_(0.9) | |
def forward(self, input, gamma, beta): | |
in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True) | |
out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) | |
ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True) | |
out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) | |
out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln | |
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) | |
return out | |
class LIN(nn.Module): | |
def __init__(self, num_features, eps=1e-5): | |
super(LIN, self).__init__() | |
self.eps = eps | |
self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) | |
self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1)) | |
self.beta = Parameter(torch.Tensor(1, num_features, 1, 1)) | |
self.rho.data.fill_(0.0) | |
self.gamma.data.fill_(1.0) | |
self.beta.data.fill_(0.0) | |
def forward(self, input): | |
in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True) | |
out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) | |
ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True) | |
out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) | |
out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln | |
out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1) | |
return out | |
class Discriminator(nn.Module): | |
def __init__(self, input_nc, ndf=64, n_layers=5): | |
super(Discriminator, self).__init__() | |
model = [nn.ReflectionPad2d(1), | |
nn.utils.spectral_norm( | |
nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)), | |
nn.LeakyReLU(0.2, True)] | |
for i in range(1, n_layers - 2): | |
mult = 2 ** (i - 1) | |
model += [nn.ReflectionPad2d(1), | |
nn.utils.spectral_norm( | |
nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)), | |
nn.LeakyReLU(0.2, True)] | |
mult = 2 ** (n_layers - 2 - 1) | |
model += [nn.ReflectionPad2d(1), | |
nn.utils.spectral_norm( | |
nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)), | |
nn.LeakyReLU(0.2, True)] | |
# Class Activation Map | |
mult = 2 ** (n_layers - 2) | |
self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False)) | |
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False)) | |
self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True) | |
self.leaky_relu = nn.LeakyReLU(0.2, True) | |
self.pad = nn.ReflectionPad2d(1) | |
self.conv = nn.utils.spectral_norm( | |
nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False)) | |
self.model = nn.Sequential(*model) | |
def forward(self, input): | |
x = self.model(input) | |
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) | |
gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) | |
gap_weight = list(self.gap_fc.parameters())[0] | |
gap = x * gap_weight.unsqueeze(2).unsqueeze(3) | |
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) | |
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) | |
gmp_weight = list(self.gmp_fc.parameters())[0] | |
gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) | |
cam_logit = torch.cat([gap_logit, gmp_logit], 1) | |
x = torch.cat([gap, gmp], 1) | |
x = self.leaky_relu(self.conv1x1(x)) | |
heatmap = torch.sum(x, dim=1, keepdim=True) | |
x = self.pad(x) | |
out = self.conv(x) | |
return out, cam_logit, heatmap | |
class RhoClipper(object): | |
def __init__(self, min, max): | |
self.clip_min = min | |
self.clip_max = max | |
assert min < max | |
def __call__(self, module): | |
if hasattr(module, 'rho'): | |
w = module.rho.data | |
w = w.clamp(self.clip_min, self.clip_max) | |
module.rho.data = w | |
class WClipper(object): | |
def __init__(self, min, max): | |
self.clip_min = min | |
self.clip_max = max | |
assert min < max | |
def __call__(self, module): | |
if hasattr(module, 'w_gamma'): | |
w = module.w_gamma.data | |
w = w.clamp(self.clip_min, self.clip_max) | |
module.w_gamma.data = w | |
if hasattr(module, 'w_beta'): | |
w = module.w_beta.data | |
w = w.clamp(self.clip_min, self.clip_max) | |
module.w_beta.data = w | |