Spaces:
Running
on
T4
Running
on
T4
import dotwiz | |
import torch | |
import torch.nn.functional as torchfunc | |
from torch.nn import Linear | |
from torch.nn import Sequential | |
from torch.nn import Tanh | |
from Modules.GeneralLayers.Conformer import Conformer | |
from Modules.GeneralLayers.LengthRegulator import LengthRegulator | |
from Modules.ToucanTTS.flow_matching import CFMDecoder | |
from Preprocessing.articulatory_features import get_feature_to_index_lookup | |
from Utility.utils import make_non_pad_mask | |
class ToucanTTS(torch.nn.Module): | |
def __init__(self, | |
weights, | |
config): | |
super().__init__() | |
self.config = config | |
config = dotwiz.DotWiz(config) | |
input_feature_dimensions = config.input_feature_dimensions | |
attention_dimension = config.attention_dimension | |
attention_heads = config.attention_heads | |
positionwise_conv_kernel_size = config.positionwise_conv_kernel_size | |
use_scaled_positional_encoding = config.use_scaled_positional_encoding | |
use_macaron_style_in_conformer = config.use_macaron_style_in_conformer | |
use_cnn_in_conformer = config.use_cnn_in_conformer | |
encoder_layers = config.encoder_layers | |
encoder_units = config.encoder_units | |
encoder_normalize_before = config.encoder_normalize_before | |
encoder_concat_after = config.encoder_concat_after | |
conformer_encoder_kernel_size = config.conformer_encoder_kernel_size | |
transformer_enc_dropout_rate = config.transformer_enc_dropout_rate | |
transformer_enc_positional_dropout_rate = config.transformer_enc_positional_dropout_rate | |
transformer_enc_attn_dropout_rate = config.transformer_enc_attn_dropout_rate | |
decoder_layers = config.decoder_layers | |
decoder_units = config.decoder_units | |
decoder_concat_after = config.decoder_concat_after | |
conformer_decoder_kernel_size = config.conformer_decoder_kernel_size | |
decoder_normalize_before = config.decoder_normalize_before | |
transformer_dec_dropout_rate = config.transformer_dec_dropout_rate | |
transformer_dec_positional_dropout_rate = config.transformer_dec_positional_dropout_rate | |
transformer_dec_attn_dropout_rate = config.transformer_dec_attn_dropout_rate | |
duration_predictor_layers = config.duration_predictor_layers | |
duration_predictor_kernel_size = config.duration_predictor_kernel_size | |
duration_predictor_dropout_rate = config.duration_predictor_dropout_rate | |
pitch_predictor_layers = config.pitch_predictor_layers | |
pitch_predictor_kernel_size = config.pitch_predictor_kernel_size | |
pitch_predictor_dropout = config.pitch_predictor_dropout | |
pitch_embed_kernel_size = config.pitch_embed_kernel_size | |
pitch_embed_dropout = config.pitch_embed_dropout | |
energy_predictor_layers = config.energy_predictor_layers | |
energy_predictor_kernel_size = config.energy_predictor_kernel_size | |
energy_predictor_dropout = config.energy_predictor_dropout | |
energy_embed_kernel_size = config.energy_embed_kernel_size | |
energy_embed_dropout = config.energy_embed_dropout | |
cfm_filter_channels = config.cfm_filter_channels | |
cfm_heads = config.cfm_heads | |
cfm_layers = config.cfm_layers | |
cfm_kernel_size = config.cfm_kernel_size | |
cfm_p_dropout = config.cfm_p_dropout | |
utt_embed_dim = config.utt_embed_dim | |
lang_embs = config.lang_embs | |
spec_channels = config.spec_channels | |
embedding_integration = config.embedding_integration | |
lang_emb_size = config.lang_emb_size | |
integrate_language_embedding_into_encoder_out = config.integrate_language_embedding_into_encoder_out | |
prosody_channels = config.prosody_channels | |
if lang_embs is None or lang_embs == 0: | |
lang_embs = None | |
integrate_language_embedding_into_encoder_out = False | |
if integrate_language_embedding_into_encoder_out: | |
utt_embed_dim = utt_embed_dim + lang_emb_size | |
self.input_feature_dimensions = input_feature_dimensions | |
self.attention_dimension = attention_dimension | |
self.use_scaled_pos_enc = use_scaled_positional_encoding | |
self.multilingual_model = lang_embs is not None | |
self.multispeaker_model = utt_embed_dim is not None | |
self.integrate_language_embedding_into_encoder_out = integrate_language_embedding_into_encoder_out | |
self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"] | |
articulatory_feature_embedding = Sequential(Linear(input_feature_dimensions, 100), Tanh(), Linear(100, attention_dimension)) | |
self.encoder = Conformer(conformer_type="encoder", | |
attention_dim=attention_dimension, | |
attention_heads=attention_heads, | |
linear_units=encoder_units, | |
num_blocks=encoder_layers, | |
input_layer=articulatory_feature_embedding, | |
dropout_rate=transformer_enc_dropout_rate, | |
positional_dropout_rate=transformer_enc_positional_dropout_rate, | |
attention_dropout_rate=transformer_enc_attn_dropout_rate, | |
normalize_before=encoder_normalize_before, | |
concat_after=encoder_concat_after, | |
positionwise_conv_kernel_size=positionwise_conv_kernel_size, | |
macaron_style=use_macaron_style_in_conformer, | |
use_cnn_module=True, | |
cnn_module_kernel=conformer_encoder_kernel_size, | |
zero_triu=False, | |
utt_embed=utt_embed_dim, | |
lang_embs=lang_embs, | |
lang_emb_size=lang_emb_size, | |
use_output_norm=True, | |
embedding_integration=embedding_integration) | |
self.duration_predictor = CFMDecoder(hidden_channels=prosody_channels, | |
out_channels=1, | |
filter_channels=prosody_channels, | |
n_heads=1, | |
n_layers=duration_predictor_layers, | |
kernel_size=duration_predictor_kernel_size, | |
p_dropout=duration_predictor_dropout_rate, | |
gin_channels=utt_embed_dim) | |
self.pitch_predictor = CFMDecoder(hidden_channels=prosody_channels, | |
out_channels=1, | |
filter_channels=prosody_channels, | |
n_heads=1, | |
n_layers=pitch_predictor_layers, | |
kernel_size=pitch_predictor_kernel_size, | |
p_dropout=pitch_predictor_dropout, | |
gin_channels=utt_embed_dim) | |
self.energy_predictor = CFMDecoder(hidden_channels=prosody_channels, | |
out_channels=1, | |
filter_channels=prosody_channels, | |
n_heads=1, | |
n_layers=energy_predictor_layers, | |
kernel_size=energy_predictor_kernel_size, | |
p_dropout=energy_predictor_dropout, | |
gin_channels=utt_embed_dim) | |
self.pitch_embed = Sequential(torch.nn.Conv1d(in_channels=1, | |
out_channels=attention_dimension, | |
kernel_size=pitch_embed_kernel_size, | |
padding=(pitch_embed_kernel_size - 1) // 2), | |
torch.nn.Dropout(pitch_embed_dropout)) | |
self.energy_embed = Sequential(torch.nn.Conv1d(in_channels=1, | |
out_channels=attention_dimension, | |
kernel_size=energy_embed_kernel_size, | |
padding=(energy_embed_kernel_size - 1) // 2), | |
torch.nn.Dropout(energy_embed_dropout)) | |
self.length_regulator = LengthRegulator() | |
self.decoder = Conformer(conformer_type="decoder", | |
attention_dim=attention_dimension, | |
attention_heads=attention_heads, | |
linear_units=decoder_units, | |
num_blocks=decoder_layers, | |
input_layer=None, | |
dropout_rate=transformer_dec_dropout_rate, | |
positional_dropout_rate=transformer_dec_positional_dropout_rate, | |
attention_dropout_rate=transformer_dec_attn_dropout_rate, | |
normalize_before=decoder_normalize_before, | |
concat_after=decoder_concat_after, | |
positionwise_conv_kernel_size=positionwise_conv_kernel_size, | |
macaron_style=use_macaron_style_in_conformer, | |
use_cnn_module=use_cnn_in_conformer, | |
cnn_module_kernel=conformer_decoder_kernel_size, | |
use_output_norm=not embedding_integration in ["AdaIN", "ConditionalLayerNorm"], | |
utt_embed=utt_embed_dim, | |
embedding_integration=embedding_integration) | |
self.output_projection = torch.nn.Linear(attention_dimension, spec_channels) | |
self.pitch_latent_reduction = torch.nn.Linear(attention_dimension, prosody_channels) | |
self.energy_latent_reduction = torch.nn.Linear(attention_dimension, prosody_channels) | |
self.duration_latent_reduction = torch.nn.Linear(attention_dimension, prosody_channels) | |
self.flow_matching_decoder = CFMDecoder(hidden_channels=spec_channels, | |
out_channels=spec_channels, | |
filter_channels=cfm_filter_channels, | |
n_heads=cfm_heads, | |
n_layers=cfm_layers, | |
kernel_size=cfm_kernel_size, | |
p_dropout=cfm_p_dropout, | |
gin_channels=utt_embed_dim) | |
self.load_state_dict(weights) | |
self.eval() | |
def _forward(self, | |
text_tensors, | |
text_lengths, | |
gold_durations=None, | |
gold_pitch=None, | |
gold_energy=None, | |
duration_scaling_factor=1.0, | |
utterance_embedding=None, | |
lang_ids=None, | |
pitch_variance_scale=1.0, | |
energy_variance_scale=1.0, | |
pause_duration_scaling_factor=1.0, | |
prosody_creativity=0.1): | |
text_tensors = torch.clamp(text_tensors, max=1.0) | |
# this is necessary, because of the way we represent modifiers to keep them identifiable. | |
if not self.multilingual_model: | |
lang_ids = None | |
if not self.multispeaker_model: | |
utterance_embedding = None | |
if utterance_embedding is not None: | |
utterance_embedding = torch.nn.functional.normalize(utterance_embedding) | |
if self.integrate_language_embedding_into_encoder_out and lang_ids is not None: | |
lang_embs = self.encoder.language_embedding(lang_ids) | |
lang_embs = torch.nn.functional.normalize(lang_embs) | |
utterance_embedding = torch.cat([lang_embs, utterance_embedding], dim=1).detach() | |
# encoding the texts | |
text_masks = make_non_pad_mask(text_lengths, device=text_lengths.device).unsqueeze(-2) | |
encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) | |
# predicting pitch, energy and durations | |
reduced_pitch_space = torchfunc.dropout(self.pitch_latent_reduction(encoded_texts), p=0.1).transpose(1, 2) | |
pitch_predictions = self.pitch_predictor(mu=reduced_pitch_space, | |
mask=text_masks.float(), | |
n_timesteps=10, | |
temperature=prosody_creativity, | |
c=utterance_embedding) if gold_pitch is None else gold_pitch | |
pitch_predictions = _scale_variance(pitch_predictions, pitch_variance_scale) | |
embedded_pitch_curve = self.pitch_embed(pitch_predictions).transpose(1, 2) | |
reduced_energy_space = torchfunc.dropout(self.energy_latent_reduction(encoded_texts + embedded_pitch_curve), p=0.1).transpose(1, 2) | |
energy_predictions = self.energy_predictor(mu=reduced_energy_space, | |
mask=text_masks.float(), | |
n_timesteps=10, | |
temperature=prosody_creativity, | |
c=utterance_embedding) if gold_energy is None else gold_energy | |
energy_predictions = _scale_variance(energy_predictions, energy_variance_scale) | |
embedded_energy_curve = self.energy_embed(energy_predictions).transpose(1, 2) | |
reduced_duration_space = torchfunc.dropout(self.duration_latent_reduction(encoded_texts + embedded_pitch_curve + embedded_energy_curve), p=0.1).transpose(1, 2) | |
predicted_durations = torch.clamp(torch.ceil(self.duration_predictor(mu=reduced_duration_space, | |
mask=text_masks.float(), | |
n_timesteps=10, | |
temperature=prosody_creativity, | |
c=utterance_embedding)), min=0.0).long().squeeze(1) if gold_durations is None else gold_durations | |
# modifying the predictions with control parameters | |
for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)): | |
if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1: | |
predicted_durations[0][phoneme_index] = 0 | |
if phoneme_vector[get_feature_to_index_lookup()["silence"]] == 1 and pause_duration_scaling_factor != 1.0: | |
predicted_durations[0][phoneme_index] = torch.round(predicted_durations[0][phoneme_index].float() * pause_duration_scaling_factor).long() | |
if duration_scaling_factor != 1.0: | |
assert duration_scaling_factor > 0.0 | |
predicted_durations = torch.round(predicted_durations.float() * duration_scaling_factor).long() | |
# enriching the text with pitch and energy info | |
enriched_encoded_texts = encoded_texts + embedded_pitch_curve + embedded_energy_curve | |
# predicting durations for text and upsampling accordingly | |
upsampled_enriched_encoded_texts = self.length_regulator(enriched_encoded_texts, predicted_durations) | |
# decoding spectrogram | |
decoded_speech, _ = self.decoder(upsampled_enriched_encoded_texts, None, utterance_embedding=utterance_embedding) | |
preliminary_spectrogram = self.output_projection(decoded_speech) | |
refined_codec_frames = self.flow_matching_decoder(mu=preliminary_spectrogram.transpose(1, 2), | |
mask=make_non_pad_mask([len(decoded_speech[0])], device=decoded_speech.device).unsqueeze(-2), | |
n_timesteps=15, | |
temperature=0.1, # low temperature, so the model follows the specified prosody curves better. | |
c=None).transpose(1, 2) | |
return refined_codec_frames, predicted_durations.squeeze(), pitch_predictions.squeeze(), energy_predictions.squeeze() | |
def forward(self, | |
text, | |
durations=None, | |
pitch=None, | |
energy=None, | |
utterance_embedding=None, | |
return_duration_pitch_energy=False, | |
lang_id=None, | |
duration_scaling_factor=1.0, | |
pitch_variance_scale=1.0, | |
energy_variance_scale=1.0, | |
pause_duration_scaling_factor=1.0, | |
prosody_creativity=0.1): | |
""" | |
Generate the sequence of spectrogram frames given the sequence of vectorized phonemes. | |
Args: | |
text: input sequence of vectorized phonemes | |
durations: durations to be used (optional, if not provided, they will be predicted) | |
pitch: token-averaged pitch curve to be used (optional, if not provided, it will be predicted) | |
energy: token-averaged energy curve to be used (optional, if not provided, it will be predicted) | |
return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting | |
utterance_embedding: embedding of speaker information | |
lang_id: id to be fed into the embedding layer that contains language information | |
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. | |
1.0 means no scaling happens, higher values increase durations for the whole | |
utterance, lower values decrease durations for the whole utterance. | |
pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. | |
1.0 means no scaling happens, higher values increase variance of the pitch curve, | |
lower values decrease variance of the pitch curve. | |
energy_variance_scale: reasonable values are 0.6 < scale < 1.4. | |
1.0 means no scaling happens, higher values increase variance of the energy curve, | |
lower values decrease variance of the energy curve. | |
pause_duration_scaling_factor: reasonable values are 0.6 < scale < 1.4. | |
scales the durations of pauses on top of the regular duration scaling | |
Returns: | |
features spectrogram | |
""" | |
# setup batch axis | |
text_length = torch.tensor([text.shape[0]], dtype=torch.long, device=text.device) | |
if durations is not None: | |
durations = durations.unsqueeze(0).to(text.device) | |
if pitch is not None: | |
pitch = pitch.unsqueeze(0).to(text.device) | |
if energy is not None: | |
energy = energy.unsqueeze(0).to(text.device) | |
if lang_id is not None: | |
lang_id = lang_id.to(text.device) | |
outs, \ | |
predicted_durations, \ | |
pitch_predictions, \ | |
energy_predictions = self._forward(text.unsqueeze(0), | |
text_length, | |
gold_durations=durations, | |
gold_pitch=pitch, | |
gold_energy=energy, | |
utterance_embedding=utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None, lang_ids=lang_id, | |
duration_scaling_factor=duration_scaling_factor, | |
pitch_variance_scale=pitch_variance_scale, | |
energy_variance_scale=energy_variance_scale, | |
pause_duration_scaling_factor=pause_duration_scaling_factor, | |
prosody_creativity=prosody_creativity) | |
if return_duration_pitch_energy: | |
return outs.squeeze().transpose(0, 1), predicted_durations, pitch_predictions, energy_predictions | |
return outs.squeeze().transpose(0, 1) | |
def store_inverse_all(self): | |
def remove_weight_norm(m): | |
try: | |
torch.nn.utils.remove_weight_norm(m) | |
except ValueError: # this module didn't have weight norm | |
return | |
# self.post_flow.store_inverse() # we're no longer using glow, so this is deprecated | |
self.apply(remove_weight_norm) | |
def _scale_variance(sequence, scale): | |
if scale == 1.0: | |
return sequence | |
average = sequence[0][sequence[0] != 0.0].mean() | |
sequence = sequence - average # center sequence around 0 | |
sequence = sequence * scale # scale the variance | |
sequence = sequence + average # move center back to original with changed variance | |
for sequence_index in range(len(sequence[0][0])): | |
if sequence[0][0][sequence_index] < 0.0: | |
sequence[0][0][sequence_index] = 0.0 | |
return sequence | |
def smooth_time_series(matrix, n_neighbors): | |
""" | |
Smooth a 2D matrix along the time axis using a moving average. | |
Parameters: | |
- matrix (torch.Tensor): Input matrix (2D tensor) representing the time series. | |
- n_neighbors (int): Number of neighboring rows to include in the moving average. | |
Returns: | |
- torch.Tensor: Smoothed matrix. | |
""" | |
smoothed_matrix = torch.zeros_like(matrix) | |
for i in range(matrix.size(0)): | |
lower = max(0, i - n_neighbors) | |
upper = min(matrix.size(0), i + n_neighbors + 1) | |
smoothed_matrix[i] = torch.mean(matrix[lower:upper], dim=0) | |
return smoothed_matrix | |
def make_near_zero_to_zero(sequence): | |
for index in range(len(sequence)): | |
if sequence[index] < 0.2: | |
sequence[index] = 0.0 | |
return sequence | |