|
"""Discriminator from StyleGAN. https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py""" |
|
|
|
import functools |
|
import math |
|
from typing import Any, Tuple |
|
import flax.linen as nn |
|
from flax.linen.initializers import xavier_uniform |
|
import jax |
|
from jax import lax |
|
import jax.numpy as jnp |
|
import ml_collections |
|
|
|
default_kernel_init = xavier_uniform() |
|
|
|
def _conv_dimension_numbers(input_shape): |
|
"""Computes the dimension numbers based on the input shape.""" |
|
ndim = len(input_shape) |
|
lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) |
|
rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) |
|
out_spec = lhs_spec |
|
return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) |
|
|
|
|
|
class BlurPool2D(nn.Module): |
|
"""A layer to do channel-wise blurring + subsampling on 2D inputs. |
|
|
|
Reference: |
|
Zhang et al. Making Convolutional Networks Shift-Invariant Again. |
|
https://arxiv.org/pdf/1904.11486.pdf. |
|
""" |
|
filter_size: int = 4 |
|
strides: Tuple[int, int] = (2, 2) |
|
padding: str = 'SAME' |
|
|
|
def setup(self): |
|
if self.filter_size == 3: |
|
self.filter = [1., 2., 1.] |
|
elif self.filter_size == 4: |
|
self.filter = [1., 3., 3., 1.] |
|
elif self.filter_size == 5: |
|
self.filter = [1., 4., 6., 4., 1.] |
|
elif self.filter_size == 6: |
|
self.filter = [1., 5., 10., 10., 5., 1.] |
|
elif self.filter_size == 7: |
|
self.filter = [1., 6., 15., 20., 15., 6., 1.] |
|
else: |
|
raise ValueError('Only filter_size of 3, 4, 5, 6 or 7 is supported.') |
|
|
|
self.filter = jnp.array(self.filter, dtype=jnp.float32) |
|
self.filter = self.filter[:, None] * self.filter[None, :] |
|
with jax.default_matmul_precision('float32'): |
|
self.filter /= jnp.sum(self.filter) |
|
self.filter = jnp.reshape( |
|
self.filter, [self.filter.shape[0], self.filter.shape[1], 1, 1]) |
|
|
|
@nn.compact |
|
def __call__(self, inputs): |
|
channel_num = inputs.shape[-1] |
|
dimension_numbers = _conv_dimension_numbers(inputs.shape) |
|
depthwise_filter = jnp.tile(self.filter, [1, 1, 1, channel_num]) |
|
with jax.default_matmul_precision('float32'): |
|
outputs = lax.conv_general_dilated(inputs, depthwise_filter, self.strides, |
|
self.padding, feature_group_count=channel_num, dimension_numbers=dimension_numbers) |
|
return outputs |
|
|
|
class ResBlock(nn.Module): |
|
"""StyleGAN ResBlock for D. |
|
|
|
https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py#L618 |
|
""" |
|
filters: int |
|
activation_fn: Any |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
input_dim = x.shape[-1] |
|
residual = x |
|
x = nn.Conv(input_dim, (3, 3), kernel_init=default_kernel_init)(x) |
|
x = self.activation_fn(x) |
|
x = BlurPool2D(filter_size=4)(x) |
|
residual = BlurPool2D(filter_size=4)(residual) |
|
residual = nn.Conv(self.filters, (1, 1), use_bias=False, kernel_init=default_kernel_init)(residual) |
|
x = nn.Conv(self.filters, (3, 3), kernel_init=default_kernel_init)(x) |
|
x = self.activation_fn(x) |
|
out = (residual + x) / math.sqrt(2) |
|
return out |
|
|
|
|
|
class Discriminator(nn.Module): |
|
"""StyleGAN Discriminator.""" |
|
config: ml_collections.ConfigDict |
|
|
|
def setup(self): |
|
self.input_size = self.config.image_size |
|
self.activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2) |
|
self.channel_multiplier = 1 |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
filters = { |
|
4: 512, |
|
8: 512, |
|
16: 512, |
|
32: 512, |
|
64: 256 * self.channel_multiplier, |
|
128: 128 * self.channel_multiplier, |
|
256: 64 * self.channel_multiplier, |
|
512: 32 * self.channel_multiplier, |
|
1024: 16 * self.channel_multiplier, |
|
} |
|
x = nn.Conv(filters[self.input_size], (3, 3), kernel_init=default_kernel_init)(x) |
|
x = self.activation_fn(x) |
|
log_size = int(math.log2(self.input_size)) |
|
for i in range(log_size, 2, -1): |
|
x = ResBlock(filters[2**(i - 1)], self.activation_fn)(x) |
|
print("Disc shape", x.shape) |
|
x = nn.Conv(filters[4], (3, 3), kernel_init=default_kernel_init)(x) |
|
x = self.activation_fn(x) |
|
x = x.reshape((x.shape[0], -1)) |
|
x = nn.Dense(filters[4], kernel_init=default_kernel_init)(x) |
|
x = self.activation_fn(x) |
|
x = nn.Dense(1, kernel_init=default_kernel_init)(x) |
|
return x |