Shadhil's picture
voice-clone with single audio sample input
9b2107c
from typing import Tuple
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F
from torch.nn.utils import parametrize
from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor
def calc_same_padding(kernel_size: int) -> Tuple[int, int]:
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
class ConvNorm(nn.Module):
"""A 1-dimensional convolutional layer with optional weight normalization.
This layer wraps a 1D convolutional layer from PyTorch and applies
optional weight normalization. The layer can be used in a similar way to
the convolutional layers in PyTorch's `torch.nn` module.
Args:
in_channels (int): The number of channels in the input signal.
out_channels (int): The number of channels in the output signal.
kernel_size (int, optional): The size of the convolving kernel.
Defaults to 1.
stride (int, optional): The stride of the convolution. Defaults to 1.
padding (int, optional): Zero-padding added to both sides of the input.
If `None`, the padding will be calculated so that the output has
the same length as the input. Defaults to `None`.
dilation (int, optional): Spacing between kernel elements. Defaults to 1.
bias (bool, optional): If `True`, add bias after convolution. Defaults to `True`.
w_init_gain (str, optional): The weight initialization function to use.
Can be either 'linear' or 'relu'. Defaults to 'linear'.
use_weight_norm (bool, optional): If `True`, apply weight normalization
to the convolutional weights. Defaults to `False`.
Shapes:
- Input: :math:`[N, D, T]`
- Output: :math:`[N, out_dim, T]` where `out_dim` is the number of output dimensions.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain="linear",
use_weight_norm=False,
):
super(ConvNorm, self).__init__() # pylint: disable=super-with-arguments
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.kernel_size = kernel_size
self.dilation = dilation
self.use_weight_norm = use_weight_norm
conv_fn = nn.Conv1d
self.conv = conv_fn(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain))
if self.use_weight_norm:
self.conv = nn.utils.parametrizations.weight_norm(self.conv)
def forward(self, signal, mask=None):
conv_signal = self.conv(signal)
if mask is not None:
# always re-zero output if mask is
# available to match zero-padding
conv_signal = conv_signal * mask
return conv_signal
class ConvLSTMLinear(nn.Module):
def __init__(
self,
in_dim,
out_dim,
n_layers=2,
n_channels=256,
kernel_size=3,
p_dropout=0.1,
lstm_type="bilstm",
use_linear=True,
):
super(ConvLSTMLinear, self).__init__() # pylint: disable=super-with-arguments
self.out_dim = out_dim
self.lstm_type = lstm_type
self.use_linear = use_linear
self.dropout = nn.Dropout(p=p_dropout)
convolutions = []
for i in range(n_layers):
conv_layer = ConvNorm(
in_dim if i == 0 else n_channels,
n_channels,
kernel_size=kernel_size,
stride=1,
padding=int((kernel_size - 1) / 2),
dilation=1,
w_init_gain="relu",
)
conv_layer = nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight")
convolutions.append(conv_layer)
self.convolutions = nn.ModuleList(convolutions)
if not self.use_linear:
n_channels = out_dim
if self.lstm_type != "":
use_bilstm = False
lstm_channels = n_channels
if self.lstm_type == "bilstm":
use_bilstm = True
lstm_channels = int(n_channels // 2)
self.bilstm = nn.LSTM(n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm)
lstm_norm_fn_pntr = nn.utils.spectral_norm
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0")
if self.lstm_type == "bilstm":
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse")
if self.use_linear:
self.dense = nn.Linear(n_channels, out_dim)
def run_padded_sequence(self, context, lens):
context_embedded = []
for b_ind in range(context.size()[0]): # TODO: speed up
curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone()
for conv in self.convolutions:
curr_context = self.dropout(F.relu(conv(curr_context)))
context_embedded.append(curr_context[0].transpose(0, 1))
context = nn.utils.rnn.pad_sequence(context_embedded, batch_first=True)
return context
def run_unsorted_inputs(self, fn, context, lens): # pylint: disable=no-self-use
lens_sorted, ids_sorted = torch.sort(lens, descending=True)
unsort_ids = [0] * lens.size(0)
for i in range(len(ids_sorted)): # pylint: disable=consider-using-enumerate
unsort_ids[ids_sorted[i]] = i
lens_sorted = lens_sorted.long().cpu()
context = context[ids_sorted]
context = nn.utils.rnn.pack_padded_sequence(context, lens_sorted, batch_first=True)
context = fn(context)[0]
context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0]
# map back to original indices
context = context[unsort_ids]
return context
def forward(self, context, lens):
if context.size()[0] > 1:
context = self.run_padded_sequence(context, lens)
# to B, D, T
context = context.transpose(1, 2)
else:
for conv in self.convolutions:
context = self.dropout(F.relu(conv(context)))
if self.lstm_type != "":
context = context.transpose(1, 2)
self.bilstm.flatten_parameters()
if lens is not None:
context = self.run_unsorted_inputs(self.bilstm, context, lens)
else:
context = self.bilstm(context)[0]
context = context.transpose(1, 2)
x_hat = context
if self.use_linear:
x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2)
return x_hat
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int):
super().__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=in_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class PointwiseConv1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
padding: int = 0,
bias: bool = True,
):
super().__init__()
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=padding,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class BSConv1d(nn.Module):
"""https://arxiv.org/pdf/2003.13549.pdf"""
def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int):
super().__init__()
self.pointwise = nn.Conv1d(channels_in, channels_out, kernel_size=1)
self.depthwise = nn.Conv1d(
channels_out,
channels_out,
kernel_size=kernel_size,
padding=padding,
groups=channels_out,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.pointwise(x)
x2 = self.depthwise(x1)
return x2
class BSConv2d(nn.Module):
"""https://arxiv.org/pdf/2003.13549.pdf"""
def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int):
super().__init__()
self.pointwise = nn.Conv2d(channels_in, channels_out, kernel_size=1)
self.depthwise = nn.Conv2d(
channels_out,
channels_out,
kernel_size=kernel_size,
padding=padding,
groups=channels_out,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.pointwise(x)
x2 = self.depthwise(x1)
return x2
class Conv1dGLU(nn.Module):
"""From DeepVoice 3"""
def __init__(self, d_model: int, kernel_size: int, padding: int, embedding_dim: int):
super().__init__()
self.conv = BSConv1d(d_model, 2 * d_model, kernel_size=kernel_size, padding=padding)
self.embedding_proj = nn.Linear(embedding_dim, d_model)
self.register_buffer("sqrt", torch.sqrt(torch.FloatTensor([0.5])).squeeze(0))
self.softsign = torch.nn.Softsign()
def forward(self, x: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor:
x = x.permute((0, 2, 1))
residual = x
x = self.conv(x)
splitdim = 1
a, b = x.split(x.size(splitdim) // 2, dim=splitdim)
embeddings = self.embedding_proj(embeddings).unsqueeze(2)
softsign = self.softsign(embeddings)
softsign = softsign.expand_as(a)
a = a + softsign
x = a * torch.sigmoid(b)
x = x + residual
x = x * self.sqrt
x = x.permute((0, 2, 1))
return x
class ConvTransposed(nn.Module):
"""
A 1D convolutional transposed layer for PyTorch.
This layer applies a 1D convolutional transpose operation to its input tensor,
where the number of channels of the input tensor is the same as the number of channels of the output tensor.
Attributes:
in_channels (int): The number of channels in the input tensor.
out_channels (int): The number of channels in the output tensor.
kernel_size (int): The size of the convolutional kernel. Default: 1.
padding (int): The number of padding elements to add to the input tensor. Default: 0.
conv (BSConv1d): The 1D convolutional transpose layer.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 1,
padding: int = 0,
):
super().__init__()
self.conv = BSConv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=padding,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.contiguous().transpose(1, 2)
x = self.conv(x)
x = x.contiguous().transpose(1, 2)
return x
class DepthwiseConvModule(nn.Module):
def __init__(self, dim: int, kernel_size: int = 7, expansion: int = 4, lrelu_slope: float = 0.3):
super().__init__()
padding = calc_same_padding(kernel_size)
self.depthwise = nn.Conv1d(
dim,
dim * expansion,
kernel_size=kernel_size,
padding=padding[0],
groups=dim,
)
self.act = nn.LeakyReLU(lrelu_slope)
self.out = nn.Conv1d(dim * expansion, dim, 1, 1, 0)
self.ln = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ln(x)
x = x.permute((0, 2, 1))
x = self.depthwise(x)
x = self.act(x)
x = self.out(x)
x = x.permute((0, 2, 1))
return x
class AddCoords(nn.Module):
def __init__(self, rank: int, with_r: bool = False):
super().__init__()
self.rank = rank
self.with_r = with_r
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.rank == 1:
batch_size_shape, channel_in_shape, dim_x = x.shape # pylint: disable=unused-variable
xx_range = torch.arange(dim_x, dtype=torch.int32)
xx_channel = xx_range[None, None, :]
xx_channel = xx_channel.float() / (dim_x - 1)
xx_channel = xx_channel * 2 - 1
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1)
xx_channel = xx_channel.to(x.device)
out = torch.cat([x, xx_channel], dim=1)
if self.with_r:
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2))
out = torch.cat([out, rr], dim=1)
elif self.rank == 2:
batch_size_shape, channel_in_shape, dim_y, dim_x = x.shape
xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32)
yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32)
xx_range = torch.arange(dim_y, dtype=torch.int32)
yy_range = torch.arange(dim_x, dtype=torch.int32)
xx_range = xx_range[None, None, :, None]
yy_range = yy_range[None, None, :, None]
xx_channel = torch.matmul(xx_range, xx_ones)
yy_channel = torch.matmul(yy_range, yy_ones)
# transpose y
yy_channel = yy_channel.permute(0, 1, 3, 2)
xx_channel = xx_channel.float() / (dim_y - 1)
yy_channel = yy_channel.float() / (dim_x - 1)
xx_channel = xx_channel * 2 - 1
yy_channel = yy_channel * 2 - 1
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)
yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)
xx_channel = xx_channel.to(x.device)
yy_channel = yy_channel.to(x.device)
out = torch.cat([x, xx_channel, yy_channel], dim=1)
if self.with_r:
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
out = torch.cat([out, rr], dim=1)
elif self.rank == 3:
batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = x.shape
xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32)
yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32)
zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32)
xy_range = torch.arange(dim_y, dtype=torch.int32)
xy_range = xy_range[None, None, None, :, None]
yz_range = torch.arange(dim_z, dtype=torch.int32)
yz_range = yz_range[None, None, None, :, None]
zx_range = torch.arange(dim_x, dtype=torch.int32)
zx_range = zx_range[None, None, None, :, None]
xy_channel = torch.matmul(xy_range, xx_ones)
xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2)
yz_channel = torch.matmul(yz_range, yy_ones)
yz_channel = yz_channel.permute(0, 1, 3, 4, 2)
yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4)
zx_channel = torch.matmul(zx_range, zz_ones)
zx_channel = zx_channel.permute(0, 1, 4, 2, 3)
zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3)
xx_channel = xx_channel.to(x.device)
yy_channel = yy_channel.to(x.device)
zz_channel = zz_channel.to(x.device)
out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1)
if self.with_r:
rr = torch.sqrt(
torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2) + torch.pow(zz_channel - 0.5, 2)
)
out = torch.cat([out, rr], dim=1)
else:
raise NotImplementedError
return out
class CoordConv1d(nn.modules.conv.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
with_r: bool = False,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
self.rank = 1
self.addcoords = AddCoords(self.rank, with_r)
self.conv = nn.Conv1d(
in_channels + self.rank + int(with_r),
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.addcoords(x)
x = self.conv(x)
return x
class CoordConv2d(nn.modules.conv.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
with_r: bool = False,
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
self.rank = 2
self.addcoords = AddCoords(self.rank, with_r)
self.conv = nn.Conv2d(
in_channels + self.rank + int(with_r),
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.addcoords(x)
x = self.conv(x)
return x
class LVCBlock(torch.nn.Module):
"""the location-variable convolutions"""
def __init__( # pylint: disable=dangerous-default-value
self,
in_channels,
cond_channels,
stride,
dilations=[1, 3, 9, 27],
lReLU_slope=0.2,
conv_kernel_size=3,
cond_hop_length=256,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
):
super().__init__()
self.cond_hop_length = cond_hop_length
self.conv_layers = len(dilations)
self.conv_kernel_size = conv_kernel_size
self.kernel_predictor = KernelPredictor(
cond_channels=cond_channels,
conv_in_channels=in_channels,
conv_out_channels=2 * in_channels,
conv_layers=len(dilations),
conv_kernel_size=conv_kernel_size,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=kpnet_dropout,
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
)
self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.parametrizations.weight_norm(
nn.ConvTranspose1d(
in_channels,
in_channels,
2 * stride,
stride=stride,
padding=stride // 2 + stride % 2,
output_padding=stride % 2,
)
),
)
self.conv_blocks = nn.ModuleList()
for dilation in dilations:
self.conv_blocks.append(
nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
in_channels,
in_channels,
conv_kernel_size,
padding=dilation * (conv_kernel_size - 1) // 2,
dilation=dilation,
)
),
nn.LeakyReLU(lReLU_slope),
)
)
def forward(self, x, c):
"""forward propagation of the location-variable convolutions.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length)
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
Tensor: the output sequence (batch, in_channels, in_length)
"""
_, in_channels, _ = x.shape # (B, c_g, L')
x = self.convt_pre(x) # (B, c_g, stride * L')
kernels, bias = self.kernel_predictor(c)
for i, conv in enumerate(self.conv_blocks):
output = conv(x) # (B, c_g, stride * L')
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
output = self.location_variable_convolution(
output, k, b, hop_size=self.cond_hop_length
) # (B, 2 * c_g, stride * L'): LVC
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
output[:, in_channels:, :]
) # (B, c_g, stride * L'): GAU
return x
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): # pylint: disable=no-self-use
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length).
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
dilation (int): the dilation of convolution.
hop_size (int): the hop_size of the conditioning sequence.
Returns:
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
"""
batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation:
x = F.pad(x, (0, dilation), "constant", 0)
x = x.unfold(
3, dilation, dilation
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
o = o.to(memory_format=torch.channels_last_3d)
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
o = o + bias
o = o.contiguous().view(batch, out_channels, -1)
return o
def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm()
parametrize.remove_parametrizations(self.convt_pre[1], "weight")
for block in self.conv_blocks:
parametrize.remove_parametrizations(block[1], "weight")