aet_demo / model.py
saeki
fix
7b918f7
raw
history blame
25.2 kB
import torch
import torch.nn as nn
import torchaudio
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
class EncoderModule(nn.Module):
"""
Analysis module based on 2D conv U-Net
Inspired by https://github.com/haoheliu/voicefixer
Args:
config (dict): config
use_channel (bool): output channel feature or not
"""
def __init__(self, config, use_channel=False):
super().__init__()
self.channels = 1
self.use_channel = use_channel
self.downsample_ratio = 2 ** 4
self.down_block1 = DownBlockRes2D(
in_channels=self.channels,
out_channels=32,
downsample=(2, 2),
activation="relu",
momentum=0.01,
)
self.down_block2 = DownBlockRes2D(
in_channels=32,
out_channels=64,
downsample=(2, 2),
activation="relu",
momentum=0.01,
)
self.down_block3 = DownBlockRes2D(
in_channels=64,
out_channels=128,
downsample=(2, 2),
activation="relu",
momentum=0.01,
)
self.down_block4 = DownBlockRes2D(
in_channels=128,
out_channels=256,
downsample=(2, 2),
activation="relu",
momentum=0.01,
)
self.conv_block5 = ConvBlockRes2D(
in_channels=256,
out_channels=256,
size=3,
activation="relu",
momentum=0.01,
)
self.up_block1 = UpBlockRes2D(
in_channels=256,
out_channels=256,
stride=(2, 2),
activation="relu",
momentum=0.01,
)
self.up_block2 = UpBlockRes2D(
in_channels=256,
out_channels=128,
stride=(2, 2),
activation="relu",
momentum=0.01,
)
self.up_block3 = UpBlockRes2D(
in_channels=128,
out_channels=64,
stride=(2, 2),
activation="relu",
momentum=0.01,
)
self.up_block4 = UpBlockRes2D(
in_channels=64,
out_channels=32,
stride=(2, 2),
activation="relu",
momentum=0.01,
)
self.after_conv_block1 = ConvBlockRes2D(
in_channels=32,
out_channels=32,
size=3,
activation="relu",
momentum=0.01,
)
self.after_conv2 = nn.Conv2d(
in_channels=32,
out_channels=1,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
bias=True,
)
if config["general"]["feature_type"] == "melspec":
out_dim = config["preprocess"]["n_mels"]
elif config["general"]["feature_type"] == "vocfeats":
out_dim = config["preprocess"]["cep_order"] + 1
else:
raise NotImplementedError()
self.after_linear = nn.Linear(
in_features=80,
out_features=out_dim,
bias=True,
)
if self.use_channel:
self.conv_channel = ConvBlockRes2D(
in_channels=256,
out_channels=256,
size=3,
activation="relu",
momentum=0.01,
)
def forward(self, x):
"""
Forward
Args:
mel spectrogram: (batch, 1, time, freq)
Return:
speech feature (mel spectrogram or mel cepstrum): (batch, 1, time, freq)
input of channel feature module (batch, 256, time, freq)
"""
origin_len = x.shape[2]
pad_len = (
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
- origin_len
)
x = F.pad(x, pad=(0, 0, 0, pad_len))
x = x[..., 0 : x.shape[-1] - 1]
(x1_pool, x1) = self.down_block1(x)
(x2_pool, x2) = self.down_block2(x1_pool)
(x3_pool, x3) = self.down_block3(x2_pool)
(x4_pool, x4) = self.down_block4(x3_pool)
x_center = self.conv_block5(x4_pool)
x5 = self.up_block1(x_center, x4)
x6 = self.up_block2(x5, x3)
x7 = self.up_block3(x6, x2)
x8 = self.up_block4(x7, x1)
x = self.after_conv_block1(x8)
x = self.after_conv2(x)
x = F.pad(x, pad=(0, 1))
x = x[:, :, 0:origin_len, :]
x = self.after_linear(x)
if self.use_channel:
x_channel = self.conv_channel(x4_pool)
return x, x_channel
else:
return x
class ChannelModule(nn.Module):
"""
Channel module based on 1D conv U-Net
Args:
config (dict): config
"""
def __init__(self, config):
super().__init__()
self.channels = 1
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
self.down_block1 = DownBlockRes1D(
in_channels=self.channels,
out_channels=32,
downsample=2,
activation="relu",
momentum=0.01,
)
self.down_block2 = DownBlockRes1D(
in_channels=32,
out_channels=64,
downsample=2,
activation="relu",
momentum=0.01,
)
self.down_block3 = DownBlockRes1D(
in_channels=64,
out_channels=128,
downsample=2,
activation="relu",
momentum=0.01,
)
self.down_block4 = DownBlockRes1D(
in_channels=128,
out_channels=256,
downsample=2,
activation="relu",
momentum=0.01,
)
self.down_block5 = DownBlockRes1D(
in_channels=256,
out_channels=512,
downsample=2,
activation="relu",
momentum=0.01,
)
self.conv_block6 = ConvBlockRes1D(
in_channels=512,
out_channels=384,
size=3,
activation="relu",
momentum=0.01,
)
self.up_block1 = UpBlockRes1D(
in_channels=512,
out_channels=512,
stride=2,
activation="relu",
momentum=0.01,
)
self.up_block2 = UpBlockRes1D(
in_channels=512,
out_channels=256,
stride=2,
activation="relu",
momentum=0.01,
)
self.up_block3 = UpBlockRes1D(
in_channels=256,
out_channels=128,
stride=2,
activation="relu",
momentum=0.01,
)
self.up_block4 = UpBlockRes1D(
in_channels=128,
out_channels=64,
stride=2,
activation="relu",
momentum=0.01,
)
self.up_block5 = UpBlockRes1D(
in_channels=64,
out_channels=32,
stride=2,
activation="relu",
momentum=0.01,
)
self.after_conv_block1 = ConvBlockRes1D(
in_channels=32,
out_channels=32,
size=3,
activation="relu",
momentum=0.01,
)
self.after_conv2 = nn.Conv1d(
in_channels=32,
out_channels=1,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
def forward(self, x, h):
"""
Forward
Args:
clean waveform: (batch, n_channel (1), time)
channel feature: (batch, feature_dim)
Outputs:
degraded waveform: (batch, n_channel (1), time)
"""
x = x.unsqueeze(1)
origin_len = x.shape[2]
pad_len = (
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
- origin_len
)
x = F.pad(x, pad=(0, pad_len))
x = x[..., 0 : x.shape[-1] - 1]
(x1_pool, x1) = self.down_block1(x)
(x2_pool, x2) = self.down_block2(x1_pool)
(x3_pool, x3) = self.down_block3(x2_pool)
(x4_pool, x4) = self.down_block4(x3_pool)
(x5_pool, x5) = self.down_block5(x4_pool)
x_center = self.conv_block6(x5_pool)
x_concat = torch.cat(
(x_center, h.unsqueeze(2).expand(-1, -1, x_center.size(2))), dim=1
)
x6 = self.up_block1(x_concat, x5)
x7 = self.up_block2(x6, x4)
x8 = self.up_block3(x7, x3)
x9 = self.up_block4(x8, x2)
x10 = self.up_block5(x9, x1)
x = self.after_conv_block1(x10)
x = self.after_conv2(x)
x = F.pad(x, pad=(0, 1))
x = x[..., 0:origin_len]
return x.squeeze(1)
class ChannelFeatureModule(nn.Module):
"""
Channel feature module based on 2D convolution layers
Args:
config (dict): config
"""
def __init__(self, config):
super().__init__()
self.conv_blocks_in = ConvBlockRes2D(
in_channels=256,
out_channels=512,
size=3,
activation="relu",
momentum=0.01,
)
self.down_block1 = DownBlockRes2D(
in_channels=512,
out_channels=256,
downsample=(2, 2),
activation="relu",
momentum=0.01,
)
self.down_block2 = DownBlockRes2D(
in_channels=256,
out_channels=256,
downsample=(2, 2),
activation="relu",
momentum=0.01,
)
self.conv_block_out = ConvBlockRes2D(
in_channels=256,
out_channels=128,
size=3,
activation="relu",
momentum=0.01,
)
self.avgpool2d = torch.nn.AdaptiveAvgPool2d(1)
def forward(self, x):
"""
Forward
Args:
output of analysis module: (batch, 256, time, freq)
Return:
channel feature: (batch, feature_dim)
"""
x = self.conv_blocks_in(x)
x, _ = self.down_block1(x)
x, _ = self.down_block2(x)
x = self.conv_block_out(x)
x = self.avgpool2d(x)
x = x.squeeze(3).squeeze(2)
return x
class ConvBlockRes2D(nn.Module):
def __init__(self, in_channels, out_channels, size, activation, momentum):
super().__init__()
self.activation = activation
if type(size) == type((3, 4)):
pad = size[0] // 2
size = size[0]
else:
pad = size // 2
size = size
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(size, size),
stride=(1, 1),
dilation=(1, 1),
padding=(pad, pad),
bias=False,
)
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
self.conv2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=(size, size),
stride=(1, 1),
dilation=(1, 1),
padding=(pad, pad),
bias=False,
)
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
if in_channels != out_channels:
self.shortcut = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
)
self.is_shortcut = True
else:
self.is_shortcut = False
def forward(self, x):
origin = x
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
if self.is_shortcut:
return self.shortcut(origin) + x
else:
return origin + x
class ConvBlockRes1D(nn.Module):
def __init__(self, in_channels, out_channels, size, activation, momentum):
super().__init__()
self.activation = activation
pad = size // 2
self.conv1 = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=size,
stride=1,
dilation=1,
padding=pad,
bias=False,
)
self.bn1 = nn.BatchNorm1d(in_channels, momentum=momentum)
self.conv2 = nn.Conv1d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=size,
stride=1,
dilation=1,
padding=pad,
bias=False,
)
self.bn2 = nn.BatchNorm1d(out_channels, momentum=momentum)
if in_channels != out_channels:
self.shortcut = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
)
self.is_shortcut = True
else:
self.is_shortcut = False
def forward(self, x):
origin = x
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
if self.is_shortcut:
return self.shortcut(origin) + x
else:
return origin + x
class DownBlockRes2D(nn.Module):
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
super().__init__()
size = 3
self.conv_block1 = ConvBlockRes2D(
in_channels, out_channels, size, activation, momentum
)
self.conv_block2 = ConvBlockRes2D(
out_channels, out_channels, size, activation, momentum
)
self.conv_block3 = ConvBlockRes2D(
out_channels, out_channels, size, activation, momentum
)
self.conv_block4 = ConvBlockRes2D(
out_channels, out_channels, size, activation, momentum
)
self.avg_pool2d = torch.nn.AvgPool2d(downsample)
def forward(self, x):
encoder = self.conv_block1(x)
encoder = self.conv_block2(encoder)
encoder = self.conv_block3(encoder)
encoder = self.conv_block4(encoder)
encoder_pool = self.avg_pool2d(encoder)
return encoder_pool, encoder
class DownBlockRes1D(nn.Module):
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
super().__init__()
size = 3
self.conv_block1 = ConvBlockRes1D(
in_channels, out_channels, size, activation, momentum
)
self.conv_block2 = ConvBlockRes1D(
out_channels, out_channels, size, activation, momentum
)
self.conv_block3 = ConvBlockRes1D(
out_channels, out_channels, size, activation, momentum
)
self.conv_block4 = ConvBlockRes1D(
out_channels, out_channels, size, activation, momentum
)
self.avg_pool1d = torch.nn.AvgPool1d(downsample)
def forward(self, x):
encoder = self.conv_block1(x)
encoder = self.conv_block2(encoder)
encoder = self.conv_block3(encoder)
encoder = self.conv_block4(encoder)
encoder_pool = self.avg_pool1d(encoder)
return encoder_pool, encoder
class UpBlockRes2D(nn.Module):
def __init__(self, in_channels, out_channels, stride, activation, momentum):
super().__init__()
size = 3
self.activation = activation
self.conv1 = torch.nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(size, size),
stride=stride,
padding=(0, 0),
output_padding=(0, 0),
bias=False,
dilation=(1, 1),
)
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv_block2 = ConvBlockRes2D(
out_channels * 2, out_channels, size, activation, momentum
)
self.conv_block3 = ConvBlockRes2D(
out_channels, out_channels, size, activation, momentum
)
self.conv_block4 = ConvBlockRes2D(
out_channels, out_channels, size, activation, momentum
)
self.conv_block5 = ConvBlockRes2D(
out_channels, out_channels, size, activation, momentum
)
def prune(self, x, both=False):
"""Prune the shape of x after transpose convolution."""
if both:
x = x[:, :, 0:-1, 0:-1]
else:
x = x[:, :, 0:-1, :]
return x
def forward(self, input_tensor, concat_tensor, both=False):
x = self.conv1(F.relu_(self.bn1(input_tensor)))
x = self.prune(x, both=both)
x = torch.cat((x, concat_tensor), dim=1)
x = self.conv_block2(x)
x = self.conv_block3(x)
x = self.conv_block4(x)
x = self.conv_block5(x)
return x
class UpBlockRes1D(nn.Module):
def __init__(self, in_channels, out_channels, stride, activation, momentum):
super().__init__()
size = 3
self.activation = activation
self.conv1 = torch.nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=size,
stride=stride,
padding=0,
output_padding=0,
bias=False,
dilation=1,
)
self.bn1 = nn.BatchNorm1d(in_channels)
self.conv_block2 = ConvBlockRes1D(
out_channels * 2, out_channels, size, activation, momentum
)
self.conv_block3 = ConvBlockRes1D(
out_channels, out_channels, size, activation, momentum
)
self.conv_block4 = ConvBlockRes1D(
out_channels, out_channels, size, activation, momentum
)
self.conv_block5 = ConvBlockRes1D(
out_channels, out_channels, size, activation, momentum
)
def prune(self, x):
"""Prune the shape of x after transpose convolution."""
print(x.shape)
x = x[:, 0:-1, :]
print(x.shape)
return x
def forward(self, input_tensor, concat_tensor):
x = self.conv1(F.relu_(self.bn1(input_tensor)))
# x = self.prune(x)
x = torch.cat((x, concat_tensor), dim=1)
x = self.conv_block2(x)
x = self.conv_block3(x)
x = self.conv_block4(x)
x = self.conv_block5(x)
return x
class MultiScaleSpectralLoss(nn.Module):
"""
Multi scale spectral loss
https://openreview.net/forum?id=B1x1ma4tDr
Args:
config (dict): config
"""
def __init__(self, config):
super().__init__()
try:
self.use_linear = config["train"]["multi_scale_loss"]["use_linear"]
self.gamma = config["train"]["multi_scale_loss"]["gamma"]
except KeyError:
self.use_linear = False
self.fft_sizes = [2048, 512, 256, 128, 64]
self.spectrograms = []
for fftsize in self.fft_sizes:
self.spectrograms.append(
torchaudio.transforms.Spectrogram(
n_fft=fftsize, hop_length=fftsize // 4, power=2
)
)
self.spectrograms = nn.ModuleList(self.spectrograms)
self.criteria = nn.L1Loss()
self.eps = 1e-10
def forward(self, wav_out, wav_target):
"""
Forward
Args:
wav_out: output of channel module (batch, time)
wav_target: input degraded waveform (batch, time)
Return:
loss
"""
loss = 0.0
length = min(wav_out.size(1), wav_target.size(1))
for spectrogram in self.spectrograms:
S_out = spectrogram(wav_out[..., :length])
S_target = spectrogram(wav_target[..., :length])
log_S_out = torch.log(S_out + self.eps)
log_S_target = torch.log(S_target + self.eps)
if self.use_linear:
loss += self.criteria(S_out, S_target) + self.gamma * self.criteria(
log_S_out, log_S_target
)
else:
loss += self.criteria(log_S_out, log_S_target)
return loss
class ReferenceEncoder(nn.Module):
def __init__(
self, idim=80, ref_enc_filters=[32, 32, 64, 64, 128, 128], ref_dim=128
):
super().__init__()
K = len(ref_enc_filters)
filters = [1] + ref_enc_filters
convs = [
nn.Conv2d(
in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
)
for i in range(K)
]
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(
[nn.BatchNorm2d(num_features=ref_enc_filters[i]) for i in range(K)]
)
out_channels = self.calculate_channels(idim, 3, 2, 1, K)
self.gru = nn.GRU(
input_size=ref_enc_filters[-1] * out_channels,
hidden_size=ref_dim,
batch_first=True,
)
self.n_mel_channels = idim
def forward(self, inputs):
out = inputs.view(inputs.size(0), 1, -1, self.n_mel_channels)
for conv, bn in zip(self.convs, self.bns):
out = conv(out)
out = bn(out)
out = F.relu(out)
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
N, T = out.size(0), out.size(1)
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
self.gru.flatten_parameters()
_, out = self.gru(out)
return out.squeeze(0)
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
for _ in range(n_convs):
L = (L - kernel_size + 2 * pad) // stride + 1
return L
class STL(nn.Module):
def __init__(self, ref_dim=128, num_heads=4, token_num=10, token_dim=128):
super().__init__()
self.embed = nn.Parameter(torch.FloatTensor(token_num, token_dim // num_heads))
d_q = ref_dim
d_k = token_dim // num_heads
self.attention = MultiHeadAttention(
query_dim=d_q, key_dim=d_k, num_units=token_dim, num_heads=num_heads
)
init.normal_(self.embed, mean=0, std=0.5)
def forward(self, inputs):
N = inputs.size(0)
query = inputs.unsqueeze(1)
keys = (
torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1)
) # [N, token_num, token_embedding_size // num_heads]
style_embed = self.attention(query, keys)
return style_embed
class MultiHeadAttention(nn.Module):
"""
Multi head attention
https://github.com/KinglittleQ/GST-Tacotron
"""
def __init__(self, query_dim, key_dim, num_units, num_heads):
super().__init__()
self.num_units = num_units
self.num_heads = num_heads
self.key_dim = key_dim
self.W_query = nn.Linear(
in_features=query_dim, out_features=num_units, bias=False
)
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
self.W_value = nn.Linear(
in_features=key_dim, out_features=num_units, bias=False
)
def forward(self, query, key):
"""
Forward
Args:
query: (batch, T_q, query_dim)
key: (batch, T_k, key_dim)
Return:
out: (N, T_q, num_units)
"""
querys = self.W_query(query) # [N, T_q, num_units]
keys = self.W_key(key) # [N, T_k, num_units]
values = self.W_value(key)
split_size = self.num_units // self.num_heads
querys = torch.stack(
torch.split(querys, split_size, dim=2), dim=0
) # [h, N, T_q, num_units/h]
keys = torch.stack(
torch.split(keys, split_size, dim=2), dim=0
) # [h, N, T_k, num_units/h]
values = torch.stack(
torch.split(values, split_size, dim=2), dim=0
) # [h, N, T_k, num_units/h]
# score = softmax(QK^T / (d_k ** 0.5))
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim ** 0.5)
scores = F.softmax(scores, dim=3)
# out = score * V
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
0
) # [N, T_q, num_units]
return out
class GSTModule(nn.Module):
def __init__(self, config):
super().__init__()
self.encoder_post = ReferenceEncoder(
idim=config["preprocess"]["n_mels"],
ref_dim=256,
)
self.stl = STL(ref_dim=256, num_heads=8, token_num=10, token_dim=128)
def forward(self, inputs):
acoustic_embed = self.encoder_post(inputs)
style_embed = self.stl(acoustic_embed)
return style_embed.squeeze(1)