import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.parameter import Parameter class ResnetGenerator(nn.Module): def __init__(self, ngf=64, img_size=256, light=False): super(ResnetGenerator, self).__init__() self.light = light self.ConvBlock1 = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(3, ngf, kernel_size=7, stride=1, padding=0, bias=False), nn.InstanceNorm2d(ngf), nn.ReLU(True)) self.HourGlass1 = HourGlass(ngf, ngf) self.HourGlass2 = HourGlass(ngf, ngf) # Down-Sampling self.DownBlock1 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=0, bias=False), nn.InstanceNorm2d(ngf * 2), nn.ReLU(True)) self.DownBlock2 = nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=0, bias=False), nn.InstanceNorm2d(ngf*4), nn.ReLU(True)) # Encoder Bottleneck self.EncodeBlock1 = ResnetBlock(ngf*4) self.EncodeBlock2 = ResnetBlock(ngf*4) self.EncodeBlock3 = ResnetBlock(ngf*4) self.EncodeBlock4 = ResnetBlock(ngf*4) # Class Activation Map self.gap_fc = nn.Linear(ngf*4, 1) self.gmp_fc = nn.Linear(ngf*4, 1) self.conv1x1 = nn.Conv2d(ngf*8, ngf*4, kernel_size=1, stride=1) self.relu = nn.ReLU(True) # Gamma, Beta block if self.light: self.FC = nn.Sequential(nn.Linear(ngf*4, ngf*4), nn.ReLU(True), nn.Linear(ngf*4, ngf*4), nn.ReLU(True)) else: self.FC = nn.Sequential(nn.Linear(img_size//4*img_size//4*ngf*4, ngf*4), nn.ReLU(True), nn.Linear(ngf*4, ngf*4), nn.ReLU(True)) # Decoder Bottleneck self.DecodeBlock1 = ResnetSoftAdaLINBlock(ngf*4) self.DecodeBlock2 = ResnetSoftAdaLINBlock(ngf*4) self.DecodeBlock3 = ResnetSoftAdaLINBlock(ngf*4) self.DecodeBlock4 = ResnetSoftAdaLINBlock(ngf*4) # Up-Sampling self.UpBlock1 = nn.Sequential(nn.Upsample(scale_factor=2), nn.ReflectionPad2d(1), nn.Conv2d(ngf*4, ngf*2, kernel_size=3, stride=1, padding=0, bias=False), LIN(ngf*2), nn.ReLU(True)) self.UpBlock2 = nn.Sequential(nn.Upsample(scale_factor=2), nn.ReflectionPad2d(1), nn.Conv2d(ngf*2, ngf, kernel_size=3, stride=1, padding=0, bias=False), LIN(ngf), nn.ReLU(True)) self.HourGlass3 = HourGlass(ngf, ngf) self.HourGlass4 = HourGlass(ngf, ngf, False) self.ConvBlock2 = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(3, 3, kernel_size=7, stride=1, padding=0, bias=False), nn.Tanh()) def forward(self, x): x = self.ConvBlock1(x) x = self.HourGlass1(x) x = self.HourGlass2(x) x = self.DownBlock1(x) x = self.DownBlock2(x) x = self.EncodeBlock1(x) content_features1 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1) x = self.EncodeBlock2(x) content_features2 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1) x = self.EncodeBlock3(x) content_features3 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1) x = self.EncodeBlock4(x) content_features4 = F.adaptive_avg_pool2d(x, 1).view(x.shape[0], -1) gap = F.adaptive_avg_pool2d(x, 1) gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) gap_weight = list(self.gap_fc.parameters())[0] gap = x * gap_weight.unsqueeze(2).unsqueeze(3) gmp = F.adaptive_max_pool2d(x, 1) gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) gmp_weight = list(self.gmp_fc.parameters())[0] gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) cam_logit = torch.cat([gap_logit, gmp_logit], 1) x = torch.cat([gap, gmp], 1) x = self.relu(self.conv1x1(x)) heatmap = torch.sum(x, dim=1, keepdim=True) if self.light: x_ = F.adaptive_avg_pool2d(x, 1) style_features = self.FC(x_.view(x_.shape[0], -1)) else: style_features = self.FC(x.view(x.shape[0], -1)) x = self.DecodeBlock1(x, content_features4, style_features) x = self.DecodeBlock2(x, content_features3, style_features) x = self.DecodeBlock3(x, content_features2, style_features) x = self.DecodeBlock4(x, content_features1, style_features) x = self.UpBlock1(x) x = self.UpBlock2(x) x = self.HourGlass3(x) x = self.HourGlass4(x) out = self.ConvBlock2(x) return out, cam_logit, heatmap class ConvBlock(nn.Module): def __init__(self, dim_in, dim_out): super(ConvBlock, self).__init__() self.dim_out = dim_out self.ConvBlock1 = nn.Sequential(nn.InstanceNorm2d(dim_in), nn.ReLU(True), nn.ReflectionPad2d(1), nn.Conv2d(dim_in, dim_out//2, kernel_size=3, stride=1, bias=False)) self.ConvBlock2 = nn.Sequential(nn.InstanceNorm2d(dim_out//2), nn.ReLU(True), nn.ReflectionPad2d(1), nn.Conv2d(dim_out//2, dim_out//4, kernel_size=3, stride=1, bias=False)) self.ConvBlock3 = nn.Sequential(nn.InstanceNorm2d(dim_out//4), nn.ReLU(True), nn.ReflectionPad2d(1), nn.Conv2d(dim_out//4, dim_out//4, kernel_size=3, stride=1, bias=False)) self.ConvBlock4 = nn.Sequential(nn.InstanceNorm2d(dim_in), nn.ReLU(True), nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, bias=False)) def forward(self, x): residual = x x1 = self.ConvBlock1(x) x2 = self.ConvBlock2(x1) x3 = self.ConvBlock3(x2) out = torch.cat((x1, x2, x3), 1) if residual.size(1) != self.dim_out: residual = self.ConvBlock4(residual) return residual + out class HourGlass(nn.Module): def __init__(self, dim_in, dim_out, use_res=True): super(HourGlass, self).__init__() self.use_res = use_res self.HG = nn.Sequential(HourGlassBlock(dim_in, dim_out), ConvBlock(dim_out, dim_out), nn.Conv2d(dim_out, dim_out, kernel_size=1, stride=1, bias=False), nn.InstanceNorm2d(dim_out), nn.ReLU(True)) self.Conv1 = nn.Conv2d(dim_out, 3, kernel_size=1, stride=1) if self.use_res: self.Conv2 = nn.Conv2d(dim_out, dim_out, kernel_size=1, stride=1) self.Conv3 = nn.Conv2d(3, dim_out, kernel_size=1, stride=1) def forward(self, x): ll = self.HG(x) tmp_out = self.Conv1(ll) if self.use_res: ll = self.Conv2(ll) tmp_out_ = self.Conv3(tmp_out) return x + ll + tmp_out_ else: return tmp_out class HourGlassBlock(nn.Module): def __init__(self, dim_in, dim_out): super(HourGlassBlock, self).__init__() self.ConvBlock1_1 = ConvBlock(dim_in, dim_out) self.ConvBlock1_2 = ConvBlock(dim_out, dim_out) self.ConvBlock2_1 = ConvBlock(dim_out, dim_out) self.ConvBlock2_2 = ConvBlock(dim_out, dim_out) self.ConvBlock3_1 = ConvBlock(dim_out, dim_out) self.ConvBlock3_2 = ConvBlock(dim_out, dim_out) self.ConvBlock4_1 = ConvBlock(dim_out, dim_out) self.ConvBlock4_2 = ConvBlock(dim_out, dim_out) self.ConvBlock5 = ConvBlock(dim_out, dim_out) self.ConvBlock6 = ConvBlock(dim_out, dim_out) self.ConvBlock7 = ConvBlock(dim_out, dim_out) self.ConvBlock8 = ConvBlock(dim_out, dim_out) self.ConvBlock9 = ConvBlock(dim_out, dim_out) def forward(self, x): skip1 = self.ConvBlock1_1(x) down1 = F.avg_pool2d(x, 2) down1 = self.ConvBlock1_2(down1) skip2 = self.ConvBlock2_1(down1) down2 = F.avg_pool2d(down1, 2) down2 = self.ConvBlock2_2(down2) skip3 = self.ConvBlock3_1(down2) down3 = F.avg_pool2d(down2, 2) down3 = self.ConvBlock3_2(down3) skip4 = self.ConvBlock4_1(down3) down4 = F.avg_pool2d(down3, 2) down4 = self.ConvBlock4_2(down4) center = self.ConvBlock5(down4) up4 = self.ConvBlock6(center) up4 = F.upsample(up4, scale_factor=2) up4 = skip4 + up4 up3 = self.ConvBlock7(up4) up3 = F.upsample(up3, scale_factor=2) up3 = skip3 + up3 up2 = self.ConvBlock8(up3) up2 = F.upsample(up2, scale_factor=2) up2 = skip2 + up2 up1 = self.ConvBlock9(up2) up1 = F.upsample(up1, scale_factor=2) up1 = skip1 + up1 return up1 class ResnetBlock(nn.Module): def __init__(self, dim, use_bias=False): super(ResnetBlock, self).__init__() conv_block = [] conv_block += [nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias), nn.InstanceNorm2d(dim), nn.ReLU(True)] conv_block += [nn.ReflectionPad2d(1), nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias), nn.InstanceNorm2d(dim)] self.conv_block = nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) return out class ResnetSoftAdaLINBlock(nn.Module): def __init__(self, dim, use_bias=False): super(ResnetSoftAdaLINBlock, self).__init__() self.pad1 = nn.ReflectionPad2d(1) self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) self.norm1 = SoftAdaLIN(dim) self.relu1 = nn.ReLU(True) self.pad2 = nn.ReflectionPad2d(1) self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) self.norm2 = SoftAdaLIN(dim) def forward(self, x, content_features, style_features): out = self.pad1(x) out = self.conv1(out) out = self.norm1(out, content_features, style_features) out = self.relu1(out) out = self.pad2(out) out = self.conv2(out) out = self.norm2(out, content_features, style_features) return out + x class ResnetAdaLINBlock(nn.Module): def __init__(self, dim, use_bias=False): super(ResnetAdaLINBlock, self).__init__() self.pad1 = nn.ReflectionPad2d(1) self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) self.norm1 = adaLIN(dim) self.relu1 = nn.ReLU(True) self.pad2 = nn.ReflectionPad2d(1) self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=0, bias=use_bias) self.norm2 = adaLIN(dim) def forward(self, x, gamma, beta): out = self.pad1(x) out = self.conv1(out) out = self.norm1(out, gamma, beta) out = self.relu1(out) out = self.pad2(out) out = self.conv2(out) out = self.norm2(out, gamma, beta) return out + x class SoftAdaLIN(nn.Module): def __init__(self, num_features, eps=1e-5): super(SoftAdaLIN, self).__init__() self.norm = adaLIN(num_features, eps) self.w_gamma = Parameter(torch.zeros(1, num_features)) self.w_beta = Parameter(torch.zeros(1, num_features)) self.c_gamma = nn.Sequential(nn.Linear(num_features, num_features), nn.ReLU(True), nn.Linear(num_features, num_features)) self.c_beta = nn.Sequential(nn.Linear(num_features, num_features), nn.ReLU(True), nn.Linear(num_features, num_features)) self.s_gamma = nn.Linear(num_features, num_features) self.s_beta = nn.Linear(num_features, num_features) def forward(self, x, content_features, style_features): content_gamma, content_beta = self.c_gamma(content_features), self.c_beta(content_features) style_gamma, style_beta = self.s_gamma(style_features), self.s_beta(style_features) w_gamma, w_beta = self.w_gamma.expand(x.shape[0], -1), self.w_beta.expand(x.shape[0], -1) soft_gamma = (1. - w_gamma) * style_gamma + w_gamma * content_gamma soft_beta = (1. - w_beta) * style_beta + w_beta * content_beta out = self.norm(x, soft_gamma, soft_beta) return out class adaLIN(nn.Module): def __init__(self, num_features, eps=1e-5): super(adaLIN, self).__init__() self.eps = eps self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) self.rho.data.fill_(0.9) def forward(self, input, gamma, beta): in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True) out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True) out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3) return out class LIN(nn.Module): def __init__(self, num_features, eps=1e-5): super(LIN, self).__init__() self.eps = eps self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1)) self.beta = Parameter(torch.Tensor(1, num_features, 1, 1)) self.rho.data.fill_(0.0) self.gamma.data.fill_(1.0) self.beta.data.fill_(0.0) def forward(self, input): in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True) out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True) out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1) return out class Discriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=5): super(Discriminator, self).__init__() model = [nn.ReflectionPad2d(1), nn.utils.spectral_norm( nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)), nn.LeakyReLU(0.2, True)] for i in range(1, n_layers - 2): mult = 2 ** (i - 1) model += [nn.ReflectionPad2d(1), nn.utils.spectral_norm( nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)), nn.LeakyReLU(0.2, True)] mult = 2 ** (n_layers - 2 - 1) model += [nn.ReflectionPad2d(1), nn.utils.spectral_norm( nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)), nn.LeakyReLU(0.2, True)] # Class Activation Map mult = 2 ** (n_layers - 2) self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False)) self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False)) self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True) self.leaky_relu = nn.LeakyReLU(0.2, True) self.pad = nn.ReflectionPad2d(1) self.conv = nn.utils.spectral_norm( nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False)) self.model = nn.Sequential(*model) def forward(self, input): x = self.model(input) gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) gap_logit = self.gap_fc(gap.view(x.shape[0], -1)) gap_weight = list(self.gap_fc.parameters())[0] gap = x * gap_weight.unsqueeze(2).unsqueeze(3) gmp = torch.nn.functional.adaptive_max_pool2d(x, 1) gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1)) gmp_weight = list(self.gmp_fc.parameters())[0] gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3) cam_logit = torch.cat([gap_logit, gmp_logit], 1) x = torch.cat([gap, gmp], 1) x = self.leaky_relu(self.conv1x1(x)) heatmap = torch.sum(x, dim=1, keepdim=True) x = self.pad(x) out = self.conv(x) return out, cam_logit, heatmap class RhoClipper(object): def __init__(self, min, max): self.clip_min = min self.clip_max = max assert min < max def __call__(self, module): if hasattr(module, 'rho'): w = module.rho.data w = w.clamp(self.clip_min, self.clip_max) module.rho.data = w class WClipper(object): def __init__(self, min, max): self.clip_min = min self.clip_max = max assert min < max def __call__(self, module): if hasattr(module, 'w_gamma'): w = module.w_gamma.data w = w.clamp(self.clip_min, self.clip_max) module.w_gamma.data = w if hasattr(module, 'w_beta'): w = module.w_beta.data w = w.clamp(self.clip_min, self.clip_max) module.w_beta.data = w