tobiasc's picture
Initial commit
ad16788
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi (Nagoya University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""This code is based on https://github.com/kan-bayashi/PytorchWaveNetVocoder."""
import logging
import sys
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
def encode_mu_law(x, mu=256):
"""Perform mu-law encoding.
Args:
x (ndarray): Audio signal with the range from -1 to 1.
mu (int): Quantized level.
Returns:
ndarray: Quantized audio signal with the range from 0 to mu - 1.
"""
mu = mu - 1
fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu)
return np.floor((fx + 1) / 2 * mu + 0.5).astype(np.int64)
def decode_mu_law(y, mu=256):
"""Perform mu-law decoding.
Args:
x (ndarray): Quantized audio signal with the range from 0 to mu - 1.
mu (int): Quantized level.
Returns:
ndarray: Audio signal with the range from -1 to 1.
"""
mu = mu - 1
fx = (y - 0.5) / mu * 2 - 1
x = np.sign(fx) / mu * ((1 + mu) ** np.abs(fx) - 1)
return x
def initialize(m):
"""Initilize conv layers with xavier.
Args:
m (torch.nn.Module): Torch module.
"""
if isinstance(m, nn.Conv1d):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.0)
if isinstance(m, nn.ConvTranspose2d):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
class OneHot(nn.Module):
"""Convert to one-hot vector.
Args:
depth (int): Dimension of one-hot vector.
"""
def __init__(self, depth):
super(OneHot, self).__init__()
self.depth = depth
def forward(self, x):
"""Calculate forward propagation.
Args:
x (LongTensor): long tensor variable with the shape (B, T)
Returns:
Tensor: float tensor variable with the shape (B, depth, T)
"""
x = x % self.depth
x = torch.unsqueeze(x, 2)
x_onehot = x.new_zeros(x.size(0), x.size(1), self.depth).float()
return x_onehot.scatter_(2, x, 1)
class CausalConv1d(nn.Module):
"""1D dilated causal convolution."""
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, bias=True):
super(CausalConv1d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.dilation = dilation
self.padding = padding = (kernel_size - 1) * dilation
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
padding=padding,
dilation=dilation,
bias=bias,
)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor with the shape (B, in_channels, T).
Returns:
Tensor: Tensor with the shape (B, out_channels, T)
"""
x = self.conv(x)
if self.padding != 0:
x = x[:, :, : -self.padding]
return x
class UpSampling(nn.Module):
"""Upsampling layer with deconvolution.
Args:
upsampling_factor (int): Upsampling factor.
"""
def __init__(self, upsampling_factor, bias=True):
super(UpSampling, self).__init__()
self.upsampling_factor = upsampling_factor
self.bias = bias
self.conv = nn.ConvTranspose2d(
1,
1,
kernel_size=(1, self.upsampling_factor),
stride=(1, self.upsampling_factor),
bias=self.bias,
)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor with the shape (B, C, T)
Returns:
Tensor: Tensor with the shape (B, C, T') where T' = T * upsampling_factor.
"""
x = x.unsqueeze(1) # B x 1 x C x T
x = self.conv(x) # B x 1 x C x T'
return x.squeeze(1)
class WaveNet(nn.Module):
"""Conditional wavenet.
Args:
n_quantize (int): Number of quantization.
n_aux (int): Number of aux feature dimension.
n_resch (int): Number of filter channels for residual block.
n_skipch (int): Number of filter channels for skip connection.
dilation_depth (int): Number of dilation depth
(e.g. if set 10, max dilation = 2^(10-1)).
dilation_repeat (int): Number of dilation repeat.
kernel_size (int): Filter size of dilated causal convolution.
upsampling_factor (int): Upsampling factor.
"""
def __init__(
self,
n_quantize=256,
n_aux=28,
n_resch=512,
n_skipch=256,
dilation_depth=10,
dilation_repeat=3,
kernel_size=2,
upsampling_factor=0,
):
super(WaveNet, self).__init__()
self.n_aux = n_aux
self.n_quantize = n_quantize
self.n_resch = n_resch
self.n_skipch = n_skipch
self.kernel_size = kernel_size
self.dilation_depth = dilation_depth
self.dilation_repeat = dilation_repeat
self.upsampling_factor = upsampling_factor
self.dilations = [
2 ** i for i in range(self.dilation_depth)
] * self.dilation_repeat
self.receptive_field = (self.kernel_size - 1) * sum(self.dilations) + 1
# for preprocessing
self.onehot = OneHot(self.n_quantize)
self.causal = CausalConv1d(self.n_quantize, self.n_resch, self.kernel_size)
if self.upsampling_factor > 0:
self.upsampling = UpSampling(self.upsampling_factor)
# for residual blocks
self.dil_sigmoid = nn.ModuleList()
self.dil_tanh = nn.ModuleList()
self.aux_1x1_sigmoid = nn.ModuleList()
self.aux_1x1_tanh = nn.ModuleList()
self.skip_1x1 = nn.ModuleList()
self.res_1x1 = nn.ModuleList()
for d in self.dilations:
self.dil_sigmoid += [
CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d)
]
self.dil_tanh += [
CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d)
]
self.aux_1x1_sigmoid += [nn.Conv1d(self.n_aux, self.n_resch, 1)]
self.aux_1x1_tanh += [nn.Conv1d(self.n_aux, self.n_resch, 1)]
self.skip_1x1 += [nn.Conv1d(self.n_resch, self.n_skipch, 1)]
self.res_1x1 += [nn.Conv1d(self.n_resch, self.n_resch, 1)]
# for postprocessing
self.conv_post_1 = nn.Conv1d(self.n_skipch, self.n_skipch, 1)
self.conv_post_2 = nn.Conv1d(self.n_skipch, self.n_quantize, 1)
def forward(self, x, h):
"""Calculate forward propagation.
Args:
x (LongTensor): Quantized input waveform tensor with the shape (B, T).
h (Tensor): Auxiliary feature tensor with the shape (B, n_aux, T).
Returns:
Tensor: Logits with the shape (B, T, n_quantize).
"""
# preprocess
output = self._preprocess(x)
if self.upsampling_factor > 0:
h = self.upsampling(h)
# residual block
skip_connections = []
for i in range(len(self.dilations)):
output, skip = self._residual_forward(
output,
h,
self.dil_sigmoid[i],
self.dil_tanh[i],
self.aux_1x1_sigmoid[i],
self.aux_1x1_tanh[i],
self.skip_1x1[i],
self.res_1x1[i],
)
skip_connections.append(skip)
# skip-connection part
output = sum(skip_connections)
output = self._postprocess(output)
return output
def generate(self, x, h, n_samples, interval=None, mode="sampling"):
"""Generate a waveform with fast genration algorithm.
This generation based on `Fast WaveNet Generation Algorithm`_.
Args:
x (LongTensor): Initial waveform tensor with the shape (T,).
h (Tensor): Auxiliary feature tensor with the shape (n_samples + T, n_aux).
n_samples (int): Number of samples to be generated.
interval (int, optional): Log interval.
mode (str, optional): "sampling" or "argmax".
Return:
ndarray: Generated quantized waveform (n_samples).
.. _`Fast WaveNet Generation Algorithm`: https://arxiv.org/abs/1611.09482
"""
# reshape inputs
assert len(x.shape) == 1
assert len(h.shape) == 2 and h.shape[1] == self.n_aux
x = x.unsqueeze(0)
h = h.transpose(0, 1).unsqueeze(0)
# perform upsampling
if self.upsampling_factor > 0:
h = self.upsampling(h)
# padding for shortage
if n_samples > h.shape[2]:
h = F.pad(h, (0, n_samples - h.shape[2]), "replicate")
# padding if the length less than
n_pad = self.receptive_field - x.size(1)
if n_pad > 0:
x = F.pad(x, (n_pad, 0), "constant", self.n_quantize // 2)
h = F.pad(h, (n_pad, 0), "replicate")
# prepare buffer
output = self._preprocess(x)
h_ = h[:, :, : x.size(1)]
output_buffer = []
buffer_size = []
for i, d in enumerate(self.dilations):
output, _ = self._residual_forward(
output,
h_,
self.dil_sigmoid[i],
self.dil_tanh[i],
self.aux_1x1_sigmoid[i],
self.aux_1x1_tanh[i],
self.skip_1x1[i],
self.res_1x1[i],
)
if d == 2 ** (self.dilation_depth - 1):
buffer_size.append(self.kernel_size - 1)
else:
buffer_size.append(d * 2 * (self.kernel_size - 1))
output_buffer.append(output[:, :, -buffer_size[i] - 1 : -1])
# generate
samples = x[0]
start_time = time.time()
for i in range(n_samples):
output = samples[-self.kernel_size * 2 + 1 :].unsqueeze(0)
output = self._preprocess(output)
h_ = h[:, :, samples.size(0) - 1].contiguous().view(1, self.n_aux, 1)
output_buffer_next = []
skip_connections = []
for j, d in enumerate(self.dilations):
output, skip = self._generate_residual_forward(
output,
h_,
self.dil_sigmoid[j],
self.dil_tanh[j],
self.aux_1x1_sigmoid[j],
self.aux_1x1_tanh[j],
self.skip_1x1[j],
self.res_1x1[j],
)
output = torch.cat([output_buffer[j], output], dim=2)
output_buffer_next.append(output[:, :, -buffer_size[j] :])
skip_connections.append(skip)
# update buffer
output_buffer = output_buffer_next
# get predicted sample
output = sum(skip_connections)
output = self._postprocess(output)[0]
if mode == "sampling":
posterior = F.softmax(output[-1], dim=0)
dist = torch.distributions.Categorical(posterior)
sample = dist.sample().unsqueeze(0)
elif mode == "argmax":
sample = output.argmax(-1)
else:
logging.error("mode should be sampling or argmax")
sys.exit(1)
samples = torch.cat([samples, sample], dim=0)
# show progress
if interval is not None and (i + 1) % interval == 0:
elapsed_time_per_sample = (time.time() - start_time) / interval
logging.info(
"%d/%d estimated time = %.3f sec (%.3f sec / sample)"
% (
i + 1,
n_samples,
(n_samples - i - 1) * elapsed_time_per_sample,
elapsed_time_per_sample,
)
)
start_time = time.time()
return samples[-n_samples:].cpu().numpy()
def _preprocess(self, x):
x = self.onehot(x).transpose(1, 2)
output = self.causal(x)
return output
def _postprocess(self, x):
output = F.relu(x)
output = self.conv_post_1(output)
output = F.relu(output) # B x C x T
output = self.conv_post_2(output).transpose(1, 2) # B x T x C
return output
def _residual_forward(
self,
x,
h,
dil_sigmoid,
dil_tanh,
aux_1x1_sigmoid,
aux_1x1_tanh,
skip_1x1,
res_1x1,
):
output_sigmoid = dil_sigmoid(x)
output_tanh = dil_tanh(x)
aux_output_sigmoid = aux_1x1_sigmoid(h)
aux_output_tanh = aux_1x1_tanh(h)
output = torch.sigmoid(output_sigmoid + aux_output_sigmoid) * torch.tanh(
output_tanh + aux_output_tanh
)
skip = skip_1x1(output)
output = res_1x1(output)
output = output + x
return output, skip
def _generate_residual_forward(
self,
x,
h,
dil_sigmoid,
dil_tanh,
aux_1x1_sigmoid,
aux_1x1_tanh,
skip_1x1,
res_1x1,
):
output_sigmoid = dil_sigmoid(x)[:, :, -1:]
output_tanh = dil_tanh(x)[:, :, -1:]
aux_output_sigmoid = aux_1x1_sigmoid(h)
aux_output_tanh = aux_1x1_tanh(h)
output = torch.sigmoid(output_sigmoid + aux_output_sigmoid) * torch.tanh(
output_tanh + aux_output_tanh
)
skip = skip_1x1(output)
output = res_1x1(output)
output = output + x[:, :, -1:] # B x C x 1
return output, skip