import math import torch import torch.nn as nn import monotonic_align from models.text_encoder import TextEncoder from models.flow_matching import CFMDecoder from models.reference_encoder import MelStyleEncoder from models.duration_predictor import DurationPredictor, duration_loss def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor: if max_length is None: max_length = length.max() x = torch.arange(max_length, dtype=length.dtype, device=length.device) return x.unsqueeze(0) < length.unsqueeze(1) def convert_pad_shape(pad_shape): inverted_shape = pad_shape[::-1] pad_shape = [item for sublist in inverted_shape for item in sublist] return pad_shape def generate_path(duration, mask): device = duration.device b, t_x, t_y = mask.shape cum_duration = torch.cumsum(duration, 1) path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = path.view(b, t_x, t_y) path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] path = path * mask return path # modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/matcha_tts.py class StableTTS(nn.Module): def __init__(self, n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, n_dec_layers, kernel_size, p_dropout, gin_channels): super().__init__() self.n_vocab = n_vocab self.mel_channels = mel_channels self.encoder = TextEncoder(n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, kernel_size, p_dropout, gin_channels) self.ref_encoder = MelStyleEncoder(mel_channels, style_vector_dim=gin_channels, style_kernel_size=3) self.dp = DurationPredictor(hidden_channels, filter_channels, kernel_size, p_dropout, gin_channels) self.decoder = CFMDecoder(mel_channels + mel_channels, mel_channels, filter_channels, n_heads, n_dec_layers, kernel_size, p_dropout, gin_channels) @torch.inference_mode() def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, y=None, length_scale=1.0): """ Generates mel-spectrogram from text. Returns: 1. encoder outputs 2. decoder outputs 3. generated alignment Args: x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. shape: (batch_size, max_text_length) x_lengths (torch.Tensor): lengths of texts in batch. shape: (batch_size,) n_timesteps (int): number of steps to use for reverse diffusion in decoder. temperature (float, optional): controls variance of terminal distribution. y (torch.Tensor): mel spectrogram of reference audio shape: (batch_size, mel_channels, time) length_scale (float, optional): controls speech pace. Increase value to slow down generated speech and vice versa. Returns: dict: { "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), # Average mel spectrogram generated by the encoder "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), # Refined mel spectrogram improved by the CFM "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length), # Alignment map between text and mel spectrogram """ # Get encoder_outputs `mu_x` and log-scaled token durations `logw` c = self.ref_encoder(y, None) x, mu_x, x_mask = self.encoder(x, c, x_lengths) logw = self.dp(x, x_mask, c) w = torch.exp(logw) * x_mask w_ceil = torch.ceil(w) * length_scale y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_max_length = y_lengths.max() # Using obtained durations `w` construct alignment map `attn` y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask.dtype) attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) # Align encoded text and get mu_y mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) mu_y = mu_y.transpose(1, 2) encoder_outputs = mu_y[:, :, :y_max_length] # Generate sample tracing the probability flow decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c) decoder_outputs = decoder_outputs[:, :, :y_max_length] return { "encoder_outputs": encoder_outputs, "decoder_outputs": decoder_outputs, "attn": attn[:, :, :y_max_length], } def forward(self, x, x_lengths, y, y_lengths): """ Computes 3 losses: 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). 2. prior loss: loss between mel-spectrogram and encoder outputs. 3. flow matching loss: loss between mel-spectrogram and decoder outputs. Args: x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. shape: (batch_size, max_text_length) x_lengths (torch.Tensor): lengths of texts in batch. shape: (batch_size,) y (torch.Tensor): batch of corresponding mel-spectrograms. shape: (batch_size, n_feats, max_mel_length) y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. shape: (batch_size,) """ # Get encoder_outputs `mu_x` and log-scaled token durations `logw` y_mask = sequence_mask(y_lengths, y.size(2)).unsqueeze(1).to(y.dtype) c = self.ref_encoder(y, y_mask) x, mu_x, x_mask = self.encoder(x, c, x_lengths) logw = self.dp(x, x_mask, c) attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) # Use MAS to find most likely alignment `attn` between text and mel-spectrogram # I'm not sure why the MAS code in Matcha TTS and Grad TTS could not align in StableTTS # so I use the code from https://github.com/p0p4k/pflowtts_pytorch/blob/master/pflow/models/pflow_tts.py and it works # Welcome everyone to solve this problem QAQ with torch.no_grad(): # const = -0.5 * math.log(2 * math.pi) * self.n_feats # const = -0.5 * math.log(2 * math.pi) * self.mel_channels # factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) # y_square = torch.matmul(factor.transpose(1, 2), y**2) # y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) # mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) # log_prior = y_square - y_mu_double + mu_square + const s_p_sq_r = torch.ones_like(mu_x) # [b, d, t] # s_p_sq_r = torch.exp(-2 * logx) neg_cent1 = torch.sum( -0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True ) # neg_cent1 = torch.sum( # -0.5 * math.log(2 * math.pi) - logx, [1], keepdim=True # ) # [b, 1, t_s] neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r) neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r)) neg_cent4 = torch.sum( -0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True ) neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) attn = ( monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() ) # attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)) # attn = attn.detach() # Compute loss between predicted log-scaled durations and those obtained from MAS # refered to as prior loss in the paper logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask # logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask dur_loss = duration_loss(logw, logw_, x_lengths) # Align encoded text with mel-spectrogram and get mu_y segment attn = attn.squeeze(1).transpose(1,2) mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) mu_y = mu_y.transpose(1, 2) # Compute loss of the decoder diff_loss, _ = self.decoder.compute_loss(y, y_mask, mu_y, c) # diff_loss = torch.tensor([0], device=mu_y.device) prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) prior_loss = prior_loss / (torch.sum(y_mask) * self.mel_channels) return dur_loss, diff_loss, prior_loss, attn