Spaces:
Runtime error
Runtime error
# Copyright 2023 (authors: Feiteng Li) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from functools import partial | |
from typing import Any, Dict, List, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# from icefall.utils import make_pad_mask | |
# from torchmetrics.classification import BinaryAccuracy | |
from models.vallex import Transpose | |
from modules.embedding import SinePositionalEmbedding, TokenEmbedding | |
from modules.scaling import BalancedDoubleSwish, ScaledLinear | |
from modules.transformer import ( | |
BalancedBasicNorm, | |
IdentityNorm, | |
TransformerDecoderLayer, | |
TransformerEncoder, | |
TransformerEncoderLayer, | |
) | |
from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS | |
from .visualizer import visualize | |
IdentityNorm = IdentityNorm | |
class Transformer(nn.Module): | |
"""It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding) | |
Neural Speech Synthesis with Transformer Network | |
https://arxiv.org/abs/1809.08895 | |
""" | |
def __init__( | |
self, | |
d_model: int, | |
nhead: int, | |
num_layers: int, | |
norm_first: bool = True, | |
add_prenet: bool = False, | |
scaling_xformers: bool = False, | |
): | |
""" | |
Args: | |
d_model: | |
The number of expected features in the input (required). | |
nhead: | |
The number of heads in the multiheadattention models (required). | |
num_layers: | |
The number of sub-decoder-layers in the decoder (required). | |
""" | |
super().__init__() | |
self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x | |
if add_prenet: | |
self.encoder_prenet = nn.Sequential( | |
Transpose(), | |
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), | |
nn.BatchNorm1d(d_model), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), | |
nn.BatchNorm1d(d_model), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), | |
nn.BatchNorm1d(d_model), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
Transpose(), | |
nn.Linear(d_model, d_model), | |
) | |
self.decoder_prenet = nn.Sequential( | |
nn.Linear(NUM_MEL_BINS, 256), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Linear(256, 256), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Linear(256, d_model), | |
) | |
assert scaling_xformers is False # TODO: update this block | |
else: | |
self.encoder_prenet = nn.Identity() | |
if scaling_xformers: | |
self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model) | |
else: | |
self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model) | |
self.encoder_position = SinePositionalEmbedding( | |
d_model, | |
dropout=0.1, | |
scale=False, | |
) | |
self.decoder_position = SinePositionalEmbedding( | |
d_model, dropout=0.1, scale=False | |
) | |
if scaling_xformers: | |
self.encoder = TransformerEncoder( | |
TransformerEncoderLayer( | |
d_model, | |
nhead, | |
dim_feedforward=d_model * 4, | |
dropout=0.1, | |
batch_first=True, | |
norm_first=norm_first, | |
linear1_self_attention_cls=ScaledLinear, | |
linear2_self_attention_cls=partial( | |
ScaledLinear, initial_scale=0.01 | |
), | |
linear1_feedforward_cls=ScaledLinear, | |
linear2_feedforward_cls=partial( | |
ScaledLinear, initial_scale=0.01 | |
), | |
activation=partial( | |
BalancedDoubleSwish, | |
channel_dim=-1, | |
max_abs=10.0, | |
min_prob=0.25, | |
), | |
layer_norm_cls=IdentityNorm, | |
), | |
num_layers=num_layers, | |
norm=BalancedBasicNorm(d_model) if norm_first else None, | |
) | |
self.decoder = nn.TransformerDecoder( | |
TransformerDecoderLayer( | |
d_model, | |
nhead, | |
dim_feedforward=d_model * 4, | |
dropout=0.1, | |
batch_first=True, | |
norm_first=norm_first, | |
linear1_self_attention_cls=ScaledLinear, | |
linear2_self_attention_cls=partial( | |
ScaledLinear, initial_scale=0.01 | |
), | |
linear1_feedforward_cls=ScaledLinear, | |
linear2_feedforward_cls=partial( | |
ScaledLinear, initial_scale=0.01 | |
), | |
activation=partial( | |
BalancedDoubleSwish, | |
channel_dim=-1, | |
max_abs=10.0, | |
min_prob=0.25, | |
), | |
layer_norm_cls=IdentityNorm, | |
), | |
num_layers=num_layers, | |
norm=BalancedBasicNorm(d_model) if norm_first else None, | |
) | |
self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS) | |
self.stop_layer = nn.Linear(d_model, 1) | |
else: | |
self.encoder = nn.TransformerEncoder( | |
nn.TransformerEncoderLayer( | |
d_model, | |
nhead, | |
dim_feedforward=d_model * 4, | |
activation=F.relu, | |
dropout=0.1, | |
batch_first=True, | |
norm_first=norm_first, | |
), | |
num_layers=num_layers, | |
norm=nn.LayerNorm(d_model) if norm_first else None, | |
) | |
self.decoder = nn.TransformerDecoder( | |
nn.TransformerDecoderLayer( | |
d_model, | |
nhead, | |
dim_feedforward=d_model * 4, | |
activation=F.relu, | |
dropout=0.1, | |
batch_first=True, | |
norm_first=norm_first, | |
), | |
num_layers=num_layers, | |
norm=nn.LayerNorm(d_model) if norm_first else None, | |
) | |
self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS) | |
self.stop_layer = nn.Linear(d_model, 1) | |
self.stop_accuracy_metric = BinaryAccuracy( | |
threshold=0.5, multidim_average="global" | |
) | |
# self.apply(self._init_weights) | |
# def _init_weights(self, module): | |
# if isinstance(module, (nn.Linear)): | |
# module.weight.data.normal_(mean=0.0, std=0.02) | |
# if isinstance(module, nn.Linear) and module.bias is not None: | |
# module.bias.data.zero_() | |
# elif isinstance(module, nn.LayerNorm): | |
# module.bias.data.zero_() | |
# module.weight.data.fill_(1.0) | |
# elif isinstance(module, nn.Embedding): | |
# module.weight.data.normal_(mean=0.0, std=0.02) | |
def forward( | |
self, | |
x: torch.Tensor, | |
x_lens: torch.Tensor, | |
y: torch.Tensor, | |
y_lens: torch.Tensor, | |
reduction: str = "sum", | |
train_stage: int = 0, | |
**kwargs, | |
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: | |
""" | |
Args: | |
x: | |
A 2-D tensor of shape (N, S). | |
x_lens: | |
A 1-D tensor of shape (N,). It contains the number of tokens in `x` | |
before padding. | |
y: | |
A 3-D tensor of shape (N, T, 8). | |
y_lens: | |
A 1-D tensor of shape (N,). It contains the number of tokens in `x` | |
before padding. | |
train_stage: | |
Not used in this model. | |
Returns: | |
Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. | |
""" | |
del train_stage | |
assert x.ndim == 2, x.shape | |
assert x_lens.ndim == 1, x_lens.shape | |
assert y.ndim == 3, y.shape | |
assert y_lens.ndim == 1, y_lens.shape | |
assert torch.all(x_lens > 0) | |
# NOTE: x has been padded in TextTokenCollater | |
x_mask = make_pad_mask(x_lens).to(x.device) | |
x = self.text_embedding(x) | |
x = self.encoder_prenet(x) | |
x = self.encoder_position(x) | |
x = self.encoder(x, src_key_padding_mask=x_mask) | |
total_loss, metrics = 0.0, {} | |
y_mask = make_pad_mask(y_lens).to(y.device) | |
y_mask_float = y_mask.type(torch.float32) | |
data_mask = 1.0 - y_mask_float.unsqueeze(-1) | |
# Training | |
# AR Decoder | |
def pad_y(y): | |
y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach() | |
# inputs, targets | |
return y[:, :-1], y[:, 1:] | |
y, targets = pad_y(y * data_mask) # mask padding as zeros | |
y_emb = self.decoder_prenet(y) | |
y_pos = self.decoder_position(y_emb) | |
y_len = y_lens.max() | |
tgt_mask = torch.triu( | |
torch.ones(y_len, y_len, device=y.device, dtype=torch.bool), | |
diagonal=1, | |
) | |
y_dec = self.decoder( | |
y_pos, | |
x, | |
tgt_mask=tgt_mask, | |
memory_key_padding_mask=x_mask, | |
) | |
predict = self.predict_layer(y_dec) | |
# loss | |
total_loss = F.mse_loss(predict, targets, reduction=reduction) | |
logits = self.stop_layer(y_dec).squeeze(-1) | |
stop_loss = F.binary_cross_entropy_with_logits( | |
logits, | |
y_mask_float.detach(), | |
weight=1.0 + y_mask_float.detach() * 4.0, | |
reduction=reduction, | |
) | |
metrics["stop_loss"] = stop_loss.detach() | |
stop_accuracy = self.stop_accuracy_metric( | |
(torch.sigmoid(logits) >= 0.5).type(torch.int64), | |
y_mask.type(torch.int64), | |
) | |
# icefall MetricsTracker.norm_items() | |
metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type( | |
torch.float32 | |
) | |
return ((x, predict), total_loss + 100.0 * stop_loss, metrics) | |
def inference( | |
self, | |
x: torch.Tensor, | |
x_lens: torch.Tensor, | |
y: Any = None, | |
**kwargs, | |
) -> torch.Tensor: | |
""" | |
Args: | |
x: | |
A 2-D tensor of shape (1, S). | |
x_lens: | |
A 1-D tensor of shape (1,). It contains the number of tokens in `x` | |
before padding. | |
Returns: | |
Return the predicted audio code matrix and cross-entropy loss. | |
""" | |
assert x.ndim == 2, x.shape | |
assert x_lens.ndim == 1, x_lens.shape | |
assert torch.all(x_lens > 0) | |
x_mask = make_pad_mask(x_lens).to(x.device) | |
x = self.text_embedding(x) | |
x = self.encoder_prenet(x) | |
x = self.encoder_position(x) | |
x = self.encoder(x, src_key_padding_mask=x_mask) | |
x_mask = make_pad_mask(x_lens).to(x.device) | |
# AR Decoder | |
# TODO: Managing decoder steps avoid repetitive computation | |
y = torch.zeros( | |
[x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device | |
) | |
while True: | |
y_emb = self.decoder_prenet(y) | |
y_pos = self.decoder_position(y_emb) | |
tgt_mask = torch.triu( | |
torch.ones( | |
y.shape[1], y.shape[1], device=y.device, dtype=torch.bool | |
), | |
diagonal=1, | |
) | |
y_dec = self.decoder( | |
y_pos, | |
x, | |
tgt_mask=tgt_mask, | |
memory_mask=None, | |
memory_key_padding_mask=x_mask, | |
) | |
predict = self.predict_layer(y_dec[:, -1:]) | |
logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5 | |
if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()): | |
print( | |
f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]" | |
) | |
break | |
y = torch.concat([y, predict], dim=1) | |
return y[:, 1:] | |
def visualize( | |
self, | |
predicts: Tuple[torch.Tensor], | |
batch: Dict[str, Union[List, torch.Tensor]], | |
output_dir: str, | |
limit: int = 4, | |
) -> None: | |
visualize(predicts, batch, output_dir, limit=limit) | |