BerfScene / models /volumegan_discriminator.py
3v324v23's picture
init
2f85de4
# python3.7
"""Contains the implementation of discriminator described in VolumeGAN.
Paper: https://arxiv.org/pdf/2112.10759.pdf
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from third_party.stylegan2_official_ops import bias_act
from third_party.stylegan2_official_ops import upfirdn2d
from third_party.stylegan2_official_ops import conv2d_gradfix
__all__ = ['VolumeGANDiscriminator']
# Resolutions allowed.
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
# Architectures allowed.
_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
class VolumeGANDiscriminator(nn.Module):
"""Defines the discriminator network in VolumeGAN.
NOTE: The discriminator takes images with `RGB` channel order and pixel
range [-1, 1] as inputs.
Settings for the backbone:
(1) resolution: The resolution of the input image. (default: -1)
(2) init_res: The initial resolution to start with convolution. (default: 4)
(3) image_channels: Number of channels of the input image. (default: 3)
(4) architecture: Type of architecture. Support `origin`, `skip`, and
`resnet`. (default: `resnet`)
(5) use_wscale: Whether to use weight scaling. (default: True)
(6) wscale_gain: The factor to control weight scaling. (default: 1.0)
(7) lr_mul: Learning rate multiplier for backbone. (default: 1.0)
(8) mbstd_groups: Group size for the minibatch standard deviation layer.
`0` means disable. (default: 4)
(9) mbstd_channels: Number of new channels (appended to the original feature
map) after the minibatch standard deviation layer. (default: 1)
(10) fmaps_base: Factor to control number of feature maps for each layer.
(default: 32 << 10)
(11) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
(12) filter_kernel: Kernel used for filtering (e.g., downsampling).
(default: (1, 3, 3, 1))
(13) conv_clamp: A threshold to clamp the output of convolution layers to
avoid overflow under FP16 training. (default: None)
(14) eps: A small value to avoid divide overflow. (default: 1e-8)
Settings for conditional model:
(1) label_dim: Dimension of the additional label for conditional generation.
In one-hot conditioning case, it is equal to the number of classes. If
set to 0, conditioning training will be disabled. (default: 0)
(2) embedding_dim: Dimension of the embedding space, if needed.
(default: 512)
(3) embedding_bias: Whether to add bias to embedding learning.
(default: True)
(4) embedding_use_wscale: Whether to use weight scaling for embedding
learning. (default: True)
(5) embedding_lr_mul: Learning rate multiplier for the embedding learning.
(default: 1.0)
(6) normalize_embedding: Whether to normalize the embedding. (default: True)
(7) mapping_layers: Number of layers of the additional mapping network after
embedding. (default: 0)
(8) mapping_fmaps: Number of hidden channels of the additional mapping
network after embedding. (default: 512)
(9) mapping_use_wscale: Whether to use weight scaling for the additional
mapping network. (default: True)
(10) mapping_lr_mul: Learning rate multiplier for the additional mapping
network after embedding. (default: 0.1)
Runtime settings:
(1) fp16_res: Layers at resolution higher than (or equal to) this field will
use `float16` precision for computation. This is merely used for
acceleration. If set as `None`, all layers will use `float32` by
default. (default: None)
(2) impl: Implementation mode of some particular ops, e.g., `filtering`,
`bias_act`, etc. `cuda` means using the official CUDA implementation
from StyleGAN2, while `ref` means using the native PyTorch ops.
(default: `cuda`)
"""
def __init__(self,
# Settings for backbone.
resolution=-1,
init_res=4,
image_channels=3,
architecture='resnet',
use_wscale=True,
wscale_gain=1.0,
lr_mul=1.0,
mbstd_groups=4,
mbstd_channels=1,
fmaps_base=32 << 10,
fmaps_max=512,
filter_kernel=(1, 3, 3, 1),
conv_clamp=None,
eps=1e-8,
# Settings for conditional model.
label_dim=0,
embedding_dim=512,
embedding_bias=True,
embedding_use_wscale=True,
embedding_lr_mul=1.0,
normalize_embedding=True,
mapping_layers=0,
mapping_fmaps=512,
mapping_use_wscale=True,
mapping_lr_mul=0.1):
"""Initializes with basic settings.
Raises:
ValueError: If the `resolution` is not supported, or `architecture`
is not supported.
"""
super().__init__()
if resolution not in _RESOLUTIONS_ALLOWED:
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
architecture = architecture.lower()
if architecture not in _ARCHITECTURES_ALLOWED:
raise ValueError(f'Invalid architecture: `{architecture}`!\n'
f'Architectures allowed: '
f'{_ARCHITECTURES_ALLOWED}.')
self.init_res = init_res
self.init_res_log2 = int(np.log2(init_res))
self.resolution = resolution
self.final_res_log2 = int(np.log2(resolution))
self.image_channels = image_channels
self.architecture = architecture
self.use_wscale = use_wscale
self.wscale_gain = wscale_gain
self.lr_mul = lr_mul
self.mbstd_groups = mbstd_groups
self.mbstd_channels = mbstd_channels
self.fmaps_base = fmaps_base
self.fmaps_max = fmaps_max
self.filter_kernel = filter_kernel
self.conv_clamp = conv_clamp
self.eps = eps
self.label_dim = label_dim
self.embedding_dim = embedding_dim
self.embedding_bias = embedding_bias
self.embedding_use_wscale = embedding_use_wscale
self.embedding_lr_mul = embedding_lr_mul
self.normalize_embedding = normalize_embedding
self.mapping_layers = mapping_layers
self.mapping_fmaps = mapping_fmaps
self.mapping_use_wscale = mapping_use_wscale
self.mapping_lr_mul = mapping_lr_mul
self.pth_to_tf_var_mapping = {}
self.register_buffer('lod', torch.zeros(()))
# Embedding for conditional discrimination.
self.use_embedding = label_dim > 0 and embedding_dim > 0
if self.use_embedding:
self.embedding = DenseLayer(in_channels=label_dim,
out_channels=embedding_dim,
add_bias=embedding_bias,
init_bias=0.0,
use_wscale=embedding_use_wscale,
wscale_gain=wscale_gain,
lr_mul=embedding_lr_mul,
activation_type='linear')
self.pth_to_tf_var_mapping['embedding.weight'] = 'LabelEmbed/weight'
if self.embedding_bias:
self.pth_to_tf_var_mapping['embedding.bias'] = 'LabelEmbed/bias'
if self.normalize_embedding:
self.norm = PixelNormLayer(dim=1, eps=eps)
for i in range(mapping_layers):
in_channels = (embedding_dim if i == 0 else mapping_fmaps)
out_channels = (embedding_dim if i == (mapping_layers - 1) else
mapping_fmaps)
layer_name = f'mapping{i}'
self.add_module(layer_name,
DenseLayer(in_channels=in_channels,
out_channels=out_channels,
add_bias=True,
init_bias=0.0,
use_wscale=mapping_use_wscale,
wscale_gain=wscale_gain,
lr_mul=mapping_lr_mul,
activation_type='lrelu'))
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
f'Mapping{i}/weight')
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
f'Mapping{i}/bias')
# Convolutional backbone.
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
res = 2 ** res_log2
in_channels = self.get_nf(res)
out_channels = self.get_nf(res // 2)
block_idx = self.final_res_log2 - res_log2
# Input convolution layer for each resolution (if needed).
layer_name = f'input{block_idx}'
self.add_module(layer_name,
ConvLayer(in_channels=image_channels,
out_channels=in_channels,
kernel_size=1,
add_bias=True,
scale_factor=1,
filter_kernel=None,
use_wscale=use_wscale,
wscale_gain=wscale_gain,
lr_mul=lr_mul,
activation_type='lrelu',
conv_clamp=conv_clamp))
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
f'{res}x{res}/FromRGB/weight')
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
f'{res}x{res}/FromRGB/bias')
# Convolution block for each resolution (except the last one).
if res != self.init_res:
# First layer (kernel 3x3) without downsampling.
layer_name = f'layer{2 * block_idx}'
self.add_module(layer_name,
ConvLayer(in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
add_bias=True,
scale_factor=1,
filter_kernel=None,
use_wscale=use_wscale,
wscale_gain=wscale_gain,
lr_mul=lr_mul,
activation_type='lrelu',
conv_clamp=conv_clamp))
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
f'{res}x{res}/Conv0/weight')
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
f'{res}x{res}/Conv0/bias')
# Second layer (kernel 3x3) with downsampling
layer_name = f'layer{2 * block_idx + 1}'
self.add_module(layer_name,
ConvLayer(in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
add_bias=True,
scale_factor=2,
filter_kernel=filter_kernel,
use_wscale=use_wscale,
wscale_gain=wscale_gain,
lr_mul=lr_mul,
activation_type='lrelu',
conv_clamp=conv_clamp))
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
f'{res}x{res}/Conv1_down/weight')
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
f'{res}x{res}/Conv1_down/bias')
# Residual branch (kernel 1x1) with downsampling, without bias,
# with linear activation.
if self.architecture == 'resnet':
layer_name = f'residual{block_idx}'
self.add_module(layer_name,
ConvLayer(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
add_bias=False,
scale_factor=2,
filter_kernel=filter_kernel,
use_wscale=use_wscale,
wscale_gain=wscale_gain,
lr_mul=lr_mul,
activation_type='linear',
conv_clamp=None))
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
f'{res}x{res}/Skip/weight')
# Convolution block for last resolution.
else:
self.mbstd = MiniBatchSTDLayer(
groups=mbstd_groups, new_channels=mbstd_channels, eps=eps)
# First layer (kernel 3x3) without downsampling.
layer_name = f'layer{2 * block_idx}'
self.add_module(
layer_name,
ConvLayer(in_channels=in_channels + mbstd_channels,
out_channels=in_channels,
kernel_size=3,
add_bias=True,
scale_factor=1,
filter_kernel=None,
use_wscale=use_wscale,
wscale_gain=wscale_gain,
lr_mul=lr_mul,
activation_type='lrelu',
conv_clamp=conv_clamp))
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
f'{res}x{res}/Conv/weight')
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
f'{res}x{res}/Conv/bias')
# Second layer, as a fully-connected layer.
layer_name = f'layer{2 * block_idx + 1}'
self.add_module(layer_name,
DenseLayer(in_channels=in_channels * res * res,
out_channels=in_channels,
add_bias=True,
init_bias=0.0,
use_wscale=use_wscale,
wscale_gain=wscale_gain,
lr_mul=lr_mul,
activation_type='lrelu'))
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
f'{res}x{res}/Dense0/weight')
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
f'{res}x{res}/Dense0/bias')
# Final dense layer to output score.
self.output = DenseLayer(in_channels=in_channels,
out_channels=(embedding_dim
if self.use_embedding
else max(label_dim, 1)),
add_bias=True,
init_bias=0.0,
use_wscale=use_wscale,
wscale_gain=wscale_gain,
lr_mul=lr_mul,
activation_type='linear')
self.pth_to_tf_var_mapping['output.weight'] = 'Output/weight'
self.pth_to_tf_var_mapping['output.bias'] = 'Output/bias'
# Used for downsampling input image for `skip` architecture.
if self.architecture == 'skip':
self.register_buffer(
'filter', upfirdn2d.setup_filter(filter_kernel))
def get_nf(self, res):
"""Gets number of feature maps according to current resolution."""
return min(self.fmaps_base // res, self.fmaps_max)
def forward(self, image, lod=None, label=None, fp16_res=None, impl='cuda'):
# Check shape.
expected_shape = (self.image_channels, self.resolution, self.resolution)
if image.ndim != 4 or image.shape[1:] != expected_shape:
raise ValueError(f'The input tensor should be with shape '
f'[batch_size, channel, height, width], where '
f'`channel` equals to {self.image_channels}, '
f'`height`, `width` equal to {self.resolution}!\n'
f'But `{image.shape}` is received!')
if self.label_dim > 0:
if label is None:
raise ValueError(f'Model requires an additional label '
f'(with dimension {self.label_dim}) as input, '
f'but no label is received!')
batch_size = image.shape[0]
if label.ndim != 2 or label.shape != (batch_size, self.label_dim):
raise ValueError(f'Input label should be with shape '
f'[batch_size, label_dim], where '
f'`batch_size` equals to that of '
f'images ({image.shape[0]}) and '
f'`label_dim` equals to {self.label_dim}!\n'
f'But `{label.shape}` is received!')
label = label.to(dtype=torch.float32)
if self.use_embedding:
embed = self.embedding(label, impl=impl)
if self.normalize_embedding:
embed = self.norm(embed)
for i in range(self.mapping_layers):
embed = getattr(self, f'mapping{i}')(embed, impl=impl)
# Cast to `torch.float16` if needed.
if fp16_res is not None and self.resolution >= fp16_res:
image = image.to(torch.float16)
lod = self.lod.item() if lod is None else lod
x = self.input0(image, impl=impl)
for res_log2 in range(self.final_res_log2, self.init_res_log2, -1):
res = 2 ** res_log2
# Cast to `torch.float16` if needed.
if fp16_res is not None and res >= fp16_res:
x = x.to(torch.float16)
else:
x = x.to(torch.float32)
idx = cur_lod = self.final_res_log2 - res_log2 # Block index
if cur_lod <= lod < cur_lod + 1:
x = getattr(self, f'input{idx}')(image, impl=impl)
elif cur_lod - 1 < lod < cur_lod:
alpha = lod - np.floor(lod)
y = getattr(self, f'input{idx}')(image, impl=impl)
x = y * alpha + x * (1 - alpha)
if lod < cur_lod + 1:
if self.architecture == 'skip' and idx > 0:
image = upfirdn2d.downsample2d(image, self.filter, impl=impl)
# Cast to `torch.float16` if needed.
if fp16_res is not None and res >= fp16_res:
image = image.to(torch.float16)
else:
image = image.to(torch.float32)
y = getattr(self, f'input{idx}')(image, impl=impl)
x = x + y
if self.architecture == 'resnet':
residual = getattr(self, f'residual{idx}')(
x, runtime_gain=np.sqrt(0.5), impl=impl)
x = getattr(self, f'layer{2 * idx}')(x, impl=impl)
x = getattr(self, f'layer{2 * idx + 1}')(
x, runtime_gain=np.sqrt(0.5), impl=impl)
x = x + residual
else:
x = getattr(self, f'layer{2 * idx}')(x, impl=impl)
x = getattr(self, f'layer{2 * idx + 1}')(x, impl=impl)
if lod > cur_lod:
image = F.avg_pool2d(
image, kernel_size=2, stride=2, padding=0)
# Final output.
if fp16_res is not None: # Always use FP32 for the last block.
x = x.to(torch.float32)
if self.architecture == 'skip':
image = upfirdn2d.downsample2d(image, self.filter, impl=impl)
if fp16_res is not None: # Always use FP32 for the last block.
image = image.to(torch.float32)
y = getattr(self, f'input{idx}')(image, impl=impl)
x = x + y
x = self.mbstd(x)
x = getattr(self, f'layer{2 * idx + 2}')(x, impl=impl)
x = getattr(self, f'layer{2 * idx + 3}')(x, impl=impl)
x = self.output(x, impl=impl)
if self.use_embedding:
x = (x * embed).sum(dim=1, keepdim=True)
x = x / np.sqrt(self.embedding_dim)
elif self.label_dim > 0:
x = (x * label).sum(dim=1, keepdim=True)
results = {
'score': x,
'label': label
}
if self.use_embedding:
results['embedding'] = embed
return results
class PixelNormLayer(nn.Module):
"""Implements pixel-wise feature vector normalization layer."""
def __init__(self, dim, eps):
super().__init__()
self.dim = dim
self.eps = eps
def extra_repr(self):
return f'dim={self.dim}, epsilon={self.eps}'
def forward(self, x):
scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt()
return x * scale
class MiniBatchSTDLayer(nn.Module):
"""Implements the minibatch standard deviation layer."""
def __init__(self, groups, new_channels, eps):
super().__init__()
self.groups = groups
self.new_channels = new_channels
self.eps = eps
def extra_repr(self):
return (f'groups={self.groups}, '
f'new_channels={self.new_channels}, '
f'epsilon={self.eps}')
def forward(self, x):
if self.groups <= 1 or self.new_channels < 1:
return x
dtype = x.dtype
N, C, H, W = x.shape
G = min(self.groups, N) # Number of groups.
nC = self.new_channels # Number of channel groups.
c = C // nC # Channels per channel group.
y = x.reshape(G, -1, nC, c, H, W) # [GnFcHW]
y = y - y.mean(dim=0) # [GnFcHW]
y = y.square().mean(dim=0) # [nFcHW]
y = (y + self.eps).sqrt() # [nFcHW]
y = y.mean(dim=(2, 3, 4)) # [nF]
y = y.reshape(-1, nC, 1, 1) # [nF11]
y = y.repeat(G, 1, H, W) # [NFHW]
x = torch.cat((x, y), dim=1) # [N(C+F)HW]
assert x.dtype == dtype
return x
class ConvLayer(nn.Module):
"""Implements the convolutional layer.
If downsampling is needed (i.e., `scale_factor = 2`), the feature map will
be filtered with `filter_kernel` first.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
add_bias,
scale_factor,
filter_kernel,
use_wscale,
wscale_gain,
lr_mul,
activation_type,
conv_clamp):
"""Initializes with layer settings.
Args:
in_channels: Number of channels of the input tensor.
out_channels: Number of channels of the output tensor.
kernel_size: Size of the convolutional kernels.
add_bias: Whether to add bias onto the convolutional result.
scale_factor: Scale factor for downsampling. `1` means skip
downsampling.
filter_kernel: Kernel used for filtering.
use_wscale: Whether to use weight scaling.
wscale_gain: Gain factor for weight scaling.
lr_mul: Learning multiplier for both weight and bias.
activation_type: Type of activation.
conv_clamp: A threshold to clamp the output of convolution layers to
avoid overflow under FP16 training.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.add_bias = add_bias
self.scale_factor = scale_factor
self.filter_kernel = filter_kernel
self.use_wscale = use_wscale
self.wscale_gain = wscale_gain
self.lr_mul = lr_mul
self.activation_type = activation_type
self.conv_clamp = conv_clamp
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
fan_in = kernel_size * kernel_size * in_channels
wscale = wscale_gain / np.sqrt(fan_in)
if use_wscale:
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
self.wscale = wscale * lr_mul
else:
self.weight = nn.Parameter(
torch.randn(*weight_shape) * wscale / lr_mul)
self.wscale = lr_mul
if add_bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
self.bscale = lr_mul
else:
self.bias = None
self.act_gain = bias_act.activation_funcs[activation_type].def_gain
if scale_factor > 1:
assert filter_kernel is not None
self.register_buffer(
'filter', upfirdn2d.setup_filter(filter_kernel))
fh, fw = self.filter.shape
self.filter_padding = (
kernel_size // 2 + (fw - scale_factor + 1) // 2,
kernel_size // 2 + (fw - scale_factor) // 2,
kernel_size // 2 + (fh - scale_factor + 1) // 2,
kernel_size // 2 + (fh - scale_factor) // 2)
def extra_repr(self):
return (f'in_ch={self.in_channels}, '
f'out_ch={self.out_channels}, '
f'ksize={self.kernel_size}, '
f'wscale_gain={self.wscale_gain:.3f}, '
f'bias={self.add_bias}, '
f'lr_mul={self.lr_mul:.3f}, '
f'downsample={self.scale_factor}, '
f'downsample_filter={self.filter_kernel}, '
f'act={self.activation_type}, '
f'clamp={self.conv_clamp}')
def forward(self, x, runtime_gain=1.0, impl='cuda'):
dtype = x.dtype
weight = self.weight
if self.wscale != 1.0:
weight = weight * self.wscale
bias = None
if self.bias is not None:
bias = self.bias.to(dtype)
if self.bscale != 1.0:
bias = bias * self.bscale
if self.scale_factor == 1: # Native convolution without downsampling.
padding = self.kernel_size // 2
x = conv2d_gradfix.conv2d(
x, weight.to(dtype), stride=1, padding=padding, impl=impl)
else: # Convolution with downsampling.
down = self.scale_factor
f = self.filter
padding = self.filter_padding
# When kernel size = 1, use filtering function for downsampling.
if self.kernel_size == 1:
x = upfirdn2d.upfirdn2d(
x, f, down=down, padding=padding, impl=impl)
x = conv2d_gradfix.conv2d(
x, weight.to(dtype), stride=1, padding=0, impl=impl)
# When kernel size != 1, use stride convolution for downsampling.
else:
x = upfirdn2d.upfirdn2d(
x, f, down=1, padding=padding, impl=impl)
x = conv2d_gradfix.conv2d(
x, weight.to(dtype), stride=down, padding=0, impl=impl)
act_gain = self.act_gain * runtime_gain
act_clamp = None
if self.conv_clamp is not None:
act_clamp = self.conv_clamp * runtime_gain
x = bias_act.bias_act(x, bias,
act=self.activation_type,
gain=act_gain,
clamp=act_clamp,
impl=impl)
assert x.dtype == dtype
return x
class DenseLayer(nn.Module):
"""Implements the dense layer."""
def __init__(self,
in_channels,
out_channels,
add_bias,
init_bias,
use_wscale,
wscale_gain,
lr_mul,
activation_type):
"""Initializes with layer settings.
Args:
in_channels: Number of channels of the input tensor.
out_channels: Number of channels of the output tensor.
add_bias: Whether to add bias onto the fully-connected result.
init_bias: The initial bias value before training.
use_wscale: Whether to use weight scaling.
wscale_gain: Gain factor for weight scaling.
lr_mul: Learning multiplier for both weight and bias.
activation_type: Type of activation.
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.add_bias = add_bias
self.init_bias = init_bias
self.use_wscale = use_wscale
self.wscale_gain = wscale_gain
self.lr_mul = lr_mul
self.activation_type = activation_type
weight_shape = (out_channels, in_channels)
wscale = wscale_gain / np.sqrt(in_channels)
if use_wscale:
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
self.wscale = wscale * lr_mul
else:
self.weight = nn.Parameter(
torch.randn(*weight_shape) * wscale / lr_mul)
self.wscale = lr_mul
if add_bias:
init_bias = np.float32(init_bias) / lr_mul
self.bias = nn.Parameter(torch.full([out_channels], init_bias))
self.bscale = lr_mul
else:
self.bias = None
def extra_repr(self):
return (f'in_ch={self.in_channels}, '
f'out_ch={self.out_channels}, '
f'wscale_gain={self.wscale_gain:.3f}, '
f'bias={self.add_bias}, '
f'init_bias={self.init_bias}, '
f'lr_mul={self.lr_mul:.3f}, '
f'act={self.activation_type}')
def forward(self, x, impl='cuda'):
dtype = x.dtype
if x.ndim != 2:
x = x.flatten(start_dim=1)
weight = self.weight.to(dtype) * self.wscale
bias = None
if self.bias is not None:
bias = self.bias.to(dtype)
if self.bscale != 1.0:
bias = bias * self.bscale
# Fast pass for linear activation.
if self.activation_type == 'linear' and bias is not None:
x = torch.addmm(bias.unsqueeze(0), x, weight.t())
else:
x = x.matmul(weight.t())
x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl)
assert x.dtype == dtype
return x