Spaces:
Build error
Build error
# Copyright 2020 Erik Härkönen. All rights reserved. | |
# This file is licensed to you under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. You may obtain a copy | |
# of the License at http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software distributed under | |
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS | |
# OF ANY KIND, either express or implied. See the License for the specific language | |
# governing permissions and limitations under the License. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from collections import OrderedDict | |
from pathlib import Path | |
import requests | |
import pickle | |
import sys | |
import numpy as np | |
# Reimplementation of StyleGAN in PyTorch | |
# Source: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb | |
class MyLinear(nn.Module): | |
"""Linear layer with equalized learning rate and custom learning rate multiplier.""" | |
def __init__(self, input_size, output_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True): | |
super().__init__() | |
he_std = gain * input_size**(-0.5) # He init | |
# Equalized learning rate and custom learning rate multiplier. | |
if use_wscale: | |
init_std = 1.0 / lrmul | |
self.w_mul = he_std * lrmul | |
else: | |
init_std = he_std / lrmul | |
self.w_mul = lrmul | |
self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std) | |
if bias: | |
self.bias = torch.nn.Parameter(torch.zeros(output_size)) | |
self.b_mul = lrmul | |
else: | |
self.bias = None | |
def forward(self, x): | |
bias = self.bias | |
if bias is not None: | |
bias = bias * self.b_mul | |
return F.linear(x, self.weight * self.w_mul, bias) | |
class MyConv2d(nn.Module): | |
"""Conv layer with equalized learning rate and custom learning rate multiplier.""" | |
def __init__(self, input_channels, output_channels, kernel_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True, | |
intermediate=None, upscale=False): | |
super().__init__() | |
if upscale: | |
self.upscale = Upscale2d() | |
else: | |
self.upscale = None | |
he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init | |
self.kernel_size = kernel_size | |
if use_wscale: | |
init_std = 1.0 / lrmul | |
self.w_mul = he_std * lrmul | |
else: | |
init_std = he_std / lrmul | |
self.w_mul = lrmul | |
self.weight = torch.nn.Parameter(torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std) | |
if bias: | |
self.bias = torch.nn.Parameter(torch.zeros(output_channels)) | |
self.b_mul = lrmul | |
else: | |
self.bias = None | |
self.intermediate = intermediate | |
def forward(self, x): | |
bias = self.bias | |
if bias is not None: | |
bias = bias * self.b_mul | |
have_convolution = False | |
if self.upscale is not None and min(x.shape[2:]) * 2 >= 128: | |
# this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way | |
# this really needs to be cleaned up and go into the conv... | |
w = self.weight * self.w_mul | |
w = w.permute(1, 0, 2, 3) | |
# probably applying a conv on w would be more efficient. also this quadruples the weight (average)?! | |
w = F.pad(w, (1,1,1,1)) | |
w = w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] | |
x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1)-1)//2) | |
have_convolution = True | |
elif self.upscale is not None: | |
x = self.upscale(x) | |
if not have_convolution and self.intermediate is None: | |
return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size//2) | |
elif not have_convolution: | |
x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size//2) | |
if self.intermediate is not None: | |
x = self.intermediate(x) | |
if bias is not None: | |
x = x + bias.view(1, -1, 1, 1) | |
return x | |
class NoiseLayer(nn.Module): | |
"""adds noise. noise is per pixel (constant over channels) with per-channel weight""" | |
def __init__(self, channels): | |
super().__init__() | |
self.weight = nn.Parameter(torch.zeros(channels)) | |
self.noise = None | |
def forward(self, x, noise=None): | |
if noise is None and self.noise is None: | |
noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype) | |
elif noise is None: | |
# here is a little trick: if you get all the noiselayers and set each | |
# modules .noise attribute, you can have pre-defined noise. | |
# Very useful for analysis | |
noise = self.noise | |
x = x + self.weight.view(1, -1, 1, 1) * noise | |
return x | |
class StyleMod(nn.Module): | |
def __init__(self, latent_size, channels, use_wscale): | |
super(StyleMod, self).__init__() | |
self.lin = MyLinear(latent_size, | |
channels * 2, | |
gain=1.0, use_wscale=use_wscale) | |
def forward(self, x, latent): | |
style = self.lin(latent) # style => [batch_size, n_channels*2] | |
shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1] | |
style = style.view(shape) # [batch_size, 2, n_channels, ...] | |
x = x * (style[:, 0] + 1.) + style[:, 1] | |
return x | |
class PixelNormLayer(nn.Module): | |
def __init__(self, epsilon=1e-8): | |
super().__init__() | |
self.epsilon = epsilon | |
def forward(self, x): | |
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon) | |
class BlurLayer(nn.Module): | |
def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1): | |
super(BlurLayer, self).__init__() | |
kernel=[1, 2, 1] | |
kernel = torch.tensor(kernel, dtype=torch.float32) | |
kernel = kernel[:, None] * kernel[None, :] | |
kernel = kernel[None, None] | |
if normalize: | |
kernel = kernel / kernel.sum() | |
if flip: | |
kernel = kernel[:, :, ::-1, ::-1] | |
self.register_buffer('kernel', kernel) | |
self.stride = stride | |
def forward(self, x): | |
# expand kernel channels | |
kernel = self.kernel.expand(x.size(1), -1, -1, -1) | |
x = F.conv2d( | |
x, | |
kernel, | |
stride=self.stride, | |
padding=int((self.kernel.size(2)-1)/2), | |
groups=x.size(1) | |
) | |
return x | |
def upscale2d(x, factor=2, gain=1): | |
assert x.dim() == 4 | |
if gain != 1: | |
x = x * gain | |
if factor != 1: | |
shape = x.shape | |
x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor) | |
x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3]) | |
return x | |
class Upscale2d(nn.Module): | |
def __init__(self, factor=2, gain=1): | |
super().__init__() | |
assert isinstance(factor, int) and factor >= 1 | |
self.gain = gain | |
self.factor = factor | |
def forward(self, x): | |
return upscale2d(x, factor=self.factor, gain=self.gain) | |
class G_mapping(nn.Sequential): | |
def __init__(self, nonlinearity='lrelu', use_wscale=True): | |
act, gain = {'relu': (torch.relu, np.sqrt(2)), | |
'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity] | |
layers = [ | |
('pixel_norm', PixelNormLayer()), | |
('dense0', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
('dense0_act', act), | |
('dense1', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
('dense1_act', act), | |
('dense2', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
('dense2_act', act), | |
('dense3', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
('dense3_act', act), | |
('dense4', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
('dense4_act', act), | |
('dense5', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
('dense5_act', act), | |
('dense6', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
('dense6_act', act), | |
('dense7', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), | |
('dense7_act', act) | |
] | |
super().__init__(OrderedDict(layers)) | |
def forward(self, x): | |
return super().forward(x) | |
class Truncation(nn.Module): | |
def __init__(self, avg_latent, max_layer=8, threshold=0.7): | |
super().__init__() | |
self.max_layer = max_layer | |
self.threshold = threshold | |
self.register_buffer('avg_latent', avg_latent) | |
def forward(self, x): | |
assert x.dim() == 3 | |
interp = torch.lerp(self.avg_latent, x, self.threshold) | |
do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1) | |
return torch.where(do_trunc, interp, x) | |
class LayerEpilogue(nn.Module): | |
"""Things to do at the end of each layer.""" | |
def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer): | |
super().__init__() | |
layers = [] | |
if use_noise: | |
layers.append(('noise', NoiseLayer(channels))) | |
layers.append(('activation', activation_layer)) | |
if use_pixel_norm: | |
layers.append(('pixel_norm', PixelNorm())) | |
if use_instance_norm: | |
layers.append(('instance_norm', nn.InstanceNorm2d(channels))) | |
self.top_epi = nn.Sequential(OrderedDict(layers)) | |
if use_styles: | |
self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale) | |
else: | |
self.style_mod = None | |
def forward(self, x, dlatents_in_slice=None): | |
x = self.top_epi(x) | |
if self.style_mod is not None: | |
x = self.style_mod(x, dlatents_in_slice) | |
else: | |
assert dlatents_in_slice is None | |
return x | |
class InputBlock(nn.Module): | |
def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer): | |
super().__init__() | |
self.const_input_layer = const_input_layer | |
self.nf = nf | |
if self.const_input_layer: | |
# called 'const' in tf | |
self.const = nn.Parameter(torch.ones(1, nf, 4, 4)) | |
self.bias = nn.Parameter(torch.ones(nf)) | |
else: | |
self.dense = MyLinear(dlatent_size, nf*16, gain=gain/4, use_wscale=use_wscale) # tweak gain to match the official implementation of Progressing GAN | |
self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) | |
self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale) | |
self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) | |
def forward(self, dlatents_in_range): | |
batch_size = dlatents_in_range.size(0) | |
if self.const_input_layer: | |
x = self.const.expand(batch_size, -1, -1, -1) | |
x = x + self.bias.view(1, -1, 1, 1) | |
else: | |
x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4) | |
x = self.epi1(x, dlatents_in_range[:, 0]) | |
x = self.conv(x) | |
x = self.epi2(x, dlatents_in_range[:, 1]) | |
return x | |
class GSynthesisBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer): | |
# 2**res x 2**res # res = 3..resolution_log2 | |
super().__init__() | |
if blur_filter: | |
blur = BlurLayer(blur_filter) | |
else: | |
blur = None | |
self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale, | |
intermediate=blur, upscale=True) | |
self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) | |
self.conv1 = MyConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale) | |
self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer) | |
def forward(self, x, dlatents_in_range): | |
x = self.conv0_up(x) | |
x = self.epi1(x, dlatents_in_range[:, 0]) | |
x = self.conv1(x) | |
x = self.epi2(x, dlatents_in_range[:, 1]) | |
return x | |
class G_synthesis(nn.Module): | |
def __init__(self, | |
dlatent_size = 512, # Disentangled latent (W) dimensionality. | |
num_channels = 3, # Number of output color channels. | |
resolution = 1024, # Output resolution. | |
fmap_base = 8192, # Overall multiplier for the number of feature maps. | |
fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution. | |
fmap_max = 512, # Maximum number of feature maps in any layer. | |
use_styles = True, # Enable style inputs? | |
const_input_layer = True, # First layer is a learned constant? | |
use_noise = True, # Enable noise inputs? | |
randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables. | |
nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu' | |
use_wscale = True, # Enable equalized learning rate? | |
use_pixel_norm = False, # Enable pixelwise feature vector normalization? | |
use_instance_norm = True, # Enable instance normalization? | |
dtype = torch.float32, # Data type to use for activations and outputs. | |
blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering. | |
): | |
super().__init__() | |
def nf(stage): | |
return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) | |
self.dlatent_size = dlatent_size | |
resolution_log2 = int(np.log2(resolution)) | |
assert resolution == 2**resolution_log2 and resolution >= 4 | |
act, gain = {'relu': (torch.relu, np.sqrt(2)), | |
'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity] | |
num_layers = resolution_log2 * 2 - 2 | |
num_styles = num_layers if use_styles else 1 | |
torgbs = [] | |
blocks = [] | |
for res in range(2, resolution_log2 + 1): | |
channels = nf(res-1) | |
name = '{s}x{s}'.format(s=2**res) | |
if res == 2: | |
blocks.append((name, | |
InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale, | |
use_noise, use_pixel_norm, use_instance_norm, use_styles, act))) | |
else: | |
blocks.append((name, | |
GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, act))) | |
last_channels = channels | |
self.torgb = MyConv2d(channels, num_channels, 1, gain=1, use_wscale=use_wscale) | |
self.blocks = nn.ModuleDict(OrderedDict(blocks)) | |
def forward(self, dlatents_in): | |
# Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size]. | |
# lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype) | |
batch_size = dlatents_in.size(0) | |
for i, m in enumerate(self.blocks.values()): | |
if i == 0: | |
x = m(dlatents_in[:, 2*i:2*i+2]) | |
else: | |
x = m(x, dlatents_in[:, 2*i:2*i+2]) | |
rgb = self.torgb(x) | |
return rgb | |
class StyleGAN_G(nn.Sequential): | |
def __init__(self, resolution, truncation=1.0): | |
self.resolution = resolution | |
self.layers = OrderedDict([ | |
('g_mapping', G_mapping()), | |
#('truncation', Truncation(avg_latent)), | |
('g_synthesis', G_synthesis(resolution=resolution)), | |
]) | |
super().__init__(self.layers) | |
def forward(self, x, latent_is_w=False): | |
if isinstance(x, list): | |
assert len(x) == 18, 'Must provide 1 or 18 latents' | |
if not latent_is_w: | |
x = [self.layers['g_mapping'].forward(l) for l in x] | |
x = torch.stack(x, dim=1) | |
else: | |
if not latent_is_w: | |
x = self.layers['g_mapping'].forward(x) | |
x = x.unsqueeze(1).expand(-1, 18, -1) | |
x = self.layers['g_synthesis'].forward(x) | |
return x | |
# From: https://github.com/lernapparat/lernapparat/releases/download/v2019-02-01/ | |
def load_weights(self, checkpoint): | |
self.load_state_dict(torch.load(checkpoint)) | |
def export_from_tf(self, pickle_path): | |
module_path = Path(__file__).parent / 'stylegan_tf' | |
sys.path.append(str(module_path.resolve())) | |
import dnnlib, dnnlib.tflib, pickle, torch, collections | |
dnnlib.tflib.init_tf() | |
weights = pickle.load(open(pickle_path,'rb')) | |
weights_pt = [collections.OrderedDict([(k, torch.from_numpy(v.value().eval())) for k,v in w.trainables.items()]) for w in weights] | |
#torch.save(weights_pt, pytorch_name) | |
# then on the PyTorch side run | |
state_G, state_D, state_Gs = weights_pt #torch.load('./karras2019stylegan-ffhq-1024x1024.pt') | |
def key_translate(k): | |
k = k.lower().split('/') | |
if k[0] == 'g_synthesis': | |
if not k[1].startswith('torgb'): | |
k.insert(1, 'blocks') | |
k = '.'.join(k) | |
k = (k.replace('const.const','const').replace('const.bias','bias').replace('const.stylemod','epi1.style_mod.lin') | |
.replace('const.noise.weight','epi1.top_epi.noise.weight') | |
.replace('conv.noise.weight','epi2.top_epi.noise.weight') | |
.replace('conv.stylemod','epi2.style_mod.lin') | |
.replace('conv0_up.noise.weight', 'epi1.top_epi.noise.weight') | |
.replace('conv0_up.stylemod','epi1.style_mod.lin') | |
.replace('conv1.noise.weight', 'epi2.top_epi.noise.weight') | |
.replace('conv1.stylemod','epi2.style_mod.lin') | |
.replace('torgb_lod0','torgb')) | |
else: | |
k = '.'.join(k) | |
return k | |
def weight_translate(k, w): | |
k = key_translate(k) | |
if k.endswith('.weight'): | |
if w.dim() == 2: | |
w = w.t() | |
elif w.dim() == 1: | |
pass | |
else: | |
assert w.dim() == 4 | |
w = w.permute(3, 2, 0, 1) | |
return w | |
# we delete the useless torgb filters | |
param_dict = {key_translate(k) : weight_translate(k, v) for k,v in state_Gs.items() if 'torgb_lod' not in key_translate(k)} | |
if 1: | |
sd_shapes = {k : v.shape for k,v in self.state_dict().items()} | |
param_shapes = {k : v.shape for k,v in param_dict.items() } | |
for k in list(sd_shapes)+list(param_shapes): | |
pds = param_shapes.get(k) | |
sds = sd_shapes.get(k) | |
if pds is None: | |
print ("sd only", k, sds) | |
elif sds is None: | |
print ("pd only", k, pds) | |
elif sds != pds: | |
print ("mismatch!", k, pds, sds) | |
self.load_state_dict(param_dict, strict=False) # needed for the blur kernels | |
torch.save(self.state_dict(), Path(pickle_path).with_suffix('.pt')) |