|
"""Layers used for up-sampling or down-sampling images.
|
|
|
|
Many functions are ported from https://github.com/NVlabs/stylegan2.
|
|
"""
|
|
|
|
import torch.nn as nn
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from .op import upfirdn2d
|
|
|
|
|
|
|
|
def get_weight(module,
|
|
shape,
|
|
weight_var='weight',
|
|
kernel_init=None):
|
|
"""Get/create weight tensor for a convolution or fully-connected layer."""
|
|
|
|
return module.param(weight_var, kernel_init, shape)
|
|
|
|
|
|
class Conv2d(nn.Module):
|
|
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
|
|
|
|
def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
|
|
resample_kernel=(1, 3, 3, 1),
|
|
use_bias=True,
|
|
kernel_init=None):
|
|
super().__init__()
|
|
assert not (up and down)
|
|
assert kernel >= 1 and kernel % 2 == 1
|
|
self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
|
|
if kernel_init is not None:
|
|
self.weight.data = kernel_init(self.weight.data.shape)
|
|
if use_bias:
|
|
self.bias = nn.Parameter(torch.zeros(out_ch))
|
|
|
|
self.up = up
|
|
self.down = down
|
|
self.resample_kernel = resample_kernel
|
|
self.kernel = kernel
|
|
self.use_bias = use_bias
|
|
|
|
def forward(self, x):
|
|
if self.up:
|
|
x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
|
|
elif self.down:
|
|
x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
|
|
else:
|
|
x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
|
|
|
|
if self.use_bias:
|
|
x = x + self.bias.reshape(1, -1, 1, 1)
|
|
|
|
return x
|
|
|
|
|
|
def naive_upsample_2d(x, factor=2):
|
|
_N, C, H, W = x.shape
|
|
x = torch.reshape(x, (-1, C, H, 1, W, 1))
|
|
x = x.repeat(1, 1, 1, factor, 1, factor)
|
|
return torch.reshape(x, (-1, C, H * factor, W * factor))
|
|
|
|
|
|
def naive_downsample_2d(x, factor=2):
|
|
_N, C, H, W = x.shape
|
|
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
|
|
return torch.mean(x, dim=(3, 5))
|
|
|
|
|
|
def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
|
|
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
|
|
|
|
Padding is performed only once at the beginning, not between the
|
|
operations.
|
|
The fused op is considerably more efficient than performing the same
|
|
calculation
|
|
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
|
Args:
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
|
C]`.
|
|
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
|
outChannels]`. Grouped convolution can be performed by `inChannels =
|
|
x.shape[0] // numGroups`.
|
|
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
|
(separable). The default is `[1] * factor`, which corresponds to
|
|
nearest-neighbor upsampling.
|
|
factor: Integer upsampling factor (default: 2).
|
|
gain: Scaling factor for signal magnitude (default: 1.0).
|
|
|
|
Returns:
|
|
Tensor of the shape `[N, C, H * factor, W * factor]` or
|
|
`[N, H * factor, W * factor, C]`, and same datatype as `x`.
|
|
"""
|
|
|
|
assert isinstance(factor, int) and factor >= 1
|
|
|
|
|
|
assert len(w.shape) == 4
|
|
convH = w.shape[2]
|
|
convW = w.shape[3]
|
|
inC = w.shape[1]
|
|
outC = w.shape[0]
|
|
|
|
assert convW == convH
|
|
|
|
|
|
if k is None:
|
|
k = [1] * factor
|
|
k = _setup_kernel(k) * (gain * (factor ** 2))
|
|
p = (k.shape[0] - factor) - (convW - 1)
|
|
|
|
stride = (factor, factor)
|
|
|
|
|
|
stride = [1, 1, factor, factor]
|
|
output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
|
|
output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
|
|
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW)
|
|
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
|
num_groups = _shape(x, 1) // inC
|
|
|
|
|
|
w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
|
|
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
|
|
w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
|
|
|
|
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return upfirdn2d(x, torch.tensor(k, device=x.device),
|
|
pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
|
|
|
|
|
|
def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
|
|
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
|
|
|
|
Padding is performed only once at the beginning, not between the operations.
|
|
The fused op is considerably more efficient than performing the same
|
|
calculation
|
|
using standard TensorFlow ops. It supports gradients of arbitrary order.
|
|
Args:
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
|
C]`.
|
|
w: Weight tensor of the shape `[filterH, filterW, inChannels,
|
|
outChannels]`. Grouped convolution can be performed by `inChannels =
|
|
x.shape[0] // numGroups`.
|
|
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
|
(separable). The default is `[1] * factor`, which corresponds to
|
|
average pooling.
|
|
factor: Integer downsampling factor (default: 2).
|
|
gain: Scaling factor for signal magnitude (default: 1.0).
|
|
|
|
Returns:
|
|
Tensor of the shape `[N, C, H // factor, W // factor]` or
|
|
`[N, H // factor, W // factor, C]`, and same datatype as `x`.
|
|
"""
|
|
|
|
assert isinstance(factor, int) and factor >= 1
|
|
_outC, _inC, convH, convW = w.shape
|
|
assert convW == convH
|
|
if k is None:
|
|
k = [1] * factor
|
|
k = _setup_kernel(k) * gain
|
|
p = (k.shape[0] - factor) + (convW - 1)
|
|
s = [factor, factor]
|
|
x = upfirdn2d(x, torch.tensor(k, device=x.device),
|
|
pad=((p + 1) // 2, p // 2))
|
|
return F.conv2d(x, w, stride=s, padding=0)
|
|
|
|
|
|
def _setup_kernel(k):
|
|
k = np.asarray(k, dtype=np.float32)
|
|
if k.ndim == 1:
|
|
k = np.outer(k, k)
|
|
k /= np.sum(k)
|
|
assert k.ndim == 2
|
|
assert k.shape[0] == k.shape[1]
|
|
return k
|
|
|
|
|
|
def _shape(x, dim):
|
|
return x.shape[dim]
|
|
|
|
|
|
def upsample_2d(x, k=None, factor=2, gain=1):
|
|
r"""Upsample a batch of 2D images with the given filter.
|
|
|
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
|
and upsamples each image with the given filter. The filter is normalized so
|
|
that
|
|
if the input pixels are constant, they will be scaled by the specified
|
|
`gain`.
|
|
Pixels outside the image are assumed to be zero, and the filter is padded
|
|
with
|
|
zeros so that its shape is a multiple of the upsampling factor.
|
|
Args:
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
|
C]`.
|
|
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
|
(separable). The default is `[1] * factor`, which corresponds to
|
|
nearest-neighbor upsampling.
|
|
factor: Integer upsampling factor (default: 2).
|
|
gain: Scaling factor for signal magnitude (default: 1.0).
|
|
|
|
Returns:
|
|
Tensor of the shape `[N, C, H * factor, W * factor]`
|
|
"""
|
|
assert isinstance(factor, int) and factor >= 1
|
|
if k is None:
|
|
k = [1] * factor
|
|
k = _setup_kernel(k) * (gain * (factor ** 2))
|
|
p = k.shape[0] - factor
|
|
return upfirdn2d(x, torch.tensor(k, device=x.device),
|
|
up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
|
|
|
|
|
|
def downsample_2d(x, k=None, factor=2, gain=1):
|
|
r"""Downsample a batch of 2D images with the given filter.
|
|
|
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
|
|
and downsamples each image with the given filter. The filter is normalized
|
|
so that
|
|
if the input pixels are constant, they will be scaled by the specified
|
|
`gain`.
|
|
Pixels outside the image are assumed to be zero, and the filter is padded
|
|
with
|
|
zeros so that its shape is a multiple of the downsampling factor.
|
|
Args:
|
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
|
|
C]`.
|
|
k: FIR filter of the shape `[firH, firW]` or `[firN]`
|
|
(separable). The default is `[1] * factor`, which corresponds to
|
|
average pooling.
|
|
factor: Integer downsampling factor (default: 2).
|
|
gain: Scaling factor for signal magnitude (default: 1.0).
|
|
|
|
Returns:
|
|
Tensor of the shape `[N, C, H // factor, W // factor]`
|
|
"""
|
|
|
|
assert isinstance(factor, int) and factor >= 1
|
|
if k is None:
|
|
k = [1] * factor
|
|
k = _setup_kernel(k) * gain
|
|
p = k.shape[0] - factor
|
|
return upfirdn2d(x, torch.tensor(k, device=x.device),
|
|
down=factor, pad=((p + 1) // 2, p // 2))
|
|
|