|
|
|
"""Contains the implementation of discriminator described in StyleGAN. |
|
|
|
Paper: https://arxiv.org/pdf/1812.04948.pdf |
|
|
|
Official TensorFlow implementation: https://github.com/NVlabs/stylegan |
|
""" |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.cuda.amp import autocast |
|
|
|
__all__ = ['StyleGANDiscriminator'] |
|
|
|
|
|
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024] |
|
|
|
|
|
_FUSED_SCALE_ALLOWED = [True, False, 'auto'] |
|
|
|
|
|
|
|
class StyleGANDiscriminator(nn.Module): |
|
"""Defines the discriminator network in StyleGAN. |
|
|
|
NOTE: The discriminator takes images with `RGB` channel order and pixel |
|
range [-1, 1] as inputs. |
|
|
|
Settings for the backbone: |
|
|
|
(1) resolution: The resolution of the input image. (default: -1) |
|
(2) init_res: Smallest resolution of the convolutional backbone. |
|
(default: 4) |
|
(3) image_channels: Number of channels of the input image. (default: 3) |
|
(4) fused_scale: The strategy of fusing `conv2d` and `downsample` as one |
|
operator. `True` means blocks from all resolutions will fuse. `False` |
|
means blocks from all resolutions will not fuse. `auto` means blocks |
|
from resolutions higher than (or equal to) `fused_scale_res` will fuse. |
|
(default: `auto`) |
|
(5) fused_scale_res: Minimum resolution to fuse `conv2d` and `downsample` |
|
as one operator. This field only takes effect if `fused_scale` is set |
|
as `auto`. (default: 128) |
|
(6) use_wscale: Whether to use weight scaling. (default: True) |
|
(7) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0)) |
|
(8) lr_mul: Learning rate multiplier for backbone. (default: 1.0) |
|
(9) mbstd_groups: Group size for the minibatch standard deviation layer. |
|
`0` means disable. (default: 4) |
|
(10) mbstd_channels: Number of new channels (appended to the original |
|
feature map) after the minibatch standard deviation layer. (default: 1) |
|
(11) fmaps_base: Factor to control number of feature maps for each layer. |
|
(default: 16 << 10) |
|
(12) fmaps_max: Maximum number of feature maps in each layer. (default: 512) |
|
(13) filter_kernel: Kernel used for filtering (e.g., downsampling). |
|
(default: (1, 2, 1)) |
|
(14) eps: A small value to avoid divide overflow. (default: 1e-8) |
|
|
|
Settings for conditional model: |
|
|
|
(1) label_dim: Dimension of the additional label for conditional generation. |
|
In one-hot conditioning case, it is equal to the number of classes. If |
|
set to 0, conditioning training will be disabled. (default: 0) |
|
|
|
Runtime settings: |
|
|
|
(1) enable_amp: Whether to enable automatic mixed precision training. |
|
(default: False) |
|
""" |
|
|
|
def __init__(self, |
|
|
|
resolution=-1, |
|
init_res=4, |
|
image_channels=3, |
|
fused_scale='auto', |
|
fused_scale_res=128, |
|
use_wscale=True, |
|
wscale_gain=np.sqrt(2.0), |
|
lr_mul=1.0, |
|
mbstd_groups=4, |
|
mbstd_channels=1, |
|
fmaps_base=16 << 10, |
|
fmaps_max=512, |
|
filter_kernel=(1, 2, 1), |
|
eps=1e-8, |
|
|
|
label_dim=0): |
|
"""Initializes with basic settings. |
|
|
|
Raises: |
|
ValueError: If the `resolution` is not supported, or `fused_scale` |
|
is not supported. |
|
""" |
|
super().__init__() |
|
|
|
if resolution not in _RESOLUTIONS_ALLOWED: |
|
raise ValueError(f'Invalid resolution: `{resolution}`!\n' |
|
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.') |
|
if fused_scale not in _FUSED_SCALE_ALLOWED: |
|
raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n' |
|
f'Options allowed: {_FUSED_SCALE_ALLOWED}.') |
|
|
|
self.init_res = init_res |
|
self.init_res_log2 = int(np.log2(init_res)) |
|
self.resolution = resolution |
|
self.final_res_log2 = int(np.log2(resolution)) |
|
self.image_channels = image_channels |
|
self.fused_scale = fused_scale |
|
self.fused_scale_res = fused_scale_res |
|
self.use_wscale = use_wscale |
|
self.wscale_gain = wscale_gain |
|
self.lr_mul = lr_mul |
|
self.mbstd_groups = mbstd_groups |
|
self.mbstd_channels = mbstd_channels |
|
self.fmaps_base = fmaps_base |
|
self.fmaps_max = fmaps_max |
|
self.filter_kernel = filter_kernel |
|
self.eps = eps |
|
self.label_dim = label_dim |
|
|
|
|
|
self.register_buffer('lod', torch.zeros(())) |
|
self.pth_to_tf_var_mapping = {'lod': 'lod'} |
|
|
|
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1): |
|
res = 2 ** res_log2 |
|
in_channels = self.get_nf(res) |
|
out_channels = self.get_nf(res // 2) |
|
block_idx = self.final_res_log2 - res_log2 |
|
|
|
|
|
self.add_module( |
|
f'input{block_idx}', |
|
ConvLayer(in_channels=image_channels, |
|
out_channels=in_channels, |
|
kernel_size=1, |
|
add_bias=True, |
|
scale_factor=1, |
|
fused_scale=False, |
|
filter_kernel=None, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
activation_type='lrelu')) |
|
self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = ( |
|
f'FromRGB_lod{block_idx}/weight') |
|
self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = ( |
|
f'FromRGB_lod{block_idx}/bias') |
|
|
|
|
|
if res != self.init_res: |
|
|
|
layer_name = f'layer{2 * block_idx}' |
|
self.add_module( |
|
layer_name, |
|
ConvLayer(in_channels=in_channels, |
|
out_channels=in_channels, |
|
kernel_size=3, |
|
add_bias=True, |
|
scale_factor=1, |
|
fused_scale=False, |
|
filter_kernel=None, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
activation_type='lrelu')) |
|
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
|
f'{res}x{res}/Conv0/weight') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
|
f'{res}x{res}/Conv0/bias') |
|
|
|
|
|
layer_name = f'layer{2 * block_idx + 1}' |
|
self.add_module( |
|
layer_name, |
|
ConvLayer(in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
add_bias=True, |
|
scale_factor=2, |
|
fused_scale=(res >= fused_scale_res |
|
if fused_scale == 'auto' |
|
else fused_scale), |
|
filter_kernel=filter_kernel, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
activation_type='lrelu')) |
|
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
|
f'{res}x{res}/Conv1_down/weight') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
|
f'{res}x{res}/Conv1_down/bias') |
|
|
|
|
|
else: |
|
self.mbstd = MiniBatchSTDLayer(groups=mbstd_groups, |
|
new_channels=mbstd_channels, |
|
eps=eps) |
|
|
|
|
|
layer_name = f'layer{2 * block_idx}' |
|
self.add_module( |
|
layer_name, |
|
ConvLayer(in_channels=in_channels + mbstd_channels, |
|
out_channels=in_channels, |
|
kernel_size=3, |
|
add_bias=True, |
|
scale_factor=1, |
|
fused_scale=False, |
|
filter_kernel=None, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
activation_type='lrelu')) |
|
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
|
f'{res}x{res}/Conv/weight') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
|
f'{res}x{res}/Conv/bias') |
|
|
|
|
|
layer_name = f'layer{2 * block_idx + 1}' |
|
self.add_module( |
|
f'layer{2 * block_idx + 1}', |
|
DenseLayer(in_channels=in_channels * res * res, |
|
out_channels=in_channels, |
|
add_bias=True, |
|
use_wscale=use_wscale, |
|
wscale_gain=wscale_gain, |
|
lr_mul=lr_mul, |
|
activation_type='lrelu')) |
|
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = ( |
|
f'{res}x{res}/Dense0/weight') |
|
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = ( |
|
f'{res}x{res}/Dense0/bias') |
|
|
|
|
|
self.output = DenseLayer(in_channels=in_channels, |
|
out_channels=max(label_dim, 1), |
|
add_bias=True, |
|
use_wscale=use_wscale, |
|
wscale_gain=1.0, |
|
lr_mul=lr_mul, |
|
activation_type='linear') |
|
self.pth_to_tf_var_mapping['output.weight'] = ( |
|
f'{res}x{res}/Dense1/weight') |
|
self.pth_to_tf_var_mapping['output.bias'] = ( |
|
f'{res}x{res}/Dense1/bias') |
|
|
|
def get_nf(self, res): |
|
"""Gets number of feature maps according to the given resolution.""" |
|
return min(self.fmaps_base // res, self.fmaps_max) |
|
|
|
def forward(self, image, label=None, lod=None, enable_amp=False): |
|
expected_shape = (self.image_channels, self.resolution, self.resolution) |
|
if image.ndim != 4 or image.shape[1:] != expected_shape: |
|
raise ValueError(f'The input tensor should be with shape ' |
|
f'[batch_size, channel, height, width], where ' |
|
f'`channel` equals to {self.image_channels}, ' |
|
f'`height`, `width` equal to {self.resolution}!\n' |
|
f'But `{image.shape}` is received!') |
|
|
|
lod = self.lod.item() if lod is None else lod |
|
if lod + self.init_res_log2 > self.final_res_log2: |
|
raise ValueError(f'Maximum level-of-details (lod) is ' |
|
f'{self.final_res_log2 - self.init_res_log2}, ' |
|
f'but `{lod}` is received!') |
|
|
|
if self.label_dim: |
|
if label is None: |
|
raise ValueError(f'Model requires an additional label ' |
|
f'(with dimension {self.label_dim}) as input, ' |
|
f'but no label is received!') |
|
batch = image.shape[0] |
|
if (label.ndim != 2 or label.shape != (batch, self.label_dim)): |
|
raise ValueError(f'Input label should be with shape ' |
|
f'[batch_size, label_dim], where ' |
|
f'`batch_size` equals to {batch}, and ' |
|
f'`label_dim` equals to {self.label_dim}!\n' |
|
f'But `{label.shape}` is received!') |
|
label = label.to(dtype=torch.float32) |
|
|
|
with autocast(enabled=enable_amp): |
|
for res_log2 in range( |
|
self.final_res_log2, self.init_res_log2 - 1, -1): |
|
block_idx = current_lod = self.final_res_log2 - res_log2 |
|
if current_lod <= lod < current_lod + 1: |
|
x = getattr(self, f'input{block_idx}')(image) |
|
elif current_lod - 1 < lod < current_lod: |
|
alpha = lod - np.floor(lod) |
|
y = getattr(self, f'input{block_idx}')(image) |
|
x = y * alpha + x * (1 - alpha) |
|
if lod < current_lod + 1: |
|
if res_log2 == self.init_res_log2: |
|
x = self.mbstd(x) |
|
x = getattr(self, f'layer{2 * block_idx}')(x) |
|
x = getattr(self, f'layer{2 * block_idx + 1}')(x) |
|
if lod > current_lod: |
|
image = F.avg_pool2d( |
|
image, kernel_size=2, stride=2, padding=0) |
|
x = self.output(x) |
|
|
|
if self.label_dim: |
|
x = (x * label).sum(dim=1, keepdim=True) |
|
|
|
results = { |
|
'score': x, |
|
'label': label |
|
} |
|
return results |
|
|
|
|
|
class MiniBatchSTDLayer(nn.Module): |
|
"""Implements the minibatch standard deviation layer.""" |
|
|
|
def __init__(self, groups, new_channels, eps): |
|
super().__init__() |
|
self.groups = groups |
|
self.new_channels = new_channels |
|
self.eps = eps |
|
|
|
def extra_repr(self): |
|
return (f'groups={self.groups}, ' |
|
f'new_channels={self.new_channels}, ' |
|
f'epsilon={self.eps}') |
|
|
|
def forward(self, x): |
|
if self.groups <= 1 or self.new_channels < 1: |
|
return x |
|
|
|
N, C, H, W = x.shape |
|
G = min(self.groups, N) |
|
nC = self.new_channels |
|
c = C // nC |
|
|
|
y = x.reshape(G, -1, nC, c, H, W) |
|
y = y - y.mean(dim=0) |
|
y = y.square().mean(dim=0) |
|
y = (y + self.eps).sqrt() |
|
y = y.mean(dim=(2, 3, 4)) |
|
y = y.reshape(-1, nC, 1, 1) |
|
y = y.repeat(G, 1, H, W) |
|
x = torch.cat((x, y), dim=1) |
|
|
|
return x |
|
|
|
|
|
class Blur(torch.autograd.Function): |
|
"""Defines blur operation with customized gradient computation.""" |
|
|
|
@staticmethod |
|
def forward(ctx, x, kernel): |
|
assert kernel.shape[2] == 3 and kernel.shape[3] == 3 |
|
ctx.save_for_backward(kernel) |
|
y = F.conv2d(input=x, |
|
weight=kernel, |
|
bias=None, |
|
stride=1, |
|
padding=1, |
|
groups=x.shape[1]) |
|
return y |
|
|
|
@staticmethod |
|
def backward(ctx, dy): |
|
kernel, = ctx.saved_tensors |
|
dx = BlurBackPropagation.apply(dy, kernel) |
|
return dx, None, None |
|
|
|
|
|
class BlurBackPropagation(torch.autograd.Function): |
|
"""Defines the back propagation of blur operation. |
|
|
|
NOTE: This is used to speed up the backward of gradient penalty. |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, dy, kernel): |
|
ctx.save_for_backward(kernel) |
|
dx = F.conv2d(input=dy, |
|
weight=kernel.flip((2, 3)), |
|
bias=None, |
|
stride=1, |
|
padding=1, |
|
groups=dy.shape[1]) |
|
return dx |
|
|
|
@staticmethod |
|
def backward(ctx, ddx): |
|
kernel, = ctx.saved_tensors |
|
ddy = F.conv2d(input=ddx, |
|
weight=kernel, |
|
bias=None, |
|
stride=1, |
|
padding=1, |
|
groups=ddx.shape[1]) |
|
return ddy, None, None |
|
|
|
|
|
class ConvLayer(nn.Module): |
|
"""Implements the convolutional layer. |
|
|
|
If downsampling is needed (i.e., `scale_factor = 2`), the feature map will |
|
be filtered with `filter_kernel` first. If `fused_scale` is set as `True`, |
|
`conv2d` and `downsample` will be fused as one operator, using stride |
|
convolution. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
add_bias, |
|
scale_factor, |
|
fused_scale, |
|
filter_kernel, |
|
use_wscale, |
|
wscale_gain, |
|
lr_mul, |
|
activation_type): |
|
"""Initializes with layer settings. |
|
|
|
Args: |
|
in_channels: Number of channels of the input tensor. |
|
out_channels: Number of channels of the output tensor. |
|
kernel_size: Size of the convolutional kernels. |
|
add_bias: Whether to add bias onto the convolutional result. |
|
scale_factor: Scale factor for downsampling. `1` means skip |
|
downsampling. |
|
fused_scale: Whether to fuse `conv2d` and `downsample` as one |
|
operator, using stride convolution. |
|
filter_kernel: Kernel used for filtering. |
|
use_wscale: Whether to use weight scaling. |
|
wscale_gain: Gain factor for weight scaling. |
|
lr_mul: Learning multiplier for both weight and bias. |
|
activation_type: Type of activation. |
|
""" |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.kernel_size = kernel_size |
|
self.add_bias = add_bias |
|
self.scale_factor = scale_factor |
|
self.fused_scale = fused_scale |
|
self.filter_kernel = filter_kernel |
|
self.use_wscale = use_wscale |
|
self.wscale_gain = wscale_gain |
|
self.lr_mul = lr_mul |
|
self.activation_type = activation_type |
|
|
|
weight_shape = (out_channels, in_channels, kernel_size, kernel_size) |
|
fan_in = kernel_size * kernel_size * in_channels |
|
wscale = wscale_gain / np.sqrt(fan_in) |
|
if use_wscale: |
|
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) |
|
self.wscale = wscale * lr_mul |
|
else: |
|
self.weight = nn.Parameter( |
|
torch.randn(*weight_shape) * wscale / lr_mul) |
|
self.wscale = lr_mul |
|
|
|
if add_bias: |
|
self.bias = nn.Parameter(torch.zeros(out_channels)) |
|
self.bscale = lr_mul |
|
else: |
|
self.bias = None |
|
|
|
if scale_factor > 1: |
|
assert filter_kernel is not None |
|
kernel = np.array(filter_kernel, dtype=np.float32).reshape(1, -1) |
|
kernel = kernel.T.dot(kernel) |
|
kernel = kernel / np.sum(kernel) |
|
kernel = kernel[np.newaxis, np.newaxis] |
|
self.register_buffer('filter', torch.from_numpy(kernel)) |
|
|
|
if scale_factor > 1 and fused_scale: |
|
self.stride = scale_factor |
|
else: |
|
self.stride = 1 |
|
self.padding = kernel_size // 2 |
|
|
|
assert activation_type in ['linear', 'relu', 'lrelu'] |
|
|
|
def extra_repr(self): |
|
return (f'in_ch={self.in_channels}, ' |
|
f'out_ch={self.out_channels}, ' |
|
f'ksize={self.kernel_size}, ' |
|
f'wscale_gain={self.wscale_gain:.3f}, ' |
|
f'bias={self.add_bias}, ' |
|
f'lr_mul={self.lr_mul:.3f}, ' |
|
f'downsample={self.scale_factor}, ' |
|
f'fused_scale={self.fused_scale}, ' |
|
f'downsample_filter={self.filter_kernel}, ' |
|
f'act={self.activation_type}') |
|
|
|
def forward(self, x): |
|
if self.scale_factor > 1: |
|
|
|
|
|
|
|
with autocast(enabled=False): |
|
f = self.filter.repeat(self.in_channels, 1, 1, 1) |
|
x = Blur.apply(x.float(), f) |
|
|
|
weight = self.weight |
|
if self.wscale != 1.0: |
|
weight = weight * self.wscale |
|
bias = None |
|
if self.bias is not None: |
|
bias = self.bias |
|
if self.bscale != 1.0: |
|
bias = bias * self.bscale |
|
|
|
if self.scale_factor > 1 and self.fused_scale: |
|
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0) |
|
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] + |
|
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25 |
|
x = F.conv2d(x, |
|
weight=weight, |
|
bias=bias, |
|
stride=self.stride, |
|
padding=self.padding) |
|
if self.scale_factor > 1 and not self.fused_scale: |
|
down = self.scale_factor |
|
x = F.avg_pool2d(x, kernel_size=down, stride=down, padding=0) |
|
|
|
if self.activation_type == 'linear': |
|
pass |
|
elif self.activation_type == 'relu': |
|
x = F.relu(x, inplace=True) |
|
elif self.activation_type == 'lrelu': |
|
x = F.leaky_relu(x, negative_slope=0.2, inplace=True) |
|
else: |
|
raise NotImplementedError(f'Not implemented activation type ' |
|
f'`{self.activation_type}`!') |
|
|
|
return x |
|
|
|
|
|
class DenseLayer(nn.Module): |
|
"""Implements the dense layer.""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
add_bias, |
|
use_wscale, |
|
wscale_gain, |
|
lr_mul, |
|
activation_type): |
|
"""Initializes with layer settings. |
|
|
|
Args: |
|
in_channels: Number of channels of the input tensor. |
|
out_channels: Number of channels of the output tensor. |
|
add_bias: Whether to add bias onto the fully-connected result. |
|
use_wscale: Whether to use weight scaling. |
|
wscale_gain: Gain factor for weight scaling. |
|
lr_mul: Learning multiplier for both weight and bias. |
|
activation_type: Type of activation. |
|
""" |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.add_bias = add_bias |
|
self.use_wscale = use_wscale |
|
self.wscale_gain = wscale_gain |
|
self.lr_mul = lr_mul |
|
self.activation_type = activation_type |
|
|
|
weight_shape = (out_channels, in_channels) |
|
wscale = wscale_gain / np.sqrt(in_channels) |
|
if use_wscale: |
|
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul) |
|
self.wscale = wscale * lr_mul |
|
else: |
|
self.weight = nn.Parameter( |
|
torch.randn(*weight_shape) * wscale / lr_mul) |
|
self.wscale = lr_mul |
|
|
|
if add_bias: |
|
self.bias = nn.Parameter(torch.zeros(out_channels)) |
|
self.bscale = lr_mul |
|
else: |
|
self.bias = None |
|
|
|
assert activation_type in ['linear', 'relu', 'lrelu'] |
|
|
|
def extra_repr(self): |
|
return (f'in_ch={self.in_channels}, ' |
|
f'out_ch={self.out_channels}, ' |
|
f'wscale_gain={self.wscale_gain:.3f}, ' |
|
f'bias={self.add_bias}, ' |
|
f'lr_mul={self.lr_mul:.3f}, ' |
|
f'act={self.activation_type}') |
|
|
|
def forward(self, x): |
|
if x.ndim != 2: |
|
x = x.flatten(start_dim=1) |
|
|
|
weight = self.weight |
|
if self.wscale != 1.0: |
|
weight = weight * self.wscale |
|
bias = None |
|
if self.bias is not None: |
|
bias = self.bias |
|
if self.bscale != 1.0: |
|
bias = bias * self.bscale |
|
|
|
x = F.linear(x, weight=weight, bias=bias) |
|
|
|
if self.activation_type == 'linear': |
|
pass |
|
elif self.activation_type == 'relu': |
|
x = F.relu(x, inplace=True) |
|
elif self.activation_type == 'lrelu': |
|
x = F.leaky_relu(x, negative_slope=0.2, inplace=True) |
|
else: |
|
raise NotImplementedError(f'Not implemented activation type ' |
|
f'`{self.activation_type}`!') |
|
|
|
return x |
|
|
|
|
|
|