from enum import Enum import math import numpy as np import torch from torch import nn from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add from models.stylegan2.model import EqualLinear # Adapted from https://github.com/omertov/encoder4editing class ProgressiveStage(Enum): WTraining = 0 Delta1Training = 1 Delta2Training = 2 Delta3Training = 3 Delta4Training = 4 Delta5Training = 5 Delta6Training = 6 Delta7Training = 7 Delta8Training = 8 Delta9Training = 9 Delta10Training = 10 Delta11Training = 11 Delta12Training = 12 Delta13Training = 13 Delta14Training = 14 Delta15Training = 15 Delta16Training = 16 Delta17Training = 17 Inference = 18 class GradualStyleBlock(Module): def __init__(self, in_c, out_c, spatial): super(GradualStyleBlock, self).__init__() self.out_c = out_c self.spatial = spatial num_pools = int(np.log2(spatial)) modules = [] modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU()] for i in range(num_pools - 1): modules += [ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1), nn.LeakyReLU() ] self.convs = nn.Sequential(*modules) self.linear = EqualLinear(out_c, out_c, lr_mul=1) def forward(self, x): x = self.convs(x) x = x.view(-1, self.out_c) x = self.linear(x) return x class GradualStyleEncoder(Module): def __init__(self, num_layers, mode='ir', opts=None): super(GradualStyleEncoder, self).__init__() assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' blocks = get_blocks(num_layers) if mode == 'ir': unit_module = bottleneck_IR elif mode == 'ir_se': unit_module = bottleneck_IR_SE self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) modules = [] for block in blocks: for bottleneck in block: modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) self.styles = nn.ModuleList() log_size = int(math.log(opts.stylegan_size, 2)) self.style_count = 2 * log_size - 2 self.coarse_ind = 3 self.middle_ind = 7 for i in range(self.style_count): if i < self.coarse_ind: style = GradualStyleBlock(512, 512, 16) elif i < self.middle_ind: style = GradualStyleBlock(512, 512, 32) else: style = GradualStyleBlock(512, 512, 64) self.styles.append(style) self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) def forward(self, x): x = self.input_layer(x) latents = [] modulelist = list(self.body._modules.values()) for i, l in enumerate(modulelist): x = l(x) if i == 6: c1 = x elif i == 20: c2 = x elif i == 23: c3 = x for j in range(self.coarse_ind): latents.append(self.styles[j](c3)) p2 = _upsample_add(c3, self.latlayer1(c2)) for j in range(self.coarse_ind, self.middle_ind): latents.append(self.styles[j](p2)) p1 = _upsample_add(p2, self.latlayer2(c1)) for j in range(self.middle_ind, self.style_count): latents.append(self.styles[j](p1)) out = torch.stack(latents, dim=1) return out class Encoder4Editing(Module): def __init__(self, num_layers, mode='ir', stylegan_size=1024): super(Encoder4Editing, self).__init__() assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' blocks = get_blocks(num_layers) if mode == 'ir': unit_module = bottleneck_IR elif mode == 'ir_se': unit_module = bottleneck_IR_SE self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) modules = [] for block in blocks: for bottleneck in block: modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) self.styles = nn.ModuleList() log_size = int(math.log(stylegan_size, 2)) self.style_count = 2 * log_size - 2 self.coarse_ind = 3 self.middle_ind = 7 for i in range(self.style_count): if i < self.coarse_ind: style = GradualStyleBlock(512, 512, 16) elif i < self.middle_ind: style = GradualStyleBlock(512, 512, 32) else: style = GradualStyleBlock(512, 512, 64) self.styles.append(style) self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0) self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0) self.progressive_stage = ProgressiveStage.Inference def get_deltas_starting_dimensions(self): ''' Get a list of the initial dimension of every delta from which it is applied ''' return list(range(self.style_count)) # Each dimension has a delta applied to it def set_progressive_stage(self, new_stage: ProgressiveStage): self.progressive_stage = new_stage print('Changed progressive stage to: ', new_stage) def forward(self, x): x = self.input_layer(x) modulelist = list(self.body._modules.values()) for i, l in enumerate(modulelist): x = l(x) if i == 6: c1 = x elif i == 20: c2 = x elif i == 23: c3 = x # Infer main W and duplicate it w0 = self.styles[0](c3) w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2) stage = self.progressive_stage.value features = c3 for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas if i == self.coarse_ind: p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features features = p2 elif i == self.middle_ind: p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features features = p1 delta_i = self.styles[i](features) w[:, i] += delta_i return w class BackboneEncoderUsingLastLayerIntoW(Module): def __init__(self, num_layers, mode='ir', opts=None): super(BackboneEncoderUsingLastLayerIntoW, self).__init__() print('Using BackboneEncoderUsingLastLayerIntoW') assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' blocks = get_blocks(num_layers) if mode == 'ir': unit_module = bottleneck_IR elif mode == 'ir_se': unit_module = bottleneck_IR_SE self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) self.linear = EqualLinear(512, 512, lr_mul=1) modules = [] for block in blocks: for bottleneck in block: modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) self.body = Sequential(*modules) log_size = int(math.log(opts.stylegan_size, 2)) self.style_count = 2 * log_size - 2 def forward(self, x): x = self.input_layer(x) x = self.body(x) x = self.output_pool(x) x = x.view(-1, 512) x = self.linear(x) return x.repeat(self.style_count, 1, 1).permute(1, 0, 2)