Spaces:
Running
on
A10G
Running
on
A10G
import math | |
from typing import Optional | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
class Mish(nn.Module): | |
def forward(self, x): | |
return x * torch.tanh(F.softplus(x)) | |
class DiffusionEmbedding(nn.Module): | |
"""Diffusion Step Embedding""" | |
def __init__(self, d_denoiser): | |
super(DiffusionEmbedding, self).__init__() | |
self.dim = d_denoiser | |
def forward(self, x): | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = x[:, None] * emb[None, :] | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
class LinearNorm(nn.Module): | |
"""LinearNorm Projection""" | |
def __init__(self, in_features, out_features, bias=False): | |
super(LinearNorm, self).__init__() | |
self.linear = nn.Linear(in_features, out_features, bias) | |
nn.init.xavier_uniform_(self.linear.weight) | |
if bias: | |
nn.init.constant_(self.linear.bias, 0.0) | |
def forward(self, x): | |
x = self.linear(x) | |
return x | |
class ConvNorm(nn.Module): | |
"""1D Convolution""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=None, | |
dilation=1, | |
bias=True, | |
w_init_gain="linear", | |
): | |
super(ConvNorm, self).__init__() | |
if padding is None: | |
assert kernel_size % 2 == 1 | |
padding = int(dilation * (kernel_size - 1) / 2) | |
self.conv = nn.Conv1d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
) | |
nn.init.kaiming_normal_(self.conv.weight) | |
def forward(self, signal): | |
conv_signal = self.conv(signal) | |
return conv_signal | |
class ResidualBlock(nn.Module): | |
"""Residual Block""" | |
def __init__( | |
self, | |
residual_channels, | |
use_linear_bias=False, | |
dilation=1, | |
condition_channels=None, | |
): | |
super(ResidualBlock, self).__init__() | |
self.conv_layer = ConvNorm( | |
residual_channels, | |
2 * residual_channels, | |
kernel_size=3, | |
stride=1, | |
padding=dilation, | |
dilation=dilation, | |
) | |
if condition_channels is not None: | |
self.diffusion_projection = LinearNorm( | |
residual_channels, residual_channels, use_linear_bias | |
) | |
self.condition_projection = ConvNorm( | |
condition_channels, 2 * residual_channels, kernel_size=1 | |
) | |
self.output_projection = ConvNorm( | |
residual_channels, 2 * residual_channels, kernel_size=1 | |
) | |
def forward(self, x, condition=None, diffusion_step=None): | |
y = x | |
if diffusion_step is not None: | |
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) | |
y = y + diffusion_step | |
y = self.conv_layer(y) | |
if condition is not None: | |
condition = self.condition_projection(condition) | |
y = y + condition | |
gate, filter = torch.chunk(y, 2, dim=1) | |
y = torch.sigmoid(gate) * torch.tanh(filter) | |
y = self.output_projection(y) | |
residual, skip = torch.chunk(y, 2, dim=1) | |
return (x + residual) / math.sqrt(2.0), skip | |
class WaveNet(nn.Module): | |
def __init__( | |
self, | |
input_channels: Optional[int] = None, | |
output_channels: Optional[int] = None, | |
residual_channels: int = 512, | |
residual_layers: int = 20, | |
dilation_cycle: Optional[int] = 4, | |
is_diffusion: bool = False, | |
condition_channels: Optional[int] = None, | |
): | |
super().__init__() | |
# Input projection | |
self.input_projection = None | |
if input_channels is not None and input_channels != residual_channels: | |
self.input_projection = ConvNorm( | |
input_channels, residual_channels, kernel_size=1 | |
) | |
if input_channels is None: | |
input_channels = residual_channels | |
self.input_channels = input_channels | |
# Residual layers | |
self.residual_layers = nn.ModuleList( | |
[ | |
ResidualBlock( | |
residual_channels=residual_channels, | |
use_linear_bias=False, | |
dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1, | |
condition_channels=condition_channels, | |
) | |
for i in range(residual_layers) | |
] | |
) | |
# Skip projection | |
self.skip_projection = ConvNorm( | |
residual_channels, residual_channels, kernel_size=1 | |
) | |
# Output projection | |
self.output_projection = None | |
if output_channels is not None and output_channels != residual_channels: | |
self.output_projection = ConvNorm( | |
residual_channels, output_channels, kernel_size=1 | |
) | |
if is_diffusion: | |
self.diffusion_embedding = DiffusionEmbedding(residual_channels) | |
self.mlp = nn.Sequential( | |
LinearNorm(residual_channels, residual_channels * 4, False), | |
Mish(), | |
LinearNorm(residual_channels * 4, residual_channels, False), | |
) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, (nn.Conv1d, nn.Linear)): | |
nn.init.trunc_normal_(m.weight, std=0.02) | |
if getattr(m, "bias", None) is not None: | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x, t=None, condition=None): | |
if self.input_projection is not None: | |
x = self.input_projection(x) | |
x = F.silu(x) | |
if t is not None: | |
t = self.diffusion_embedding(t) | |
t = self.mlp(t) | |
skip = [] | |
for layer in self.residual_layers: | |
x, skip_connection = layer(x, condition, t) | |
skip.append(skip_connection) | |
x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers)) | |
x = self.skip_projection(x) | |
if self.output_projection is not None: | |
x = F.silu(x) | |
x = self.output_projection(x) | |
return x | |