deploy-s2s-api / models /transformer.py
3v324v23's picture
Add application file
c2bb5f2
raw
history blame
13.2 kB
# Copyright 2023 (authors: Feiteng Li)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Any, Dict, List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
# from icefall.utils import make_pad_mask
# from torchmetrics.classification import BinaryAccuracy
from models.vallex import Transpose
from modules.embedding import SinePositionalEmbedding, TokenEmbedding
from modules.scaling import BalancedDoubleSwish, ScaledLinear
from modules.transformer import (
BalancedBasicNorm,
IdentityNorm,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)
from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS
from .visualizer import visualize
IdentityNorm = IdentityNorm
class Transformer(nn.Module):
"""It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding)
Neural Speech Synthesis with Transformer Network
https://arxiv.org/abs/1809.08895
"""
def __init__(
self,
d_model: int,
nhead: int,
num_layers: int,
norm_first: bool = True,
add_prenet: bool = False,
scaling_xformers: bool = False,
):
"""
Args:
d_model:
The number of expected features in the input (required).
nhead:
The number of heads in the multiheadattention models (required).
num_layers:
The number of sub-decoder-layers in the decoder (required).
"""
super().__init__()
self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
if add_prenet:
self.encoder_prenet = nn.Sequential(
Transpose(),
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(d_model),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(d_model),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
nn.BatchNorm1d(d_model),
nn.ReLU(),
nn.Dropout(0.5),
Transpose(),
nn.Linear(d_model, d_model),
)
self.decoder_prenet = nn.Sequential(
nn.Linear(NUM_MEL_BINS, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, d_model),
)
assert scaling_xformers is False # TODO: update this block
else:
self.encoder_prenet = nn.Identity()
if scaling_xformers:
self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model)
else:
self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model)
self.encoder_position = SinePositionalEmbedding(
d_model,
dropout=0.1,
scale=False,
)
self.decoder_position = SinePositionalEmbedding(
d_model, dropout=0.1, scale=False
)
if scaling_xformers:
self.encoder = TransformerEncoder(
TransformerEncoderLayer(
d_model,
nhead,
dim_feedforward=d_model * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
linear1_self_attention_cls=ScaledLinear,
linear2_self_attention_cls=partial(
ScaledLinear, initial_scale=0.01
),
linear1_feedforward_cls=ScaledLinear,
linear2_feedforward_cls=partial(
ScaledLinear, initial_scale=0.01
),
activation=partial(
BalancedDoubleSwish,
channel_dim=-1,
max_abs=10.0,
min_prob=0.25,
),
layer_norm_cls=IdentityNorm,
),
num_layers=num_layers,
norm=BalancedBasicNorm(d_model) if norm_first else None,
)
self.decoder = nn.TransformerDecoder(
TransformerDecoderLayer(
d_model,
nhead,
dim_feedforward=d_model * 4,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
linear1_self_attention_cls=ScaledLinear,
linear2_self_attention_cls=partial(
ScaledLinear, initial_scale=0.01
),
linear1_feedforward_cls=ScaledLinear,
linear2_feedforward_cls=partial(
ScaledLinear, initial_scale=0.01
),
activation=partial(
BalancedDoubleSwish,
channel_dim=-1,
max_abs=10.0,
min_prob=0.25,
),
layer_norm_cls=IdentityNorm,
),
num_layers=num_layers,
norm=BalancedBasicNorm(d_model) if norm_first else None,
)
self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS)
self.stop_layer = nn.Linear(d_model, 1)
else:
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model,
nhead,
dim_feedforward=d_model * 4,
activation=F.relu,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
),
num_layers=num_layers,
norm=nn.LayerNorm(d_model) if norm_first else None,
)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model,
nhead,
dim_feedforward=d_model * 4,
activation=F.relu,
dropout=0.1,
batch_first=True,
norm_first=norm_first,
),
num_layers=num_layers,
norm=nn.LayerNorm(d_model) if norm_first else None,
)
self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS)
self.stop_layer = nn.Linear(d_model, 1)
self.stop_accuracy_metric = BinaryAccuracy(
threshold=0.5, multidim_average="global"
)
# self.apply(self._init_weights)
# def _init_weights(self, module):
# if isinstance(module, (nn.Linear)):
# module.weight.data.normal_(mean=0.0, std=0.02)
# if isinstance(module, nn.Linear) and module.bias is not None:
# module.bias.data.zero_()
# elif isinstance(module, nn.LayerNorm):
# module.bias.data.zero_()
# module.weight.data.fill_(1.0)
# elif isinstance(module, nn.Embedding):
# module.weight.data.normal_(mean=0.0, std=0.02)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: torch.Tensor,
y_lens: torch.Tensor,
reduction: str = "sum",
train_stage: int = 0,
**kwargs,
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
"""
Args:
x:
A 2-D tensor of shape (N, S).
x_lens:
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
before padding.
y:
A 3-D tensor of shape (N, T, 8).
y_lens:
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
before padding.
train_stage:
Not used in this model.
Returns:
Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
"""
del train_stage
assert x.ndim == 2, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.ndim == 3, y.shape
assert y_lens.ndim == 1, y_lens.shape
assert torch.all(x_lens > 0)
# NOTE: x has been padded in TextTokenCollater
x_mask = make_pad_mask(x_lens).to(x.device)
x = self.text_embedding(x)
x = self.encoder_prenet(x)
x = self.encoder_position(x)
x = self.encoder(x, src_key_padding_mask=x_mask)
total_loss, metrics = 0.0, {}
y_mask = make_pad_mask(y_lens).to(y.device)
y_mask_float = y_mask.type(torch.float32)
data_mask = 1.0 - y_mask_float.unsqueeze(-1)
# Training
# AR Decoder
def pad_y(y):
y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach()
# inputs, targets
return y[:, :-1], y[:, 1:]
y, targets = pad_y(y * data_mask) # mask padding as zeros
y_emb = self.decoder_prenet(y)
y_pos = self.decoder_position(y_emb)
y_len = y_lens.max()
tgt_mask = torch.triu(
torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
diagonal=1,
)
y_dec = self.decoder(
y_pos,
x,
tgt_mask=tgt_mask,
memory_key_padding_mask=x_mask,
)
predict = self.predict_layer(y_dec)
# loss
total_loss = F.mse_loss(predict, targets, reduction=reduction)
logits = self.stop_layer(y_dec).squeeze(-1)
stop_loss = F.binary_cross_entropy_with_logits(
logits,
y_mask_float.detach(),
weight=1.0 + y_mask_float.detach() * 4.0,
reduction=reduction,
)
metrics["stop_loss"] = stop_loss.detach()
stop_accuracy = self.stop_accuracy_metric(
(torch.sigmoid(logits) >= 0.5).type(torch.int64),
y_mask.type(torch.int64),
)
# icefall MetricsTracker.norm_items()
metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type(
torch.float32
)
return ((x, predict), total_loss + 100.0 * stop_loss, metrics)
def inference(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: Any = None,
**kwargs,
) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor of shape (1, S).
x_lens:
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
before padding.
Returns:
Return the predicted audio code matrix and cross-entropy loss.
"""
assert x.ndim == 2, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert torch.all(x_lens > 0)
x_mask = make_pad_mask(x_lens).to(x.device)
x = self.text_embedding(x)
x = self.encoder_prenet(x)
x = self.encoder_position(x)
x = self.encoder(x, src_key_padding_mask=x_mask)
x_mask = make_pad_mask(x_lens).to(x.device)
# AR Decoder
# TODO: Managing decoder steps avoid repetitive computation
y = torch.zeros(
[x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device
)
while True:
y_emb = self.decoder_prenet(y)
y_pos = self.decoder_position(y_emb)
tgt_mask = torch.triu(
torch.ones(
y.shape[1], y.shape[1], device=y.device, dtype=torch.bool
),
diagonal=1,
)
y_dec = self.decoder(
y_pos,
x,
tgt_mask=tgt_mask,
memory_mask=None,
memory_key_padding_mask=x_mask,
)
predict = self.predict_layer(y_dec[:, -1:])
logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5
if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()):
print(
f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]"
)
break
y = torch.concat([y, predict], dim=1)
return y[:, 1:]
def visualize(
self,
predicts: Tuple[torch.Tensor],
batch: Dict[str, Union[List, torch.Tensor]],
output_dir: str,
limit: int = 4,
) -> None:
visualize(predicts, batch, output_dir, limit=limit)