diff --git a/Architectures/Aligner/Reconstructor.py b/Architectures/Aligner/Reconstructor.py deleted file mode 100644 index 066f886f4080c1347f731a70a0938d2e571939ca..0000000000000000000000000000000000000000 --- a/Architectures/Aligner/Reconstructor.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -import torch.multiprocessing -from torch.nn.utils.rnn import pack_padded_sequence -from torch.nn.utils.rnn import pad_packed_sequence - -from Utility.utils import make_non_pad_mask - - -class Reconstructor(torch.nn.Module): - - def __init__(self, - n_features=128, - num_symbols=145, - speaker_embedding_dim=192, - lstm_dim=256): - super().__init__() - self.in_proj = torch.nn.Linear(num_symbols + speaker_embedding_dim, lstm_dim) - self.rnn1 = torch.nn.LSTM(lstm_dim, lstm_dim, batch_first=True, bidirectional=True) - self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True) - self.out_proj = torch.nn.Linear(2 * lstm_dim, n_features) - self.l1_criterion = torch.nn.L1Loss(reduction="none") - self.l2_criterion = torch.nn.MSELoss(reduction="none") - - def forward(self, x, lens, ys): - x = self.in_proj(x) - x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False) - x, _ = self.rnn1(x) - x, _ = self.rnn2(x) - x, _ = pad_packed_sequence(x, batch_first=True) - x = self.out_proj(x) - out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(ys.device) - out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() - out_weights /= ys.size(0) * ys.size(2) - l1_loss = self.l1_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum() - l2_loss = self.l2_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum() - return l1_loss + l2_loss - - -if __name__ == '__main__': - print(sum(p.numel() for p in Reconstructor().parameters() if p.requires_grad)) \ No newline at end of file diff --git a/Architectures/ToucanTTS/StochasticToucanTTS/README.md b/Architectures/ToucanTTS/StochasticToucanTTS/README.md deleted file mode 100644 index 61f7205ebe4866f568bd0a99c6cd37634fab8a63..0000000000000000000000000000000000000000 --- a/Architectures/ToucanTTS/StochasticToucanTTS/README.md +++ /dev/null @@ -1 +0,0 @@ -This is an experimental version of the TTS that uses normalizing flows to predict the prosody explicitly, so that we can still have the controllability of the explicit prosody predictors, however a much better naturalness and livelyness than what we get from a deterministic predictor. \ No newline at end of file diff --git a/Architectures/ToucanTTS/StochasticToucanTTS/StochasticToucanTTS.py b/Architectures/ToucanTTS/StochasticToucanTTS/StochasticToucanTTS.py deleted file mode 100644 index af9c8f5cd8bb9fcaef5c26590ff6ba8eabd324c2..0000000000000000000000000000000000000000 --- a/Architectures/ToucanTTS/StochasticToucanTTS/StochasticToucanTTS.py +++ /dev/null @@ -1,493 +0,0 @@ -import torch -from torch.nn import Linear -from torch.nn import Sequential -from torch.nn import Tanh - -from Architectures.GeneralLayers.Conformer import Conformer -from Architectures.GeneralLayers.LengthRegulator import LengthRegulator -from Architectures.ToucanTTS.Glow import Glow -from Architectures.ToucanTTS.StochasticToucanTTS.StochasticToucanTTSLoss import StochasticToucanTTSLoss -from Architectures.ToucanTTS.StochasticToucanTTS.StochasticVariancePredictor import StochasticVariancePredictor -from Preprocessing.articulatory_features import get_feature_to_index_lookup -from Utility.utils import initialize -from Utility.utils import make_non_pad_mask -from Utility.utils import make_pad_mask - - -class StochasticToucanTTS(torch.nn.Module): - """ - StochasticToucanTTS module, which is mostly just a FastSpeech 2 module, - but with lots of designs from different architectures accumulated - and some major components added to put a large focus on multilinguality. - - Original contributions: - - Inputs are configurations of the articulatory tract - - Word boundaries are modeled explicitly in the encoder end removed before the decoder - - Speaker embedding conditioning is derived from GST and Adaspeech 4 - - Responsiveness of variance predictors to utterance embedding is increased through conditional layer norm - - The final output receives a GAN discriminator feedback signal - - Stochastic Duration Prediction through a normalizing flow - - Stochastic Pitch Prediction through a normalizing flow - - Stochastic Energy prediction through a normalizing flow - - Contributions inspired from elsewhere: - - The PostNet is also a normalizing flow, like in PortaSpeech - - Pitch and energy values are averaged per-phone, as in FastPitch to enable great controllability - - The encoder and decoder are Conformers - - """ - - def __init__(self, - # network structure related - input_feature_dimensions=62, - output_spectrogram_channels=80, - attention_dimension=192, - attention_heads=4, - positionwise_conv_kernel_size=1, - use_scaled_positional_encoding=True, - init_type="xavier_uniform", - use_macaron_style_in_conformer=True, - use_cnn_in_conformer=True, - - # encoder - encoder_layers=6, - encoder_units=1536, - encoder_normalize_before=True, - encoder_concat_after=False, - conformer_encoder_kernel_size=7, - transformer_enc_dropout_rate=0.2, - transformer_enc_positional_dropout_rate=0.2, - transformer_enc_attn_dropout_rate=0.2, - - # decoder - decoder_layers=6, - decoder_units=1536, - decoder_concat_after=False, - conformer_decoder_kernel_size=31, - decoder_normalize_before=True, - transformer_dec_dropout_rate=0.2, - transformer_dec_positional_dropout_rate=0.2, - transformer_dec_attn_dropout_rate=0.2, - - # duration predictor - duration_predictor_layers=3, - duration_predictor_chans=256, - duration_predictor_kernel_size=3, - duration_predictor_dropout_rate=0.2, - - # pitch predictor - pitch_embed_kernel_size=1, - pitch_embed_dropout=0.0, - - # energy predictor - energy_embed_kernel_size=1, - energy_embed_dropout=0.0, - - # additional features - utt_embed_dim=192, - lang_embs=8000): - super().__init__() - - self.input_feature_dimensions = input_feature_dimensions - self.output_spectrogram_channels = output_spectrogram_channels - self.attention_dimension = attention_dimension - self.use_scaled_pos_enc = use_scaled_positional_encoding - self.multilingual_model = lang_embs is not None - self.multispeaker_model = utt_embed_dim is not None - - articulatory_feature_embedding = Sequential(Linear(input_feature_dimensions, 100), Tanh(), Linear(100, attention_dimension)) - self.encoder = Conformer(conformer_type="encoder", - attention_dim=attention_dimension, - attention_heads=attention_heads, - linear_units=encoder_units, - num_blocks=encoder_layers, - input_layer=articulatory_feature_embedding, - dropout_rate=transformer_enc_dropout_rate, - positional_dropout_rate=transformer_enc_positional_dropout_rate, - attention_dropout_rate=transformer_enc_attn_dropout_rate, - normalize_before=encoder_normalize_before, - concat_after=encoder_concat_after, - positionwise_conv_kernel_size=positionwise_conv_kernel_size, - macaron_style=use_macaron_style_in_conformer, - use_cnn_module=use_cnn_in_conformer, - cnn_module_kernel=conformer_encoder_kernel_size, - zero_triu=False, - utt_embed=utt_embed_dim, - lang_embs=lang_embs, - use_output_norm=True) - - self.duration_flow = StochasticVariancePredictor(in_channels=attention_dimension, - kernel_size=3, - p_dropout=0.5, - n_flows=5, - conditioning_signal_channels=utt_embed_dim) - - self.pitch_flow = StochasticVariancePredictor(in_channels=attention_dimension, - kernel_size=5, - p_dropout=0.5, - n_flows=6, - conditioning_signal_channels=utt_embed_dim) - - self.energy_flow = StochasticVariancePredictor(in_channels=attention_dimension, - kernel_size=3, - p_dropout=0.5, - n_flows=3, - conditioning_signal_channels=utt_embed_dim) - - self.pitch_embed = Sequential(torch.nn.Conv1d(in_channels=1, - out_channels=attention_dimension, - kernel_size=pitch_embed_kernel_size, - padding=(pitch_embed_kernel_size - 1) // 2), - torch.nn.Dropout(pitch_embed_dropout)) - - self.energy_embed = Sequential(torch.nn.Conv1d(in_channels=1, out_channels=attention_dimension, kernel_size=energy_embed_kernel_size, - padding=(energy_embed_kernel_size - 1) // 2), - torch.nn.Dropout(energy_embed_dropout)) - - self.length_regulator = LengthRegulator() - - self.decoder = Conformer(conformer_type="decoder", - attention_dim=attention_dimension, - attention_heads=attention_heads, - linear_units=decoder_units, - num_blocks=decoder_layers, - input_layer=None, - dropout_rate=transformer_dec_dropout_rate, - positional_dropout_rate=transformer_dec_positional_dropout_rate, - attention_dropout_rate=transformer_dec_attn_dropout_rate, - normalize_before=decoder_normalize_before, - concat_after=decoder_concat_after, - positionwise_conv_kernel_size=positionwise_conv_kernel_size, - macaron_style=use_macaron_style_in_conformer, - use_cnn_module=use_cnn_in_conformer, - cnn_module_kernel=conformer_decoder_kernel_size, - use_output_norm=False, - utt_embed=utt_embed_dim) - - self.feat_out = Linear(attention_dimension, output_spectrogram_channels) - - self.post_flow = Glow( - in_channels=output_spectrogram_channels, - hidden_channels=192, # post_glow_hidden - kernel_size=3, # post_glow_kernel_size - dilation_rate=1, - n_blocks=12, # post_glow_n_blocks (original 12 in paper) - n_layers=3, # post_glow_n_block_layers (original 3 in paper) - n_split=4, - n_sqz=2, - text_condition_channels=attention_dimension, - share_cond_layers=False, # post_share_cond_layers - share_wn_layers=4, - sigmoid_scale=False, - condition_integration_projection=torch.nn.Conv1d(output_spectrogram_channels + attention_dimension, attention_dimension, 5, padding=2) - ) - - # initialize parameters - self._reset_parameters(init_type=init_type) - if lang_embs is not None: - torch.nn.init.normal_(self.encoder.language_embedding.weight, mean=0, std=attention_dimension ** -0.5) - - self.criterion = StochasticToucanTTSLoss() - - def forward(self, - text_tensors, - text_lengths, - gold_speech, - speech_lengths, - gold_durations, - gold_pitch, - gold_energy, - utterance_embedding, - return_feats=False, - lang_ids=None, - run_glow=True - ): - """ - Args: - return_feats (Boolean): whether to return the predicted spectrogram - text_tensors (LongTensor): Batch of padded text vectors (B, Tmax). - text_lengths (LongTensor): Batch of lengths of each input (B,). - gold_speech (Tensor): Batch of padded target features (B, Lmax, odim). - speech_lengths (LongTensor): Batch of the lengths of each target (B,). - gold_durations (LongTensor): Batch of padded durations (B, Tmax + 1). - gold_pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1). - gold_energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1). - run_glow (Boolean): Whether to run the PostNet. There should be a warmup phase in the beginning. - lang_ids (LongTensor): The language IDs used to access the language embedding table, if the model is multilingual - utterance_embedding (Tensor): Batch of embeddings to condition the TTS on, if the model is multispeaker - """ - before_outs, \ - after_outs, \ - duration_loss, \ - pitch_loss, \ - energy_loss, \ - glow_loss = self._forward(text_tensors=text_tensors, - text_lengths=text_lengths, - gold_speech=gold_speech, - speech_lengths=speech_lengths, - gold_durations=gold_durations, - gold_pitch=gold_pitch, - gold_energy=gold_energy, - utterance_embedding=utterance_embedding, - is_inference=False, - lang_ids=lang_ids, - run_glow=run_glow) - - # calculate loss - l1_loss = self.criterion(after_outs=after_outs, - before_outs=before_outs, - gold_spectrograms=gold_speech, - spectrogram_lengths=speech_lengths, - text_lengths=text_lengths) - - if return_feats: - if after_outs is None: - after_outs = before_outs - return l1_loss, duration_loss, pitch_loss, energy_loss, glow_loss, after_outs - return l1_loss, duration_loss, pitch_loss, energy_loss, glow_loss - - def _forward(self, - text_tensors, - text_lengths, - gold_speech=None, - speech_lengths=None, - gold_durations=None, - gold_pitch=None, - gold_energy=None, - is_inference=False, - utterance_embedding=None, - lang_ids=None, - run_glow=True): - - if not self.multilingual_model: - lang_ids = None - - if not self.multispeaker_model: - utterance_embedding = None - - # encoding the texts - text_masks = make_non_pad_mask(text_lengths, device=text_lengths.device).unsqueeze(-2) - padding_masks = make_pad_mask(text_lengths, device=text_lengths.device) - encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) - - if is_inference: - variance_mask = torch.ones(size=[text_tensors.size(1)], device=text_tensors.device) - - # predicting pitch - pitch_predictions = self.pitch_flow(encoded_texts.transpose(1, 2), variance_mask, w=None, g=utterance_embedding.unsqueeze(-1), reverse=True).squeeze(-1).transpose(1, 2) - for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)): - if phoneme_vector[get_feature_to_index_lookup()["voiced"]] == 0: - pitch_predictions[0][phoneme_index] = 0.0 - embedded_pitch_curve = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2) - encoded_texts = encoded_texts + embedded_pitch_curve - - # predicting energy - energy_predictions = self.energy_flow(encoded_texts.transpose(1, 2), variance_mask, w=None, g=utterance_embedding.unsqueeze(-1), reverse=True).squeeze(-1).transpose(1, 2) - embedded_energy_curve = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2) - encoded_texts = encoded_texts + embedded_energy_curve - - # predicting durations - predicted_durations = self.duration_flow(encoded_texts.transpose(1, 2), variance_mask, w=None, g=utterance_embedding.unsqueeze(-1), reverse=True).squeeze(-1).transpose(1, 2).squeeze(-1) - predicted_durations = torch.ceil(torch.exp(predicted_durations)).long() - for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)): - if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1: - predicted_durations[0][phoneme_index] = 0 - - # predicting durations for text and upsampling accordingly - upsampled_enriched_encoded_texts = self.length_regulator(encoded_texts, predicted_durations) - - else: - # learning to predict pitch - idx = gold_pitch != 0 - pitch_mask = torch.logical_and(text_masks, idx.transpose(1, 2)) - scaled_pitch_targets = gold_pitch.detach().clone() - scaled_pitch_targets[idx] = torch.exp(gold_pitch[idx]) # we scale up, so that the log in the flow can handle the value ranges better. - pitch_flow_loss = torch.sum(self.pitch_flow(encoded_texts.transpose(1, 2).detach(), pitch_mask, w=scaled_pitch_targets.transpose(1, 2), g=utterance_embedding.unsqueeze(-1), reverse=False)) - pitch_flow_loss = torch.sum(pitch_flow_loss / torch.sum(pitch_mask)) # weighted masking - embedded_pitch_curve = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2) - encoded_texts = encoded_texts + embedded_pitch_curve - - # learning to predict energy - idx = gold_energy != 0 - energy_mask = torch.logical_and(text_masks, idx.transpose(1, 2)) - scaled_energy_targets = gold_energy.detach().clone() - scaled_energy_targets[idx] = torch.exp(gold_energy[idx]) # we scale up, so that the log in the flow can handle the value ranges better. - energy_flow_loss = torch.sum(self.energy_flow(encoded_texts.transpose(1, 2).detach(), energy_mask, w=scaled_energy_targets.transpose(1, 2), g=utterance_embedding.unsqueeze(-1), reverse=False)) - energy_flow_loss = torch.sum(energy_flow_loss / torch.sum(energy_mask)) # weighted masking - embedded_energy_curve = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2) - encoded_texts = encoded_texts + embedded_energy_curve - - # learning to predict durations - idx = gold_durations.unsqueeze(-1) != 0 - duration_mask = torch.logical_and(text_masks, idx.transpose(1, 2)) - duration_targets = gold_durations.unsqueeze(-1).detach().clone().float() - duration_flow_loss = torch.sum(self.duration_flow(encoded_texts.transpose(1, 2).detach(), duration_mask, w=duration_targets.transpose(1, 2), g=utterance_embedding.unsqueeze(-1), reverse=False)) - duration_flow_loss = torch.sum(duration_flow_loss / torch.sum(duration_mask)) # weighted masking - - upsampled_enriched_encoded_texts = self.length_regulator(encoded_texts, gold_durations) - - # decoding spectrogram - decoder_masks = make_non_pad_mask(speech_lengths, device=speech_lengths.device).unsqueeze(-2) if speech_lengths is not None and not is_inference else None - decoded_speech, _ = self.decoder(upsampled_enriched_encoded_texts, decoder_masks, utterance_embedding=utterance_embedding) - decoded_spectrogram = self.feat_out(decoded_speech).view(decoded_speech.size(0), -1, self.output_spectrogram_channels) - - # refine spectrogram further with a normalizing flow (requires warmup, so it's not always on) - glow_loss = None - if run_glow: - if is_inference: - refined_spectrogram = self.post_flow(tgt_mels=None, - infer=is_inference, - mel_out=decoded_spectrogram, - encoded_texts=upsampled_enriched_encoded_texts, - tgt_nonpadding=None).squeeze() - else: - glow_loss = self.post_flow(tgt_mels=gold_speech, - infer=is_inference, - mel_out=decoded_spectrogram.detach().clone(), - encoded_texts=upsampled_enriched_encoded_texts.detach().clone(), - tgt_nonpadding=decoder_masks) - if is_inference: - return decoded_spectrogram.squeeze(), \ - refined_spectrogram.squeeze(), \ - predicted_durations.squeeze(), \ - pitch_predictions.squeeze(), \ - energy_predictions.squeeze() - else: - return decoded_spectrogram, \ - None, \ - duration_flow_loss, \ - pitch_flow_loss, \ - energy_flow_loss, \ - glow_loss - - @torch.inference_mode() - def inference(self, - text, - speech=None, - utterance_embedding=None, - return_duration_pitch_energy=False, - lang_id=None, - run_postflow=True): - """ - Args: - text (LongTensor): Input sequence of characters (T,). - speech (Tensor, optional): Feature sequence to extract style (N, idim). - return_duration_pitch_energy (Boolean): whether to return the list of predicted durations for nicer plotting - run_postflow (Boolean): Whether to run the PostNet. There should be a warmup phase in the beginning. - lang_id (LongTensor): The language ID used to access the language embedding table, if the model is multilingual - utterance_embedding (Tensor): Embedding to condition the TTS on, if the model is multispeaker - """ - self.eval() - x, y = text, speech - - # setup batch axis - ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) - xs, ys = x.unsqueeze(0), None - if y is not None: - ys = y.unsqueeze(0) - if lang_id is not None: - lang_id = lang_id.unsqueeze(0) - utterance_embeddings = utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None - - before_outs, \ - after_outs, \ - duration_predictions, \ - pitch_predictions, \ - energy_predictions = self._forward(xs, - ilens, - ys, - is_inference=True, - utterance_embedding=utterance_embeddings, - lang_ids=lang_id, - run_glow=run_postflow) # (1, L, odim) - self.train() - if after_outs is None: - after_outs = before_outs - if return_duration_pitch_energy: - return before_outs, after_outs, duration_predictions, pitch_predictions, energy_predictions - return after_outs - - def _reset_parameters(self, init_type): - # initialize parameters - if init_type != "pytorch": - initialize(self, init_type) - - -if __name__ == '__main__': - print(sum(p.numel() for p in StochasticToucanTTS().parameters() if p.requires_grad)) - - print(" TESTING TRAINING ") - - print(" batchsize 3 ") - dummy_text_batch = torch.randint(low=0, high=2, size=[3, 3, 62]).float() # [Batch, Sequence Length, Features per Phone] - dummy_text_lens = torch.LongTensor([2, 3, 3]) - - dummy_speech_batch = torch.randn([3, 30, 80]) # [Batch, Sequence Length, Spectrogram Buckets] - dummy_speech_lens = torch.LongTensor([10, 30, 20]) - - dummy_durations = torch.LongTensor([[10, 0, 0], [10, 15, 5], [5, 5, 10]]) - dummy_pitch = torch.Tensor([[[1.0], [0.], [0.]], [[1.1], [1.2], [0.8]], [[1.1], [1.2], [0.8]]]) - dummy_energy = torch.Tensor([[[1.0], [1.3], [0.]], [[1.1], [1.4], [0.8]], [[1.1], [1.2], [0.8]]]) - - dummy_utterance_embed = torch.randn([3, 192]) # [Batch, Dimensions of Speaker Embedding] - dummy_language_id = torch.LongTensor([5, 3, 2]).unsqueeze(1) - - model = StochasticToucanTTS() - l1, dl, pl, el, gl = model(dummy_text_batch, - dummy_text_lens, - dummy_speech_batch, - dummy_speech_lens, - dummy_durations, - dummy_pitch, - dummy_energy, - utterance_embedding=dummy_utterance_embed, - lang_ids=dummy_language_id) - - loss = l1 + gl + dl + pl + el - print(loss) - loss.backward() - - # from Utility.utils import plot_grad_flow - - # plot_grad_flow(model.encoder.named_parameters()) - # plot_grad_flow(model.decoder.named_parameters()) - # plot_grad_flow(model.pitch_predictor.named_parameters()) - # plot_grad_flow(model.duration_predictor.named_parameters()) - # plot_grad_flow(model.post_flow.named_parameters()) - - print(" batchsize 2 ") - dummy_text_batch = torch.randint(low=0, high=2, size=[2, 3, 62]).float() # [Batch, Sequence Length, Features per Phone] - dummy_text_lens = torch.LongTensor([2, 3]) - - dummy_speech_batch = torch.randn([2, 30, 80]) # [Batch, Sequence Length, Spectrogram Buckets] - dummy_speech_lens = torch.LongTensor([10, 30]) - - dummy_durations = torch.LongTensor([[10, 0, 0], [10, 15, 5]]) - dummy_pitch = torch.Tensor([[[1.0], [0.], [0.]], [[1.1], [1.2], [0.8]]]) - dummy_energy = torch.Tensor([[[1.0], [1.3], [0.]], [[1.1], [1.4], [0.8]]]) - - dummy_utterance_embed = torch.randn([2, 192]) # [Batch, Dimensions of Speaker Embedding] - dummy_language_id = torch.LongTensor([5, 3]).unsqueeze(1) - - model = StochasticToucanTTS() - l1, dl, pl, el, gl = model(dummy_text_batch, - dummy_text_lens, - dummy_speech_batch, - dummy_speech_lens, - dummy_durations, - dummy_pitch, - dummy_energy, - utterance_embedding=dummy_utterance_embed, - lang_ids=dummy_language_id) - - loss = l1 + gl + dl + el + pl - print(loss) - loss.backward() - - print(" TESTING INFERENCE ") - dummy_text_batch = torch.randint(low=0, high=2, size=[12, 62]).float() # [Sequence Length, Features per Phone] - dummy_utterance_embed = torch.randn([192]) # [Dimensions of Speaker Embedding] - dummy_language_id = torch.LongTensor([2]) - print(StochasticToucanTTS().inference(dummy_text_batch, - utterance_embedding=dummy_utterance_embed, - lang_id=dummy_language_id).shape) diff --git a/Architectures/ToucanTTS/StochasticToucanTTS/StochasticToucanTTSLoss.py b/Architectures/ToucanTTS/StochasticToucanTTS/StochasticToucanTTSLoss.py deleted file mode 100644 index 37bbaf00299a129dcc06cdfaf5841e6bb92fa3c5..0000000000000000000000000000000000000000 --- a/Architectures/ToucanTTS/StochasticToucanTTS/StochasticToucanTTSLoss.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -Taken from ESPNet -Adapted by Flux -""" - -import torch - -from Architectures.GeneralLayers.DurationPredictor import DurationPredictorLoss -from Utility.utils import make_non_pad_mask - - -class StochasticToucanTTSLoss(torch.nn.Module): - - def __init__(self): - super().__init__() - self.l1_criterion = torch.nn.L1Loss(reduction="none") - self.duration_criterion = DurationPredictorLoss(reduction="none") - self.mse_criterion = torch.nn.MSELoss(reduction="none") - - def forward(self, after_outs, before_outs, gold_spectrograms, spectrogram_lengths, text_lengths): - """ - Args: - after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). - before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). - gold_spectrograms (Tensor): Batch of target features (B, Lmax, odim). - spectrogram_lengths (LongTensor): Batch of the lengths of each target (B,). - text_lengths (LongTensor): Batch of the lengths of each input (B,). - - Returns: - Tensor: L1 loss value. - Tensor: Duration loss value - """ - - # calculate loss - l1_loss = self.l1_criterion(before_outs, gold_spectrograms) - if after_outs is not None: - l1_loss = l1_loss + self.l1_criterion(after_outs, gold_spectrograms) - - # make weighted mask and apply it - out_masks = make_non_pad_mask(spectrogram_lengths).unsqueeze(-1).to(gold_spectrograms.device) - out_masks = torch.nn.functional.pad(out_masks.transpose(1, 2), [0, gold_spectrograms.size(1) - out_masks.size(1), 0, 0, 0, 0], value=False).transpose(1, 2) - out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() - out_weights /= gold_spectrograms.size(0) * gold_spectrograms.size(2) - duration_masks = make_non_pad_mask(text_lengths).to(gold_spectrograms.device) - duration_weights = (duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float()) - variance_masks = duration_masks.unsqueeze(-1) - variance_weights = duration_weights.unsqueeze(-1) - - # apply weight - l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum() - - return l1_loss diff --git a/Architectures/ToucanTTS/StochasticToucanTTS/StochasticVariancePredictor.py b/Architectures/ToucanTTS/StochasticToucanTTS/StochasticVariancePredictor.py deleted file mode 100644 index ea34679ad649061ee38187763241d53d49e8452a..0000000000000000000000000000000000000000 --- a/Architectures/ToucanTTS/StochasticToucanTTS/StochasticVariancePredictor.py +++ /dev/null @@ -1,440 +0,0 @@ -""" -Code taken and adapted from https://github.com/jaywalnut310/vits - -MIT License - -Copyright (c) 2021 Jaehyeon Kim - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -""" - -import math - -import numpy as np -import torch -from torch import nn -from torch.nn import functional as F - -DEFAULT_MIN_BIN_WIDTH = 1e-3 -DEFAULT_MIN_BIN_HEIGHT = 1e-3 -DEFAULT_MIN_DERIVATIVE = 1e-3 - - -class StochasticVariancePredictor(nn.Module): - def __init__(self, in_channels, kernel_size, p_dropout, n_flows=4, conditioning_signal_channels=0): - super().__init__() - self.in_channels = in_channels - self.filter_channels = in_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.n_flows = n_flows - self.gin_channels = conditioning_signal_channels if conditioning_signal_channels is not None else 0 - - self.log_flow = Log() - self.flows = nn.ModuleList() - self.flows.append(ElementwiseAffine(2)) - for i in range(n_flows): - self.flows.append(ConvFlow(2, in_channels, kernel_size, n_layers=3)) - self.flows.append(Flip()) - - self.post_pre = nn.Conv1d(1, in_channels, 1) - self.post_proj = nn.Conv1d(in_channels, in_channels, 1) - self.post_convs = DDSConv(in_channels, kernel_size, n_layers=3, p_dropout=p_dropout) - self.post_flows = nn.ModuleList() - self.post_flows.append(ElementwiseAffine(2)) - for i in range(4): - self.post_flows.append(ConvFlow(2, in_channels, kernel_size, n_layers=3)) - self.post_flows.append(Flip()) - - self.pre = nn.Conv1d(in_channels, in_channels, 1) - self.proj = nn.Conv1d(in_channels, in_channels, 1) - self.convs = DDSConv(in_channels, kernel_size, n_layers=3, p_dropout=p_dropout) - if self.gin_channels != 0: - self.cond = nn.Conv1d(self.gin_channels, in_channels, 1) - - def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=0.3): - x = self.pre(x) - if g is not None: - g = torch.detach(g) - x = x + self.cond(g) - x = self.convs(x, x_mask) - x = self.proj(x) * x_mask - - if not reverse: - flows = self.flows - assert w is not None - - logdet_tot_q = 0 - h_w = self.post_pre(w) - h_w = self.post_convs(h_w, x_mask) - h_w = self.post_proj(h_w) * x_mask - e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask - z_q = e_q - for flow in self.post_flows: - z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) - logdet_tot_q += logdet_q - z_u, z1 = torch.split(z_q, [1, 1], 1) - u = torch.sigmoid(z_u) * x_mask - z0 = (w - u) * x_mask - logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) - logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q - - logdet_tot = 0 - z0, logdet = self.log_flow(z0, x_mask) - logdet_tot += logdet - z = torch.cat([z0, z1], 1) - for flow in flows: - z, logdet = flow(z, x_mask, g=x, reverse=reverse) - logdet_tot = logdet_tot + logdet - nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot - return nll + logq # [b] - else: - flows = list(reversed(self.flows)) - flows = flows[:-2] + [flows[-1]] # remove a useless vflow - z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale - # noise scale 0.8 derived from coqui implementation, but dropped to 0.3 during testing. Might not be ideal yet. - for flow in flows: - z = flow(z, x_mask, g=x, reverse=reverse) - z0, z1 = torch.split(z, [1, 1], 1) - logw = z0 - return logw - - -class Log(nn.Module): - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = torch.log(torch.clamp_min(x, 1e-6)) * x_mask - logdet = torch.sum(-y, [1, 2]) - return y, logdet - else: - x = torch.exp(x) * x_mask - return x - - -class Flip(nn.Module): - def forward(self, x, *args, reverse=False, **kwargs): - x = torch.flip(x, [1]) - if not reverse: - logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) - return x, logdet - else: - return x - - -class DDSConv(nn.Module): - """ - Dialted and Depth-Separable Convolution - """ - - def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - - self.drop = nn.Dropout(p_dropout) - self.convs_sep = nn.ModuleList() - self.convs_1x1 = nn.ModuleList() - self.norms_1 = nn.ModuleList() - self.norms_2 = nn.ModuleList() - for i in range(n_layers): - dilation = kernel_size ** i - padding = (kernel_size * dilation - dilation) // 2 - self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, - groups=channels, dilation=dilation, padding=padding - )) - self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) - - def forward(self, x, x_mask, g=None): - if g is not None: - x = x + g - for i in range(self.n_layers): - y = self.convs_sep[i](x * x_mask) - y = self.norms_1[i](y) - y = F.gelu(y) - y = self.convs_1x1[i](y) - y = self.norms_2[i](y) - y = F.gelu(y) - y = self.drop(y) - x = x + y - return x * x_mask - - -class ConvFlow(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): - super().__init__() - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.num_bins = num_bins - self.tail_bound = tail_bound - self.half_channels = in_channels // 2 - - self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) - self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) - self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward(self, x, x_mask, g=None, reverse=False): - x0, x1 = torch.split(x, [self.half_channels] * 2, 1) - h = self.pre(x0) - h = self.convs(h, x_mask, g=g) - h = self.proj(h) * x_mask - - b, c, t = x0.shape - h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] - - unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) - unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels) - unnormalized_derivatives = h[..., 2 * self.num_bins:] - - x1, logabsdet = piecewise_rational_quadratic_transform(x1, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=reverse, - tails='linear', - tail_bound=self.tail_bound - ) - - x = torch.cat([x0, x1], 1) * x_mask - logdet = torch.sum(logabsdet * x_mask, [1, 2]) - if not reverse: - return x, logdet - else: - return x - - -class ElementwiseAffine(nn.Module): - def __init__(self, channels): - super().__init__() - self.channels = channels - self.m = nn.Parameter(torch.zeros(channels, 1)) - self.logs = nn.Parameter(torch.zeros(channels, 1)) - - def forward(self, x, x_mask, reverse=False, **kwargs): - if not reverse: - y = self.m + torch.exp(self.logs) * x - y = y * x_mask - logdet = torch.sum(self.logs * x_mask, [1, 2]) - return y, logdet - else: - x = (x - self.m) * torch.exp(-self.logs) * x_mask - return x - - -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps - - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - - -def piecewise_rational_quadratic_transform(inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails=None, - tail_bound=1., - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE): - if tails is None: - spline_fn = rational_quadratic_spline - spline_kwargs = {} - else: - spline_fn = unconstrained_rational_quadratic_spline - spline_kwargs = { - 'tails' : tails, - 'tail_bound': tail_bound - } - - outputs, logabsdet = spline_fn( - inputs=inputs, - unnormalized_widths=unnormalized_widths, - unnormalized_heights=unnormalized_heights, - unnormalized_derivatives=unnormalized_derivatives, - inverse=inverse, - min_bin_width=min_bin_width, - min_bin_height=min_bin_height, - min_derivative=min_derivative, - **spline_kwargs - ) - return outputs, logabsdet - - -def rational_quadratic_spline(inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - left=0., right=1., bottom=0., top=1., - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE): - if torch.min(inputs) < left or torch.max(inputs) > right: - raise ValueError('Input to a transform is not within its domain') - - num_bins = unnormalized_widths.shape[-1] - - if min_bin_width * num_bins > 1.0: - raise ValueError('Minimal bin width too large for the number of bins') - if min_bin_height * num_bins > 1.0: - raise ValueError('Minimal bin height too large for the number of bins') - - widths = F.softmax(unnormalized_widths, dim=-1) - widths = min_bin_width + (1 - min_bin_width * num_bins) * widths - cumwidths = torch.cumsum(widths, dim=-1) - cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) - cumwidths = (right - left) * cumwidths + left - cumwidths[..., 0] = left - cumwidths[..., -1] = right - widths = cumwidths[..., 1:] - cumwidths[..., :-1] - - derivatives = min_derivative + F.softplus(unnormalized_derivatives) - - heights = F.softmax(unnormalized_heights, dim=-1) - heights = min_bin_height + (1 - min_bin_height * num_bins) * heights - cumheights = torch.cumsum(heights, dim=-1) - cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) - cumheights = (top - bottom) * cumheights + bottom - cumheights[..., 0] = bottom - cumheights[..., -1] = top - heights = cumheights[..., 1:] - cumheights[..., :-1] - - if inverse: - bin_idx = searchsorted(cumheights, inputs)[..., None] - else: - bin_idx = searchsorted(cumwidths, inputs)[..., None] - - input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] - input_bin_widths = widths.gather(-1, bin_idx)[..., 0] - - input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] - delta = heights / widths - input_delta = delta.gather(-1, bin_idx)[..., 0] - - input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] - input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] - - input_heights = heights.gather(-1, bin_idx)[..., 0] - - if inverse: - a = (((inputs - input_cumheights) * (input_derivatives - + input_derivatives_plus_one - - 2 * input_delta) - + input_heights * (input_delta - input_derivatives))) - b = (input_heights * input_derivatives - - (inputs - input_cumheights) * (input_derivatives - + input_derivatives_plus_one - - 2 * input_delta)) - c = - input_delta * (inputs - input_cumheights) - - discriminant = b.pow(2) - 4 * a * c - assert (discriminant >= 0).all() - - root = (2 * c) / (-b - torch.sqrt(discriminant)) - outputs = root * input_bin_widths + input_cumwidths - - theta_one_minus_theta = root * (1 - root) - denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta) - derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - root).pow(2)) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) - - return outputs, -logabsdet - else: - theta = (inputs - input_cumwidths) / input_bin_widths - theta_one_minus_theta = theta * (1 - theta) - - numerator = input_heights * (input_delta * theta.pow(2) - + input_derivatives * theta_one_minus_theta) - denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) - * theta_one_minus_theta) - outputs = input_cumheights + numerator / denominator - - derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) - + 2 * input_delta * theta_one_minus_theta - + input_derivatives * (1 - theta).pow(2)) - logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) - - return outputs, logabsdet - - -def searchsorted(bin_locations, inputs, eps=1e-6): - bin_locations[..., -1] += eps - return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 - - -def unconstrained_rational_quadratic_spline(inputs, - unnormalized_widths, - unnormalized_heights, - unnormalized_derivatives, - inverse=False, - tails='linear', - tail_bound=1., - min_bin_width=DEFAULT_MIN_BIN_WIDTH, - min_bin_height=DEFAULT_MIN_BIN_HEIGHT, - min_derivative=DEFAULT_MIN_DERIVATIVE): - inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) - outside_interval_mask = ~inside_interval_mask - - outputs = torch.zeros_like(inputs) - logabsdet = torch.zeros_like(inputs) - - if tails == 'linear': - unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) - constant = np.log(np.exp(1 - min_derivative) - 1) - unnormalized_derivatives[..., 0] = constant - unnormalized_derivatives[..., -1] = constant - - outputs[outside_interval_mask] = inputs[outside_interval_mask] - logabsdet[outside_interval_mask] = 0 - else: - raise RuntimeError('{} tails are not implemented.'.format(tails)) - - outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( - inputs=inputs[inside_interval_mask], - unnormalized_widths=unnormalized_widths[inside_interval_mask, :], - unnormalized_heights=unnormalized_heights[inside_interval_mask, :], - unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], - inverse=inverse, - left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, - min_bin_width=min_bin_width, - min_bin_height=min_bin_height, - min_derivative=min_derivative - ) - - return outputs, logabsdet diff --git a/Architectures/__init__.py b/Architectures/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/InferenceInterfaces/ControllableInterface.py b/InferenceInterfaces/ControllableInterface.py index bab941b57c070060fbe694ad4bbdb974f0d8b400..ae68ceb5ddabe4fc3805a7dabbce114d58abde33 100644 --- a/InferenceInterfaces/ControllableInterface.py +++ b/InferenceInterfaces/ControllableInterface.py @@ -2,8 +2,8 @@ import os import torch -from Architectures.ControllabilityGAN.GAN import GanWrapper from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface +from Modules.ControllabilityGAN.GAN import GanWrapper from Utility.storage_config import MODELS_DIR @@ -15,7 +15,7 @@ class ControllableInterface: else: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}" - self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = "cuda" if gpu_id != "cpu" else "cpu" self.model = ToucanTTSInterface(device=self.device, tts_model_path="Meta") self.wgan = GanWrapper(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt"), device=self.device) self.generated_speaker_embeds = list() @@ -25,9 +25,11 @@ class ControllableInterface: def read(self, prompt, + reference_audio, language, accent, voice_seed, + prosody_creativity, duration_scaling_factor, pause_duration_scaling_factor, pitch_variance_scale, @@ -37,24 +39,29 @@ class ControllableInterface: emb_slider_3, emb_slider_4, emb_slider_5, - emb_slider_6 + emb_slider_6, + loudness_in_db ): if self.current_language != language: self.model.set_phonemizer_language(language) + print(f"switched phonemizer language to {language}") self.current_language = language if self.current_accent != accent: self.model.set_accent_language(accent) + print(f"switched accent language to {accent}") self.current_accent = accent - - self.wgan.set_latent(voice_seed) - controllability_vector = torch.tensor([emb_slider_1, - emb_slider_2, - emb_slider_3, - emb_slider_4, - emb_slider_5, - emb_slider_6], dtype=torch.float32) - embedding = self.wgan.modify_embed(controllability_vector) - self.model.set_utterance_embedding(embedding=embedding) + if reference_audio is None: + self.wgan.set_latent(voice_seed) + controllability_vector = torch.tensor([emb_slider_1, + emb_slider_2, + emb_slider_3, + emb_slider_4, + emb_slider_5, + emb_slider_6], dtype=torch.float32) + embedding = self.wgan.modify_embed(controllability_vector) + self.model.set_utterance_embedding(embedding=embedding) + else: + self.model.set_utterance_embedding(reference_audio) phones = self.model.text2phone.get_phone_string(prompt) if len(phones) > 1800: @@ -92,15 +99,15 @@ class ControllableInterface: if self.current_accent != "eng": self.model.set_accent_language("eng") self.current_accent = "eng" - print("\n\n") - print(prompt) - print(language) - print("\n\n") + + print(prompt + "\n\n") wav, sr, fig = self.model(prompt, input_is_phones=False, duration_scaling_factor=duration_scaling_factor, pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, pause_duration_scaling_factor=pause_duration_scaling_factor, - return_plot_as_filepath=True) + return_plot_as_filepath=True, + prosody_creativity=prosody_creativity, + loudness_in_db=loudness_in_db) return sr, wav, fig diff --git a/InferenceInterfaces/ToucanTTSInterface.py b/InferenceInterfaces/ToucanTTSInterface.py index 6ebae8a696e03bd3418d542fd89981263c90837e..271a99ef1565a37661c0ca070c3ff93b6d94f30b 100644 --- a/InferenceInterfaces/ToucanTTSInterface.py +++ b/InferenceInterfaces/ToucanTTSInterface.py @@ -1,19 +1,17 @@ import itertools import os -import warnings +import librosa import matplotlib.pyplot as plt import pyloudnorm import sounddevice import soundfile import torch -with warnings.catch_warnings(): - warnings.simplefilter("ignore") - from speechbrain.pretrained import EncoderClassifier - from torchaudio.transforms import Resample +from speechbrain.pretrained import EncoderClassifier +from torchaudio.transforms import Resample -from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS -from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN +from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS +from Modules.Vocoder.HiFiGAN_Generator import HiFiGAN from Preprocessing.AudioPreprocessor import AudioPreprocessor from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend from Preprocessing.TextFrontend import get_language_id @@ -29,7 +27,6 @@ class ToucanTTSInterface(torch.nn.Module): tts_model_path=os.path.join(MODELS_DIR, f"ToucanTTS_Meta", "best.pt"), # path to the ToucanTTS checkpoint or just a shorthand if run standalone vocoder_model_path=os.path.join(MODELS_DIR, f"Vocoder", "best.pt"), # path to the Vocoder checkpoint language="eng", # initial language of the model, can be changed later with the setter methods - enhance=None # legacy argument ): super().__init__() self.device = device @@ -40,7 +37,7 @@ class ToucanTTSInterface(torch.nn.Module): ################################ # build text to phone # ################################ - self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True) + self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True, device=device) ##################################### # load phone to features model # @@ -92,8 +89,12 @@ class ToucanTTSInterface(torch.nn.Module): speaker_embs = list() for path in path_to_reference_audio: wave, sr = soundfile.read(path) + if len(wave.shape) > 1: # oh no, we found a stereo audio! + if len(wave[0]) == 2: # let's figure out whether we need to switch the axes + wave = wave.transpose() # if yes, we switch the axes. + wave = librosa.to_mono(wave) wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)(torch.tensor(wave, device=self.device, dtype=torch.float32)) - speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(self.device).unsqueeze(0)).squeeze() + speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(self.device).squeeze().unsqueeze(0)).squeeze() speaker_embs.append(speaker_embedding) self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs) @@ -105,10 +106,10 @@ class ToucanTTSInterface(torch.nn.Module): self.set_accent_language(lang_id=lang_id) def set_phonemizer_language(self, lang_id): - self.text2phone.change_lang(language=lang_id, add_silence_to_end=True) + self.text2phone = ArticulatoryCombinedTextFrontend(language=lang_id, add_silence_to_end=True, device=self.device) def set_accent_language(self, lang_id): - if lang_id in ['ajp', 'ajt', 'lak', 'lno', 'nul', 'pii', 'plj', 'slq', 'smd', 'snb', 'tpw', 'wya', 'zua', 'en-us', 'en-sc', 'fr-be', 'fr-sw', 'pt-br', 'spa-lat', 'vi-ctr', 'vi-so']: + if lang_id in {'ajp', 'ajt', 'lak', 'lno', 'nul', 'pii', 'plj', 'slq', 'smd', 'snb', 'tpw', 'wya', 'zua', 'en-us', 'en-sc', 'fr-be', 'fr-sw', 'pt-br', 'spa-lat', 'vi-ctr', 'vi-so'}: if lang_id == 'vi-so' or lang_id == 'vi-ctr': lang_id = 'vie' elif lang_id == 'spa-lat': @@ -120,7 +121,7 @@ class ToucanTTSInterface(torch.nn.Module): elif lang_id == 'en-sc' or lang_id == 'en-us': lang_id = 'eng' else: - # no clue where these others are even coming from, they are not in ISO 639-2 + # no clue where these others are even coming from, they are not in ISO 639-3 lang_id = 'eng' self.lang_id = get_language_id(lang_id).to(self.device) @@ -138,7 +139,7 @@ class ToucanTTSInterface(torch.nn.Module): input_is_phones=False, return_plot_as_filepath=False, loudness_in_db=-24.0, - glow_sampling_temperature=0.2): + prosody_creativity=0.1): """ duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. 1.0 means no scaling happens, higher values increase durations for the whole @@ -154,16 +155,16 @@ class ToucanTTSInterface(torch.nn.Module): phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device)) mel, durations, pitch, energy = self.phone2mel(phones, return_duration_pitch_energy=True, - utterance_embedding=self.default_utterance_embedding.to(self.device), + utterance_embedding=self.default_utterance_embedding, durations=durations, pitch=pitch, energy=energy, - lang_id=self.lang_id.to(self.device), + lang_id=self.lang_id, duration_scaling_factor=duration_scaling_factor, pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, pause_duration_scaling_factor=pause_duration_scaling_factor, - glow_sampling_temperature=glow_sampling_temperature) + prosody_creativity=prosody_creativity) wave, _, _ = self.vocoder(mel.unsqueeze(0)) wave = wave.squeeze().cpu() @@ -177,63 +178,56 @@ class ToucanTTSInterface(torch.nn.Module): pass if view or return_plot_as_filepath: - try: - fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5)) + fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5)) - ax.imshow(mel.cpu().numpy(), origin="lower", cmap='GnBu') - ax.yaxis.set_visible(False) - duration_splits, label_positions = cumsum_durations(durations.cpu().numpy()) - ax.xaxis.grid(True, which='minor') - ax.set_xticks(label_positions, minor=False) - if input_is_phones: - phones = text.replace(" ", "|") - else: - phones = self.text2phone.get_phone_string(text, for_plot_labels=True) - try: - ax.set_xticklabels(phones) - except IndexError: - pass - word_boundaries = list() - for label_index, phone in enumerate(phones): - if phone == "|": - word_boundaries.append(label_positions[label_index]) + ax.imshow(mel.cpu().numpy(), origin="lower", cmap='GnBu') + ax.yaxis.set_visible(False) + duration_splits, label_positions = cumsum_durations(durations.cpu().numpy()) + ax.xaxis.grid(True, which='minor') + ax.set_xticks(label_positions, minor=False) + if input_is_phones: + phones = text.replace(" ", "|") + else: + phones = self.text2phone.get_phone_string(text, for_plot_labels=True) + try: + ax.set_xticklabels(phones) + except IndexError: + pass + except ValueError: + pass + word_boundaries = list() + for label_index, phone in enumerate(phones): + if phone == "|": + word_boundaries.append(label_positions[label_index]) - try: - prev_word_boundary = 0 - word_label_positions = list() - for word_boundary in word_boundaries: - word_label_positions.append((word_boundary + prev_word_boundary) / 2) - prev_word_boundary = word_boundary - word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2) + try: + prev_word_boundary = 0 + word_label_positions = list() + for word_boundary in word_boundaries: + word_label_positions.append((word_boundary + prev_word_boundary) / 2) + prev_word_boundary = word_boundary + word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2) - secondary_ax = ax.secondary_xaxis('bottom') - secondary_ax.tick_params(axis="x", direction="out", pad=24) - secondary_ax.set_xticks(word_label_positions, minor=False) - secondary_ax.set_xticklabels(text.split()) - secondary_ax.tick_params(axis='x', colors='orange') - secondary_ax.xaxis.label.set_color('orange') - except ValueError: - ax.set_title(text) - except IndexError: - ax.set_title(text) - except RuntimeError: - ax.set_title(text) + secondary_ax = ax.secondary_xaxis('bottom') + secondary_ax.tick_params(axis="x", direction="out", pad=24) + secondary_ax.set_xticks(word_label_positions, minor=False) + secondary_ax.set_xticklabels(text.split()) + secondary_ax.tick_params(axis='x', colors='orange') + secondary_ax.xaxis.label.set_color('orange') + except ValueError: + ax.set_title(text) + except IndexError: + ax.set_title(text) - ax.vlines(x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5) - ax.vlines(x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0) - plt.subplots_adjust(left=0.02, bottom=0.2, right=0.98, top=.9, wspace=0.0, hspace=0.0) - ax.set_aspect("auto") - except: - pass + ax.vlines(x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5) + ax.vlines(x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0) + plt.subplots_adjust(left=0.02, bottom=0.2, right=0.98, top=.9, wspace=0.0, hspace=0.0) + ax.set_aspect("auto") if return_plot_as_filepath: - try: - plt.savefig("tmp.png") - plt.close() - except: - pass + plt.savefig("tmp.png") + plt.close() return wave, sr, "tmp.png" - return wave, sr def read_to_file(self, @@ -247,7 +241,7 @@ class ToucanTTSInterface(torch.nn.Module): dur_list=None, pitch_list=None, energy_list=None, - glow_sampling_temperature=0.2): + prosody_creativity=0.1): """ Args: silent: Whether to be verbose about the process @@ -259,12 +253,19 @@ class ToucanTTSInterface(torch.nn.Module): duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. 1.0 means no scaling happens, higher values increase durations for the whole utterance, lower values decrease durations for the whole utterance. + pause_duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. + 1.0 means no scaling happens, higher values increase durations for the pauses, + lower values decrease durations for the whole utterance. pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. 1.0 means no scaling happens, higher values increase variance of the pitch curve, lower values decrease variance of the pitch curve. energy_variance_scale: reasonable values are 0.6 < scale < 1.4. 1.0 means no scaling happens, higher values increase variance of the energy curve, lower values decrease variance of the energy curve. + prosody_creativity: sampling temperature of the generative model that comes up with the pitch, energy and + durations. Higher values mena more variance, lower temperature means less variance across + generations. reasonable values are between 0.0 and 1.2, anything higher makes the voice + sound very weird. """ if not dur_list: dur_list = [] @@ -272,7 +273,7 @@ class ToucanTTSInterface(torch.nn.Module): pitch_list = [] if not energy_list: energy_list = [] - silence = torch.zeros([14300]) + silence = torch.zeros([400]) wav = silence.clone() for (text, durations, pitch, energy) in itertools.zip_longest(text_list, dur_list, pitch_list, energy_list): if text.strip() != "": @@ -286,7 +287,7 @@ class ToucanTTSInterface(torch.nn.Module): pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, pause_duration_scaling_factor=pause_duration_scaling_factor, - glow_sampling_temperature=glow_sampling_temperature) + prosody_creativity=prosody_creativity) spoken_sentence = torch.tensor(spoken_sentence).cpu() wav = torch.cat((wav, spoken_sentence, silence), 0) soundfile.write(file=file_location, data=float2pcm(wav), samplerate=sr, subtype="PCM_16") @@ -298,7 +299,7 @@ class ToucanTTSInterface(torch.nn.Module): pitch_variance_scale=1.0, energy_variance_scale=1.0, blocking=False, - glow_sampling_temperature=0.2): + prosody_creativity=0.1): if text.strip() == "": return wav, sr = self(text, @@ -306,7 +307,7 @@ class ToucanTTSInterface(torch.nn.Module): duration_scaling_factor=duration_scaling_factor, pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, - glow_sampling_temperature=glow_sampling_temperature) + prosody_creativity=prosody_creativity) silence = torch.zeros([sr // 2]) wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy() sounddevice.play(float2pcm(wav), samplerate=sr) diff --git a/InferenceInterfaces/UtteranceCloner.py b/InferenceInterfaces/UtteranceCloner.py index 409d0efe6e962b1fe3d68f40dce418a38d1badb3..cddeda0efed47f18abeb5586d36bd7cf8840419b 100644 --- a/InferenceInterfaces/UtteranceCloner.py +++ b/InferenceInterfaces/UtteranceCloner.py @@ -4,11 +4,11 @@ import numpy import soundfile as sf import torch -from Architectures.Aligner.Aligner import Aligner -from Architectures.ToucanTTS.DurationCalculator import DurationCalculator -from Architectures.ToucanTTS.EnergyCalculator import EnergyCalculator -from Architectures.ToucanTTS.PitchCalculator import Parselmouth from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface +from Modules.Aligner.Aligner import Aligner +from Modules.ToucanTTS.DurationCalculator import DurationCalculator +from Modules.ToucanTTS.EnergyCalculator import EnergyCalculator +from Modules.ToucanTTS.PitchCalculator import Parselmouth from Preprocessing.AudioPreprocessor import AudioPreprocessor from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend from Preprocessing.articulatory_features import get_feature_to_index_lookup @@ -26,7 +26,7 @@ class UtteranceCloner: def __init__(self, model_id, device, language="eng"): self.tts = ToucanTTSInterface(device=device, tts_model_path=model_id) self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, cut_silence=False) - self.tf = ArticulatoryCombinedTextFrontend(language=language) + self.tf = ArticulatoryCombinedTextFrontend(language=language, device=device) self.device = device acoustic_checkpoint_path = os.path.join(MODELS_DIR, "Aligner", "aligner.pt") self.aligner_weights = torch.load(acoustic_checkpoint_path, map_location=device)["asr_model"] @@ -43,6 +43,7 @@ class UtteranceCloner: self.acoustic_model = Aligner() self.acoustic_model = self.acoustic_model.to(self.device) self.acoustic_model.load_state_dict(self.aligner_weights) + self.acoustic_model.eval() self.parsel = Parselmouth(reduction_factor=1, fs=16000) self.energy_calc = EnergyCalculator(reduction_factor=1, fs=16000) self.dc = DurationCalculator(reduction_factor=1) @@ -50,10 +51,11 @@ class UtteranceCloner: def extract_prosody(self, transcript, ref_audio_path, lang="eng", on_line_fine_tune=True): if on_line_fine_tune: self.acoustic_model.load_state_dict(self.aligner_weights) + self.acoustic_model.eval() wave, sr = sf.read(ref_audio_path) if self.tf.language != lang: - self.tf = ArticulatoryCombinedTextFrontend(language=lang) + self.tf = ArticulatoryCombinedTextFrontend(language=lang, device=self.device) if self.ap.input_sr != sr: self.ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=False) try: diff --git a/InferenceInterfaces/audioseal_wm_16bits.yaml b/InferenceInterfaces/audioseal_wm_16bits.yaml deleted file mode 100644 index 64fc1b2c1a759d23aeaf430889566f225f60c159..0000000000000000000000000000000000000000 --- a/InferenceInterfaces/audioseal_wm_16bits.yaml +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -name: audioseal_wm_16bits -model_type: seanet -checkpoint: "https://dl.fbaipublicfiles.com/audioseal/6edcf62f/generator.pth" -nbits: 16 -seanet: - activation: ELU - activation_params: - alpha: 1.0 - causal: false - channels: 1 - compress: 2 - dilation_base: 2 - dimension: 128 - disable_norm_outer_blocks: 0 - kernel_size: 7 - last_kernel_size: 7 - lstm: 2 - n_filters: 32 - n_residual_layers: 1 - norm: weight_norm - norm_params: { } - pad_mode: constant - ratios: - - 8 - - 5 - - 4 - - 2 - residual_kernel_size: 3 - true_skip: true -decoder: - final_activation: null - final_activation_params: null - trim_right_ratio: 1.0 diff --git a/Architectures/Aligner/Aligner.py b/Modules/Aligner/Aligner.py similarity index 77% rename from Architectures/Aligner/Aligner.py rename to Modules/Aligner/Aligner.py index 8396061e9e4d50d95c0932641803a5cff4dd3de3..f027b773f7a2e1f8eab08a9055a04ffd42a5b456 100644 --- a/Architectures/Aligner/Aligner.py +++ b/Modules/Aligner/Aligner.py @@ -1,27 +1,31 @@ """ taken and adapted from https://github.com/as-ideas/DeepForcedAligner + +refined with insights from https://www.audiolabs-erlangen.de/resources/NLUI/2023-ICASSP-eval-alignment-tts +EVALUATING SPEECH–PHONEME ALIGNMENT AND ITS IMPACT ON NEURAL TEXT-TO-SPEECH SYNTHESIS +by Frank Zalkow, Prachi Govalkar, Meinard Muller, Emanuel A. P. Habets, Christian Dittmar """ import matplotlib.pyplot as plt import numpy as np import torch import torch.multiprocessing -import torch.nn as nn from torch.nn import CTCLoss from torch.nn.utils.rnn import pack_padded_sequence from torch.nn.utils.rnn import pad_packed_sequence from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend +from Utility.utils import make_non_pad_mask -class BatchNormConv(nn.Module): +class BatchNormConv(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: int): super().__init__() - self.conv = nn.Conv1d( + self.conv = torch.nn.Conv1d( in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, bias=False) - self.bnorm = nn.BatchNorm1d(out_channels) - self.relu = nn.ReLU() + self.bnorm = torch.nn.SyncBatchNorm.convert_sync_batchnorm(torch.nn.BatchNorm1d(out_channels)) + self.relu = torch.nn.ReLU() def forward(self, x): x = x.transpose(1, 2) @@ -37,22 +41,23 @@ class Aligner(torch.nn.Module): def __init__(self, n_features=128, num_symbols=145, - lstm_dim=512, - conv_dim=512): + conv_dim=512, + lstm_dim=512): super().__init__() - self.convs = nn.ModuleList([ + self.convs = torch.nn.ModuleList([ BatchNormConv(n_features, conv_dim, 3), - nn.Dropout(p=0.5), + torch.nn.Dropout(p=0.5), BatchNormConv(conv_dim, conv_dim, 3), - nn.Dropout(p=0.5), + torch.nn.Dropout(p=0.5), BatchNormConv(conv_dim, conv_dim, 3), - nn.Dropout(p=0.5), + torch.nn.Dropout(p=0.5), BatchNormConv(conv_dim, conv_dim, 3), - nn.Dropout(p=0.5), + torch.nn.Dropout(p=0.5), BatchNormConv(conv_dim, conv_dim, 3), - nn.Dropout(p=0.5), + torch.nn.Dropout(p=0.5), ]) - self.rnn = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True) + self.rnn1 = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True) + self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True) self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols) self.tf = ArticulatoryCombinedTextFrontend(language="eng") self.ctc_loss = CTCLoss(blank=144, zero_infinity=True) @@ -61,14 +66,17 @@ class Aligner(torch.nn.Module): def forward(self, x, lens=None): for conv in self.convs: x = conv(x) - if lens is not None: x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False) - x, _ = self.rnn(x) + x, _ = self.rnn1(x) + x, _ = self.rnn2(x) if lens is not None: x, _ = pad_packed_sequence(x, batch_first=True) x = self.proj(x) + if lens is not None: + out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(x.device) + x = x * out_masks.float() return x @@ -88,15 +96,12 @@ class Aligner(torch.nn.Module): pred_max = pred[:, tokens] # run monotonic alignment search - alignment_matrix = binarize_alignment(pred_max) if save_img_for_debug is not None: phones = list() for index in tokens: - for phone in self.tf.phone_to_id: - if self.tf.phone_to_id[phone] == index: - phones.append(phone) + phones.append(self.tf.id_to_phone[index]) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5)) ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis') @@ -115,7 +120,6 @@ class Aligner(torch.nn.Module): return alignment_matrix - def binarize_alignment(alignment_prob): """ # Implementation by: @@ -152,13 +156,5 @@ def binarize_alignment(alignment_prob): if __name__ == '__main__': - tf = ArticulatoryCombinedTextFrontend(language="eng") - from Preprocessing.HiFiCodecAudioPreprocessor import CodecAudioPreprocessor - - cap = CodecAudioPreprocessor(input_sr=-1) - dummy_codebook_indexes = torch.randint(low=0, high=1023, size=[9, 20]) - codebook_frames = cap.indexes_to_codec_frames(dummy_codebook_indexes) - alignment = Aligner().inference(codebook_frames.transpose(0, 1), tokens=tf.string_to_tensor("Hello world")) - print(alignment.shape) - plt.imshow(alignment, origin="lower", cmap="GnBu") - plt.show() + print(sum(p.numel() for p in Aligner().parameters() if p.requires_grad)) + print(Aligner()(x=torch.randn(size=[3, 30, 128]), lens=torch.LongTensor([20, 30, 10])).shape) diff --git a/Architectures/Aligner/CodecAlignerDataset.py b/Modules/Aligner/CodecAlignerDataset.py similarity index 81% rename from Architectures/Aligner/CodecAlignerDataset.py rename to Modules/Aligner/CodecAlignerDataset.py index 861353263ee42d73ecd8cbd6d7491826dc311aac..c8afa439bfec2e110f4d2a55210c462270704c0c 100644 --- a/Architectures/Aligner/CodecAlignerDataset.py +++ b/Modules/Aligner/CodecAlignerDataset.py @@ -32,6 +32,7 @@ class CodecAlignerDataset(Dataset): allow_unknown_symbols=False, gpu_count=1, rank=0): + self.gpu_count = gpu_count self.rank = rank if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache: @@ -50,9 +51,10 @@ class CodecAlignerDataset(Dataset): self.lang = lang self.device = device self.cache_dir = cache_dir - self.tf = ArticulatoryCombinedTextFrontend(language=self.lang) + self.tf = ArticulatoryCombinedTextFrontend(language=self.lang, device=device) cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu') self.speaker_embeddings = cache[2] + self.filepaths = cache[3] self.datapoints = cache[0] if self.gpu_count > 1: # we only keep a chunk of the dataset in memory to avoid redundancy. Which chunk, we figure out using the rank. @@ -85,6 +87,7 @@ class CodecAlignerDataset(Dataset): if type(path_to_transcript_dict) != dict: path_to_transcript_dict = path_to_transcript_dict() # in this case we passed a function instead of the dict, so that the function isn't executed if not necessary. torch.multiprocessing.set_start_method('spawn', force=True) + torch.multiprocessing.set_sharing_strategy('file_system') resource_manager = Manager() self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict) key_list = list(self.path_to_transcript_dict.keys()) @@ -93,6 +96,13 @@ class CodecAlignerDataset(Dataset): fisher_yates_shuffle(key_list) # build cache print("... building dataset cache ...") + torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround + # careful: assumes 16kHz or 8kHz audio + _, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', # make sure it gets downloaded during single-processing first, if it's not already downloaded + model='silero_vad', + force_reload=False, + onnx=False, + verbose=False) self.result_pool = resource_manager.list() # make processes key_splits = list() @@ -176,8 +186,8 @@ class CodecAlignerDataset(Dataset): torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets # this to false globally during model loading rather than using inference mode or no_grad silero_model = silero_model.to(device) - silence = torch.zeros([16000 // 4], device=device) - tf = ArticulatoryCombinedTextFrontend(language=lang) + silence = torch.zeros([16000 // 8]).to(device) + tf = ArticulatoryCombinedTextFrontend(language=lang, device=device) _, sr = sf.read(path_list[0]) assumed_sr = sr ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device) @@ -186,13 +196,15 @@ class CodecAlignerDataset(Dataset): for path in tqdm(path_list): if self.path_to_transcript_dict[path].strip() == "": continue - try: wave, sr = sf.read(path) except: print(f"Problem with an audio file: {path}") continue + if len(wave.shape) > 1: # oh no, we found a stereo audio! + if len(wave[0]) == 2: # let's figure out whether we need to switch the axes + wave = wave.transpose() # if yes, we switch the axes. wave = librosa.to_mono(wave) if sr != assumed_sr: @@ -210,16 +222,19 @@ class CodecAlignerDataset(Dataset): if verbose: print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.") continue - - # remove silences from front and back, then add constant 1/4th second silences back to front and back - with torch.no_grad(): + with torch.inference_mode(): speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000) try: + silence_timestamps = invert_segments(speech_timestamps, len(norm_wave)) + for silence_timestamp in silence_timestamps: + begin = silence_timestamp['start'] + end = silence_timestamp['end'] + norm_wave = torch.cat([norm_wave[:begin], torch.zeros([end - begin], device=device), norm_wave[end:]]) result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']] except IndexError: print("Audio might be too short to cut silences from front and back.") continue - wave = torch.cat([silence, result, silence]) + norm_wave = torch.cat([silence, result, silence]) # raw audio preprocessing is done transcript = self.path_to_transcript_dict[path] @@ -238,10 +253,10 @@ class CodecAlignerDataset(Dataset): # this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample. continue - cached_speech = ap.audio_to_codebook_indexes(audio=wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy() + cached_speech = ap.audio_to_codebook_indexes(audio=norm_wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy() process_internal_dataset_chunk.append([cached_text, cached_speech, - result.cpu().detach().numpy(), + norm_wave.cpu().detach().numpy(), path]) self.result_pool.append(process_internal_dataset_chunk) @@ -256,16 +271,44 @@ class CodecAlignerDataset(Dataset): codes = codes.transpose(0, 1) return tokens, \ - token_len, \ - codes, \ - None, \ - self.speaker_embeddings[index] + token_len, \ + codes, \ + None, \ + self.speaker_embeddings[index] def __len__(self): return len(self.datapoints) + def remove_samples(self, list_of_samples_to_remove): + for remove_id in sorted(list_of_samples_to_remove, reverse=True): + self.datapoints.pop(remove_id) + self.speaker_embeddings.pop(remove_id) + self.filepaths.pop(remove_id) + torch.save((self.datapoints, None, self.speaker_embeddings, self.filepaths), + os.path.join(self.cache_dir, "aligner_train_cache.pt")) + print("Dataset updated!") + def fisher_yates_shuffle(lst): for i in range(len(lst) - 1, 0, -1): j = random.randint(0, i) lst[i], lst[j] = lst[j], lst[i] + + +def invert_segments(segments, total_duration): + if not segments: + return [{'start': 0, 'end': total_duration}] + + inverted_segments = [] + previous_end = 0 + + for segment in segments: + start = segment['start'] + if previous_end < start: + inverted_segments.append({'start': previous_end, 'end': start}) + previous_end = segment['end'] + + if previous_end < total_duration: + inverted_segments.append({'start': previous_end, 'end': total_duration}) + + return inverted_segments \ No newline at end of file diff --git a/Architectures/Aligner/README.md b/Modules/Aligner/README.md similarity index 100% rename from Architectures/Aligner/README.md rename to Modules/Aligner/README.md diff --git a/Modules/Aligner/Reconstructor.py b/Modules/Aligner/Reconstructor.py new file mode 100644 index 0000000000000000000000000000000000000000..05aac1d0d958f366edef90cd74f76c17847c5172 --- /dev/null +++ b/Modules/Aligner/Reconstructor.py @@ -0,0 +1,33 @@ +import torch +import torch.multiprocessing + +from Utility.utils import make_non_pad_mask + + +class Reconstructor(torch.nn.Module): + + def __init__(self, + n_features=128, + num_symbols=145, + speaker_embedding_dim=192, + hidden_dim=256): + super().__init__() + self.in_proj = torch.nn.Linear(num_symbols + speaker_embedding_dim, hidden_dim) + self.hidden_proj = torch.nn.Linear(hidden_dim, hidden_dim) + self.out_proj = torch.nn.Linear(hidden_dim, n_features) + self.l1_criterion = torch.nn.L1Loss(reduction="none") + + def forward(self, x, lens, ys): + x = self.in_proj(x) + x = torch.nn.functional.leaky_relu(x) + x = self.hidden_proj(x) + x = torch.nn.functional.leaky_relu(x) + x = self.out_proj(x) + out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(ys.device) + out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() + out_weights /= ys.size(0) * ys.size(2) + return self.l1_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum() + + +if __name__ == '__main__': + print(sum(p.numel() for p in Reconstructor().parameters() if p.requires_grad)) \ No newline at end of file diff --git a/Architectures/Aligner/__init__.py b/Modules/Aligner/__init__.py similarity index 100% rename from Architectures/Aligner/__init__.py rename to Modules/Aligner/__init__.py diff --git a/Architectures/Aligner/autoaligner_train_loop.py b/Modules/Aligner/autoaligner_train_loop.py similarity index 97% rename from Architectures/Aligner/autoaligner_train_loop.py rename to Modules/Aligner/autoaligner_train_loop.py index cf10bb6fc4997d19426c11282738a1abebe3ce00..ed46ff03f19845d501776c0717e4a71262c6614d 100644 --- a/Architectures/Aligner/autoaligner_train_loop.py +++ b/Modules/Aligner/autoaligner_train_loop.py @@ -8,8 +8,8 @@ from torch.optim import RAdam from torch.utils.data.dataloader import DataLoader from tqdm import tqdm -from Architectures.Aligner.Aligner import Aligner -from Architectures.Aligner.Reconstructor import Reconstructor +from Modules.Aligner.Aligner import Aligner +from Modules.Aligner.Reconstructor import Reconstructor from Preprocessing.AudioPreprocessor import AudioPreprocessor from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor @@ -152,6 +152,8 @@ def train_loop(train_dataset, optim_asr.zero_grad() if use_reconstruction: optim_tts.zero_grad() + if gpu_count > 1: + torch.distributed.barrier() loss.backward() torch.nn.utils.clip_grad_norm_(asr_model.parameters(), 1.0) if use_reconstruction: diff --git a/Architectures/ControllabilityGAN/GAN.py b/Modules/ControllabilityGAN/GAN.py similarity index 66% rename from Architectures/ControllabilityGAN/GAN.py rename to Modules/ControllabilityGAN/GAN.py index 082106c99f0ca27069c3993dc57341025567c6e7..ac268cb4bd5a9955cda61d3e205e7d6e96e1a2e2 100644 --- a/Architectures/ControllabilityGAN/GAN.py +++ b/Modules/ControllabilityGAN/GAN.py @@ -1,12 +1,11 @@ import torch -from Architectures.ControllabilityGAN.wgan.init_wgan import create_wgan +from Modules.ControllabilityGAN.wgan.init_wgan import create_wgan -class GanWrapper(torch.nn.Module): +class GanWrapper: - def __init__(self, path_wgan, device, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, path_wgan, device): self.device = device self.path_wgan = path_wgan @@ -20,27 +19,41 @@ class GanWrapper(torch.nn.Module): self.U = self.compute_controllability() self.z_list = list() + for _ in range(1100): - self.z_list.append(self.wgan.G.module.sample_latent(1, 32).to("cpu")) + self.z_list.append(self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8)) self.z = self.z_list[0] def set_latent(self, seed): self.z = self.z = self.z_list[seed] def reset_default_latent(self): - self.z = self.wgan.G.module.sample_latent(1, 32).to("cpu") + self.z = self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8) def load_model(self, path): gan_checkpoint = torch.load(path, map_location="cpu") self.wgan = create_wgan(parameters=gan_checkpoint['model_parameters'], device=self.device) - self.wgan.G.load_state_dict(gan_checkpoint['generator_state_dict']) - self.wgan.D.load_state_dict(gan_checkpoint['critic_state_dict']) + # Create a new state dict without 'module.' prefix + new_state_dict_G = {} + for key, value in gan_checkpoint['generator_state_dict'].items(): + # Remove 'module.' prefix + new_key = key.replace('module.', '') + new_state_dict_G[new_key] = value + + new_state_dict_D = {} + for key, value in gan_checkpoint['critic_state_dict'].items(): + # Remove 'module.' prefix + new_key = key.replace('module.', '') + new_state_dict_D[new_key] = value + + self.wgan.G.load_state_dict(new_state_dict_G) + self.wgan.D.load_state_dict(new_state_dict_D) self.mean = gan_checkpoint["dataset_mean"] self.std = gan_checkpoint["dataset_std"] - def compute_controllability(self, n_samples=50000): + def compute_controllability(self, n_samples=100000): _, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True) intermediate = intermediate.cpu() z = z.cpu() @@ -69,7 +82,7 @@ class GanWrapper(torch.nn.Module): def modify_embed(self, x): self.wgan.G.eval() z_new = self.z.squeeze() + torch.matmul(self.U.solution.t(), x) - embed_modified = self.wgan.G.module.forward(z_new.unsqueeze(0).to(self.device)) + embed_modified = self.wgan.G.forward(z_new.unsqueeze(0).to(self.device)) if self.normalize: embed_modified = inverse_normalize( embed_modified.cpu(), diff --git a/Architectures/ControllabilityGAN/__init__.py b/Modules/ControllabilityGAN/__init__.py similarity index 100% rename from Architectures/ControllabilityGAN/__init__.py rename to Modules/ControllabilityGAN/__init__.py diff --git a/Architectures/ControllabilityGAN/dataset/__init__.py b/Modules/ControllabilityGAN/dataset/__init__.py similarity index 100% rename from Architectures/ControllabilityGAN/dataset/__init__.py rename to Modules/ControllabilityGAN/dataset/__init__.py diff --git a/Architectures/ControllabilityGAN/dataset/speaker_embeddings_dataset.py b/Modules/ControllabilityGAN/dataset/speaker_embeddings_dataset.py similarity index 100% rename from Architectures/ControllabilityGAN/dataset/speaker_embeddings_dataset.py rename to Modules/ControllabilityGAN/dataset/speaker_embeddings_dataset.py diff --git a/Architectures/ControllabilityGAN/wgan/__init__.py b/Modules/ControllabilityGAN/wgan/__init__.py similarity index 100% rename from Architectures/ControllabilityGAN/wgan/__init__.py rename to Modules/ControllabilityGAN/wgan/__init__.py diff --git a/Architectures/ControllabilityGAN/wgan/init_weights.py b/Modules/ControllabilityGAN/wgan/init_weights.py similarity index 100% rename from Architectures/ControllabilityGAN/wgan/init_weights.py rename to Modules/ControllabilityGAN/wgan/init_weights.py diff --git a/Architectures/ControllabilityGAN/wgan/init_wgan.py b/Modules/ControllabilityGAN/wgan/init_wgan.py similarity index 90% rename from Architectures/ControllabilityGAN/wgan/init_wgan.py rename to Modules/ControllabilityGAN/wgan/init_wgan.py index da345b6af6bf237daba480efad1d57247eb26dbf..9caddd6d5bbaaad0a92968dfabbf651b3b260fa7 100644 --- a/Architectures/ControllabilityGAN/wgan/init_wgan.py +++ b/Modules/ControllabilityGAN/wgan/init_wgan.py @@ -1,7 +1,7 @@ import torch -from Architectures.ControllabilityGAN.wgan.resnet_init import init_resnet -from Architectures.ControllabilityGAN.wgan.wgan_qc import WassersteinGanQuadraticCost +from Modules.ControllabilityGAN.wgan.resnet_init import init_resnet +from Modules.ControllabilityGAN.wgan.wgan_qc import WassersteinGanQuadraticCost def create_wgan(parameters, device, optimizer='adam'): diff --git a/Architectures/ControllabilityGAN/wgan/resnet_1.py b/Modules/ControllabilityGAN/wgan/resnet_1.py similarity index 97% rename from Architectures/ControllabilityGAN/wgan/resnet_1.py rename to Modules/ControllabilityGAN/wgan/resnet_1.py index 6f5aef2f79aca66a4b61675718bde809b1ef9b48..ffd5e75b461ee876017d713b7e07ba1104b4fae0 100644 --- a/Architectures/ControllabilityGAN/wgan/resnet_1.py +++ b/Modules/ControllabilityGAN/wgan/resnet_1.py @@ -76,8 +76,8 @@ class ResNet_G(nn.Module): return out, l_1 return out - def sample_latent(self, n_samples, z_size): - return torch.randn((n_samples, z_size)) + def sample_latent(self, n_samples, z_size, temperature=0.7): + return torch.randn((n_samples, z_size)) * temperature class ResNet_D(nn.Module): diff --git a/Architectures/ControllabilityGAN/wgan/resnet_init.py b/Modules/ControllabilityGAN/wgan/resnet_init.py similarity index 61% rename from Architectures/ControllabilityGAN/wgan/resnet_init.py rename to Modules/ControllabilityGAN/wgan/resnet_init.py index 0ff7e8f228f15d1de629a8c2355cf73f1908681b..890bece8156fd627419c1ce0bcba2b48b1a7532a 100644 --- a/Architectures/ControllabilityGAN/wgan/resnet_init.py +++ b/Modules/ControllabilityGAN/wgan/resnet_init.py @@ -1,7 +1,7 @@ -from Architectures.ControllabilityGAN.wgan.init_weights import weights_init_D -from Architectures.ControllabilityGAN.wgan.init_weights import weights_init_G -from Architectures.ControllabilityGAN.wgan.resnet_1 import ResNet_D -from Architectures.ControllabilityGAN.wgan.resnet_1 import ResNet_G +from Modules.ControllabilityGAN.wgan.init_weights import weights_init_D +from Modules.ControllabilityGAN.wgan.init_weights import weights_init_G +from Modules.ControllabilityGAN.wgan.resnet_1 import ResNet_D +from Modules.ControllabilityGAN.wgan.resnet_1 import ResNet_G def init_resnet(parameters): diff --git a/Architectures/ControllabilityGAN/wgan/wgan_qc.py b/Modules/ControllabilityGAN/wgan/wgan_qc.py similarity index 93% rename from Architectures/ControllabilityGAN/wgan/wgan_qc.py rename to Modules/ControllabilityGAN/wgan/wgan_qc.py index 27fd608adb000aacd005cde52eac1b58c0ee19f6..a7e99d057bf7a6f68627c627cdcc998cb3b6db04 100644 --- a/Architectures/ControllabilityGAN/wgan/wgan_qc.py +++ b/Modules/ControllabilityGAN/wgan/wgan_qc.py @@ -3,7 +3,6 @@ import time import numpy as np import torch -import torch.nn as nn import torch.optim as optim from cvxopt import matrix from cvxopt import solvers @@ -11,13 +10,12 @@ from cvxopt import sparse from cvxopt import spmatrix from torch.autograd import grad as torch_grad from tqdm import tqdm -import spaces -class WassersteinGanQuadraticCost(torch.nn.Module): +class WassersteinGanQuadraticCost: - def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, criterion, epochs, n_max_iterations, data_dimensions, batch_size, device, gamma=0.1, K=-1, milestones=[150000, 250000], lr_anneal=1.0, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, criterion, epochs, n_max_iterations, + data_dimensions, batch_size, device, gamma=0.1, K=-1, milestones=[150000, 250000], lr_anneal=1.0): self.G = generator self.G_opt = gen_optimizer self.D = discriminator @@ -46,8 +44,8 @@ class WassersteinGanQuadraticCost(torch.nn.Module): self.Kr = np.sqrt(self.K) self.LAMBDA = 2 * self.Kr * gamma * 2 - self.G = nn.DataParallel(self.G.to(self.device)) - self.D = nn.DataParallel(self.D.to(self.device)) + self.G = self.G.to(self.device) + self.D = self.D.to(self.device) self.schedulerD = self._build_lr_scheduler_(self.D_opt, milestones, lr_anneal) self.schedulerG = self._build_lr_scheduler_(self.G_opt, milestones, lr_anneal) @@ -245,10 +243,7 @@ class WassersteinGanQuadraticCost(torch.nn.Module): latent_samples = latent_samples.to(self.device) if nograd: with torch.no_grad(): - if isinstance(self.G, torch.nn.parallel.DataParallel): - generated_data = self.G.module(latent_samples, return_intermediate=return_intermediate) - else: - generated_data = self.G(latent_samples, return_intermediate=return_intermediate) + generated_data = self.G(latent_samples, return_intermediate=return_intermediate) else: generated_data = self.G(latent_samples) self.G.train() diff --git a/Architectures/EmbeddingModel/GST.py b/Modules/EmbeddingModel/GST.py similarity index 99% rename from Architectures/EmbeddingModel/GST.py rename to Modules/EmbeddingModel/GST.py index b0f2435bd91aaf93a3de3ab6c779409972f0b907..eb23d4ad9d4cdf800119969fb0811bf5126f616e 100644 --- a/Architectures/EmbeddingModel/GST.py +++ b/Modules/EmbeddingModel/GST.py @@ -3,7 +3,7 @@ import torch -from Architectures.GeneralLayers.Attention import MultiHeadedAttention as BaseMultiHeadedAttention +from Modules.GeneralLayers.Attention import MultiHeadedAttention as BaseMultiHeadedAttention class GSTStyleEncoder(torch.nn.Module): diff --git a/Architectures/EmbeddingModel/README.md b/Modules/EmbeddingModel/README.md similarity index 100% rename from Architectures/EmbeddingModel/README.md rename to Modules/EmbeddingModel/README.md diff --git a/Architectures/EmbeddingModel/StyleEmbedding.py b/Modules/EmbeddingModel/StyleEmbedding.py similarity index 95% rename from Architectures/EmbeddingModel/StyleEmbedding.py rename to Modules/EmbeddingModel/StyleEmbedding.py index d7154010e3b9c4945dc76a6cefdcb0fc8541ef06..46a2120a29e2a8b4407ff7ae520123a13e14adc0 100644 --- a/Architectures/EmbeddingModel/StyleEmbedding.py +++ b/Modules/EmbeddingModel/StyleEmbedding.py @@ -1,7 +1,7 @@ import torch -from Architectures.EmbeddingModel.GST import GSTStyleEncoder -from Architectures.EmbeddingModel.StyleTTSEncoder import StyleEncoder as StyleTTSEncoder +from Modules.EmbeddingModel.GST import GSTStyleEncoder +from Modules.EmbeddingModel.StyleTTSEncoder import StyleEncoder as StyleTTSEncoder class StyleEmbedding(torch.nn.Module): diff --git a/Architectures/EmbeddingModel/StyleTTSEncoder.py b/Modules/EmbeddingModel/StyleTTSEncoder.py similarity index 100% rename from Architectures/EmbeddingModel/StyleTTSEncoder.py rename to Modules/EmbeddingModel/StyleTTSEncoder.py diff --git a/Architectures/EmbeddingModel/__init__.py b/Modules/EmbeddingModel/__init__.py similarity index 100% rename from Architectures/EmbeddingModel/__init__.py rename to Modules/EmbeddingModel/__init__.py diff --git a/Architectures/GeneralLayers/Attention.py b/Modules/GeneralLayers/Attention.py similarity index 100% rename from Architectures/GeneralLayers/Attention.py rename to Modules/GeneralLayers/Attention.py diff --git a/Architectures/GeneralLayers/ConditionalLayerNorm.py b/Modules/GeneralLayers/ConditionalLayerNorm.py similarity index 100% rename from Architectures/GeneralLayers/ConditionalLayerNorm.py rename to Modules/GeneralLayers/ConditionalLayerNorm.py diff --git a/Architectures/GeneralLayers/Conformer.py b/Modules/GeneralLayers/Conformer.py similarity index 79% rename from Architectures/GeneralLayers/Conformer.py rename to Modules/GeneralLayers/Conformer.py index 21a9e852b783c5243dc17e7ec3b1150f3734a3cc..0489d8f1116e4b9749d966c44d87e09602b8fe28 100644 --- a/Architectures/GeneralLayers/Conformer.py +++ b/Modules/GeneralLayers/Conformer.py @@ -4,16 +4,16 @@ Taken from ESPNet, but heavily modified import torch -from Architectures.GeneralLayers.Attention import RelPositionMultiHeadedAttention -from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d -from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm -from Architectures.GeneralLayers.Convolution import ConvolutionModule -from Architectures.GeneralLayers.EncoderLayer import EncoderLayer -from Architectures.GeneralLayers.LayerNorm import LayerNorm -from Architectures.GeneralLayers.MultiLayeredConv1d import MultiLayeredConv1d -from Architectures.GeneralLayers.MultiSequential import repeat -from Architectures.GeneralLayers.PositionalEncoding import RelPositionalEncoding -from Architectures.GeneralLayers.Swish import Swish +from Modules.GeneralLayers.Attention import RelPositionMultiHeadedAttention +from Modules.GeneralLayers.ConditionalLayerNorm import AdaIN1d +from Modules.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm +from Modules.GeneralLayers.Convolution import ConvolutionModule +from Modules.GeneralLayers.EncoderLayer import EncoderLayer +from Modules.GeneralLayers.LayerNorm import LayerNorm +from Modules.GeneralLayers.MultiLayeredConv1d import MultiLayeredConv1d +from Modules.GeneralLayers.MultiSequential import repeat +from Modules.GeneralLayers.PositionalEncoding import RelPositionalEncoding +from Modules.GeneralLayers.Swish import Swish from Utility.utils import integrate_with_utt_embed @@ -84,8 +84,12 @@ class Conformer(torch.nn.Module): self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: torch.nn.Linear(attention_dim + utt_embed, attention_dim)) if lang_embs is not None: self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=lang_emb_size) - self.language_embedding_projection = torch.nn.Linear(lang_emb_size, attention_dim) + if lang_emb_size == attention_dim: + self.language_embedding_projection = lambda x: x + else: + self.language_embedding_projection = torch.nn.Linear(lang_emb_size, attention_dim) self.language_emb_norm = LayerNorm(attention_dim) + # self-attention module definition encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu) @@ -138,21 +142,28 @@ class Conformer(torch.nn.Module): if isinstance(xs, tuple): x, pos_emb = xs[0], xs[1] if self.conformer_type != "encoder": - x = integrate_with_utt_embed(hs=x, utt_embeddings=utterance_embedding, projection=self.decoder_embedding_projections[encoder_index], embedding_training=self.use_conditional_layernorm_embedding_integration) + x = integrate_with_utt_embed(hs=x, + utt_embeddings=utterance_embedding, + projection=self.decoder_embedding_projections[encoder_index], + embedding_training=self.use_conditional_layernorm_embedding_integration) xs = (x, pos_emb) else: if self.conformer_type != "encoder": - xs = integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding, projection=self.decoder_embedding_projections[encoder_index], embedding_training=self.use_conditional_layernorm_embedding_integration) + xs = integrate_with_utt_embed(hs=xs, + utt_embeddings=utterance_embedding, + projection=self.decoder_embedding_projections[encoder_index], + embedding_training=self.use_conditional_layernorm_embedding_integration) xs, masks = encoder(xs, masks) if isinstance(xs, tuple): xs = xs[0] - if self.use_output_norm and not (self.utt_embed and self.conformer_type == "encoder"): - xs = self.output_norm(xs) - if self.utt_embed and self.conformer_type == "encoder": - xs = integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding, - projection=self.encoder_embedding_projection, embedding_training=self.use_conditional_layernorm_embedding_integration) + xs = integrate_with_utt_embed(hs=xs, + utt_embeddings=utterance_embedding, + projection=self.encoder_embedding_projection, + embedding_training=self.use_conditional_layernorm_embedding_integration) + elif self.use_output_norm: + xs = self.output_norm(xs) return xs, masks diff --git a/Architectures/GeneralLayers/Convolution.py b/Modules/GeneralLayers/Convolution.py similarity index 95% rename from Architectures/GeneralLayers/Convolution.py rename to Modules/GeneralLayers/Convolution.py index 29d6c8c17a8c66fa3e57fdb02b767fc2870f2269..12f23c81f71aae773ddcde1bb84a156d656ea492 100644 --- a/Architectures/GeneralLayers/Convolution.py +++ b/Modules/GeneralLayers/Convolution.py @@ -24,7 +24,7 @@ class ConvolutionModule(nn.Module): self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, ) self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, ) - self.norm = nn.BatchNorm1d(channels) + self.norm = nn.SyncBatchNorm.convert_sync_batchnorm(nn.BatchNorm1d(channels)) self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, ) self.activation = activation diff --git a/Architectures/GeneralLayers/DurationPredictor.py b/Modules/GeneralLayers/DurationPredictor.py similarity index 96% rename from Architectures/GeneralLayers/DurationPredictor.py rename to Modules/GeneralLayers/DurationPredictor.py index 871f3bb2e1e2f571ab2e24f3233e9a22d009b043..5a6e25739d16ff6dfb55b86c5ce0fca0e3a365ac 100644 --- a/Architectures/GeneralLayers/DurationPredictor.py +++ b/Modules/GeneralLayers/DurationPredictor.py @@ -5,9 +5,9 @@ import torch -from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d -from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm -from Architectures.GeneralLayers.LayerNorm import LayerNorm +from Modules.GeneralLayers.ConditionalLayerNorm import AdaIN1d +from Modules.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm +from Modules.GeneralLayers.LayerNorm import LayerNorm from Utility.utils import integrate_with_utt_embed diff --git a/Architectures/GeneralLayers/EncoderLayer.py b/Modules/GeneralLayers/EncoderLayer.py similarity index 98% rename from Architectures/GeneralLayers/EncoderLayer.py rename to Modules/GeneralLayers/EncoderLayer.py index e21d35ae120527e049c37fd9516fd0d860ea5e7f..8008358fe8e306752b9d45647f824ca758de491e 100644 --- a/Architectures/GeneralLayers/EncoderLayer.py +++ b/Modules/GeneralLayers/EncoderLayer.py @@ -7,7 +7,7 @@ import torch from torch import nn -from Architectures.GeneralLayers.LayerNorm import LayerNorm +from Modules.GeneralLayers.LayerNorm import LayerNorm class EncoderLayer(nn.Module): diff --git a/Architectures/GeneralLayers/LayerNorm.py b/Modules/GeneralLayers/LayerNorm.py similarity index 100% rename from Architectures/GeneralLayers/LayerNorm.py rename to Modules/GeneralLayers/LayerNorm.py diff --git a/Architectures/GeneralLayers/LengthRegulator.py b/Modules/GeneralLayers/LengthRegulator.py similarity index 100% rename from Architectures/GeneralLayers/LengthRegulator.py rename to Modules/GeneralLayers/LengthRegulator.py diff --git a/Architectures/GeneralLayers/MultiLayeredConv1d.py b/Modules/GeneralLayers/MultiLayeredConv1d.py similarity index 100% rename from Architectures/GeneralLayers/MultiLayeredConv1d.py rename to Modules/GeneralLayers/MultiLayeredConv1d.py diff --git a/Architectures/GeneralLayers/MultiSequential.py b/Modules/GeneralLayers/MultiSequential.py similarity index 100% rename from Architectures/GeneralLayers/MultiSequential.py rename to Modules/GeneralLayers/MultiSequential.py diff --git a/Architectures/GeneralLayers/PositionalEncoding.py b/Modules/GeneralLayers/PositionalEncoding.py similarity index 100% rename from Architectures/GeneralLayers/PositionalEncoding.py rename to Modules/GeneralLayers/PositionalEncoding.py diff --git a/Architectures/GeneralLayers/PositionwiseFeedForward.py b/Modules/GeneralLayers/PositionwiseFeedForward.py similarity index 100% rename from Architectures/GeneralLayers/PositionwiseFeedForward.py rename to Modules/GeneralLayers/PositionwiseFeedForward.py diff --git a/Architectures/GeneralLayers/README.md b/Modules/GeneralLayers/README.md similarity index 100% rename from Architectures/GeneralLayers/README.md rename to Modules/GeneralLayers/README.md diff --git a/Architectures/GeneralLayers/ResidualBlock.py b/Modules/GeneralLayers/ResidualBlock.py similarity index 100% rename from Architectures/GeneralLayers/ResidualBlock.py rename to Modules/GeneralLayers/ResidualBlock.py diff --git a/Architectures/GeneralLayers/ResidualStack.py b/Modules/GeneralLayers/ResidualStack.py similarity index 100% rename from Architectures/GeneralLayers/ResidualStack.py rename to Modules/GeneralLayers/ResidualStack.py diff --git a/Architectures/GeneralLayers/STFT.py b/Modules/GeneralLayers/STFT.py similarity index 100% rename from Architectures/GeneralLayers/STFT.py rename to Modules/GeneralLayers/STFT.py diff --git a/Architectures/GeneralLayers/Swish.py b/Modules/GeneralLayers/Swish.py similarity index 100% rename from Architectures/GeneralLayers/Swish.py rename to Modules/GeneralLayers/Swish.py diff --git a/Architectures/GeneralLayers/VariancePredictor.py b/Modules/GeneralLayers/VariancePredictor.py similarity index 94% rename from Architectures/GeneralLayers/VariancePredictor.py rename to Modules/GeneralLayers/VariancePredictor.py index 52d2effd53c4da6cb4009c641daa11f8e8b6711f..ad187b15342a8b4b001ce971adfa1ace9f51948d 100644 --- a/Architectures/GeneralLayers/VariancePredictor.py +++ b/Modules/GeneralLayers/VariancePredictor.py @@ -6,9 +6,9 @@ from abc import ABC import torch -from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d -from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm -from Architectures.GeneralLayers.LayerNorm import LayerNorm +from Modules.GeneralLayers.ConditionalLayerNorm import AdaIN1d +from Modules.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm +from Modules.GeneralLayers.LayerNorm import LayerNorm from Utility.utils import integrate_with_utt_embed diff --git a/Architectures/GeneralLayers/__init__.py b/Modules/GeneralLayers/__init__.py similarity index 100% rename from Architectures/GeneralLayers/__init__.py rename to Modules/GeneralLayers/__init__.py diff --git a/Architectures/README.md b/Modules/README.md similarity index 100% rename from Architectures/README.md rename to Modules/README.md diff --git a/Architectures/ToucanTTS/CodecDiscriminator.py b/Modules/ToucanTTS/CodecDiscriminator.py similarity index 100% rename from Architectures/ToucanTTS/CodecDiscriminator.py rename to Modules/ToucanTTS/CodecDiscriminator.py diff --git a/Architectures/ToucanTTS/CodecRefinementTransformer.py b/Modules/ToucanTTS/CodecRefinementTransformer.py similarity index 99% rename from Architectures/ToucanTTS/CodecRefinementTransformer.py rename to Modules/ToucanTTS/CodecRefinementTransformer.py index 98a19c6865176631ba736acfb691c59f8934a03a..f73176d60e264a92ca74b5912d074af77810b8b5 100644 --- a/Architectures/ToucanTTS/CodecRefinementTransformer.py +++ b/Modules/ToucanTTS/CodecRefinementTransformer.py @@ -1,6 +1,6 @@ import torch -from Architectures.GeneralLayers.Conformer import Conformer +from Modules.GeneralLayers.Conformer import Conformer class CodecRefinementTransformer(torch.nn.Module): @@ -151,7 +151,7 @@ def one_hot_sequence_to_token_sequence(batch_of_indexes_one_hot_per_codebook): if __name__ == '__main__': - from Architectures.ToucanTTS.ToucanTTS import ToucanTTS + from Modules.ToucanTTS.ToucanTTS import ToucanTTS from Utility.utils import make_pad_mask # prepare dummy inputs diff --git a/Architectures/ToucanTTS/DurationCalculator.py b/Modules/ToucanTTS/DurationCalculator.py similarity index 100% rename from Architectures/ToucanTTS/DurationCalculator.py rename to Modules/ToucanTTS/DurationCalculator.py diff --git a/Architectures/ToucanTTS/EnergyCalculator.py b/Modules/ToucanTTS/EnergyCalculator.py similarity index 98% rename from Architectures/ToucanTTS/EnergyCalculator.py rename to Modules/ToucanTTS/EnergyCalculator.py index 373bcb902a90e57e1df05204a4d3c0fb9d185419..6861dc57028da046c9855ec94ccffe5f5c6be937 100644 --- a/Architectures/ToucanTTS/EnergyCalculator.py +++ b/Modules/ToucanTTS/EnergyCalculator.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F -from Architectures.GeneralLayers.STFT import STFT +from Modules.GeneralLayers.STFT import STFT from Utility.utils import pad_list diff --git a/Architectures/ToucanTTS/Glow.py b/Modules/ToucanTTS/Glow.py similarity index 99% rename from Architectures/ToucanTTS/Glow.py rename to Modules/ToucanTTS/Glow.py index fb5329ef4d7136fe175d005a1710679f53440ffd..fe9a2369cc07c181befb2a2aff8e587ca63bcf9d 100644 --- a/Architectures/ToucanTTS/Glow.py +++ b/Modules/ToucanTTS/Glow.py @@ -5,8 +5,8 @@ import torch.distributions as dist from torch import nn from torch.nn import functional as F -from Architectures.ToucanTTS import glow_utils -from Architectures.ToucanTTS.wavenet import WN +from Modules.ToucanTTS import glow_utils +from Modules.ToucanTTS.wavenet import WN class ActNorm(nn.Module): @@ -339,7 +339,7 @@ class Glow(nn.Module): use_weightnorm=use_weightnorm )) - def forward(self, tgt_mels, infer, mel_out, encoded_texts, tgt_nonpadding, glow_sampling_temperature=0.2): + def forward(self, tgt_mels, infer, mel_out, encoded_texts, tgt_nonpadding, glow_sampling_temperature=0.7): x_recon = mel_out.transpose(1, 2) g = x_recon B, _, T = g.shape diff --git a/Architectures/ToucanTTS/InferenceToucanTTS.py b/Modules/ToucanTTS/InferenceToucanTTS.py similarity index 69% rename from Architectures/ToucanTTS/InferenceToucanTTS.py rename to Modules/ToucanTTS/InferenceToucanTTS.py index 78491b83e52d1685450e6f9db4c9448374c3f143..0e593d72b2aba20aaa7ca6ab4c3bfb0cb9ce2bf2 100644 --- a/Architectures/ToucanTTS/InferenceToucanTTS.py +++ b/Modules/ToucanTTS/InferenceToucanTTS.py @@ -1,18 +1,14 @@ import dotwiz import torch +import torch.nn.functional as torchfunc from torch.nn import Linear from torch.nn import Sequential from torch.nn import Tanh -from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d -from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm -from Architectures.GeneralLayers.Conformer import Conformer -from Architectures.GeneralLayers.DurationPredictor import DurationPredictor -from Architectures.GeneralLayers.LengthRegulator import LengthRegulator -from Architectures.GeneralLayers.VariancePredictor import VariancePredictor -from Architectures.ToucanTTS.Glow import Glow +from Modules.GeneralLayers.Conformer import Conformer +from Modules.GeneralLayers.LengthRegulator import LengthRegulator +from Modules.ToucanTTS.flow_matching import CFMDecoder from Preprocessing.articulatory_features import get_feature_to_index_lookup -from Utility.utils import integrate_with_utt_embed from Utility.utils import make_non_pad_mask @@ -62,14 +58,24 @@ class ToucanTTS(torch.nn.Module): energy_predictor_dropout = config.energy_predictor_dropout energy_embed_kernel_size = config.energy_embed_kernel_size energy_embed_dropout = config.energy_embed_dropout + cfm_filter_channels = config.cfm_filter_channels + cfm_heads = config.cfm_heads + cfm_layers = config.cfm_layers + cfm_kernel_size = config.cfm_kernel_size + cfm_p_dropout = config.cfm_p_dropout utt_embed_dim = config.utt_embed_dim lang_embs = config.lang_embs + spec_channels = config.spec_channels embedding_integration = config.embedding_integration - glow_kernel_size = config.glow_kernel_size - glow_blocks = config.glow_blocks - glow_layers = config.glow_layers lang_emb_size = config.lang_emb_size integrate_language_embedding_into_encoder_out = config.integrate_language_embedding_into_encoder_out + prosody_channels = config.prosody_channels + + if lang_embs is None or lang_embs == 0: + lang_embs = None + integrate_language_embedding_into_encoder_out = False + if integrate_language_embedding_into_encoder_out: + utt_embed_dim = utt_embed_dim + lang_emb_size self.input_feature_dimensions = input_feature_dimensions self.attention_dimension = attention_dimension @@ -102,37 +108,32 @@ class ToucanTTS(torch.nn.Module): use_output_norm=True, embedding_integration=embedding_integration) - if self.integrate_language_embedding_into_encoder_out: - if embedding_integration == "AdaIN": - self.language_embedding_infusion = AdaIN1d(style_dim=lang_emb_size, num_features=attention_dimension) - elif embedding_integration == "ConditionalLayerNorm": - self.language_embedding_infusion = ConditionalLayerNorm(speaker_embedding_dim=lang_emb_size, hidden_dim=attention_dimension) - else: - self.language_embedding_infusion = torch.nn.Linear(attention_dimension + lang_emb_size, attention_dimension) - - self.duration_predictor = DurationPredictor(idim=attention_dimension, - n_layers=duration_predictor_layers, - n_chans=attention_dimension, - kernel_size=duration_predictor_kernel_size, - dropout_rate=duration_predictor_dropout_rate, - utt_embed_dim=utt_embed_dim, - embedding_integration=embedding_integration) - - self.pitch_predictor = VariancePredictor(idim=attention_dimension, - n_layers=pitch_predictor_layers, - n_chans=attention_dimension, - kernel_size=pitch_predictor_kernel_size, - dropout_rate=pitch_predictor_dropout, - utt_embed_dim=utt_embed_dim, - embedding_integration=embedding_integration) - - self.energy_predictor = VariancePredictor(idim=attention_dimension, - n_layers=energy_predictor_layers, - n_chans=attention_dimension, - kernel_size=energy_predictor_kernel_size, - dropout_rate=energy_predictor_dropout, - utt_embed_dim=utt_embed_dim, - embedding_integration=embedding_integration) + self.duration_predictor = CFMDecoder(hidden_channels=prosody_channels, + out_channels=1, + filter_channels=prosody_channels, + n_heads=1, + n_layers=duration_predictor_layers, + kernel_size=duration_predictor_kernel_size, + p_dropout=duration_predictor_dropout_rate, + gin_channels=utt_embed_dim) + + self.pitch_predictor = CFMDecoder(hidden_channels=prosody_channels, + out_channels=1, + filter_channels=prosody_channels, + n_heads=1, + n_layers=pitch_predictor_layers, + kernel_size=pitch_predictor_kernel_size, + p_dropout=pitch_predictor_dropout, + gin_channels=utt_embed_dim) + + self.energy_predictor = CFMDecoder(hidden_channels=prosody_channels, + out_channels=1, + filter_channels=prosody_channels, + n_heads=1, + n_layers=energy_predictor_layers, + kernel_size=energy_predictor_kernel_size, + p_dropout=energy_predictor_dropout, + gin_channels=utt_embed_dim) self.pitch_embed = Sequential(torch.nn.Conv1d(in_channels=1, out_channels=attention_dimension, @@ -167,24 +168,19 @@ class ToucanTTS(torch.nn.Module): utt_embed=utt_embed_dim, embedding_integration=embedding_integration) - self.output_projection = torch.nn.Linear(attention_dimension, 128) - - self.post_flow = Glow( - in_channels=128, - hidden_channels=attention_dimension, # post_glow_hidden - kernel_size=glow_kernel_size, # post_glow_kernel_size - dilation_rate=1, - n_blocks=glow_blocks, # post_glow_n_blocks (original 12 in paper) - n_layers=glow_layers, # post_glow_n_block_layers (original 3 in paper) - n_split=4, - n_sqz=2, - text_condition_channels=attention_dimension, - share_cond_layers=False, # post_share_cond_layers - share_wn_layers=4, - sigmoid_scale=False, - condition_integration_projection=torch.nn.Conv1d(128 + attention_dimension, attention_dimension, 5, padding=2) - ) - + self.output_projection = torch.nn.Linear(attention_dimension, spec_channels) + self.pitch_latent_reduction = torch.nn.Linear(attention_dimension, prosody_channels) + self.energy_latent_reduction = torch.nn.Linear(attention_dimension, prosody_channels) + self.duration_latent_reduction = torch.nn.Linear(attention_dimension, prosody_channels) + + self.flow_matching_decoder = CFMDecoder(hidden_channels=spec_channels, + out_channels=spec_channels, + filter_channels=cfm_filter_channels, + n_heads=cfm_heads, + n_layers=cfm_layers, + kernel_size=cfm_kernel_size, + p_dropout=cfm_p_dropout, + gin_channels=utt_embed_dim) self.load_state_dict(weights) self.eval() @@ -200,28 +196,53 @@ class ToucanTTS(torch.nn.Module): pitch_variance_scale=1.0, energy_variance_scale=1.0, pause_duration_scaling_factor=1.0, - glow_sampling_temperature=0.2): + prosody_creativity=0.1): + + text_tensors = torch.clamp(text_tensors, max=1.0) + # this is necessary, because of the way we represent modifiers to keep them identifiable. if not self.multilingual_model: lang_ids = None if not self.multispeaker_model: utterance_embedding = None - else: + + if utterance_embedding is not None: utterance_embedding = torch.nn.functional.normalize(utterance_embedding) + if self.integrate_language_embedding_into_encoder_out and lang_ids is not None: + lang_embs = self.encoder.language_embedding(lang_ids) + lang_embs = torch.nn.functional.normalize(lang_embs) + utterance_embedding = torch.cat([lang_embs, utterance_embedding], dim=1).detach() # encoding the texts text_masks = make_non_pad_mask(text_lengths, device=text_lengths.device).unsqueeze(-2) encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) - if self.integrate_language_embedding_into_encoder_out: - lang_embs = self.encoder.language_embedding(lang_ids).squeeze(-1).detach() - encoded_texts = integrate_with_utt_embed(hs=encoded_texts, utt_embeddings=lang_embs, projection=self.language_embedding_infusion, embedding_training=self.use_conditional_layernorm_embedding_integration) - # predicting pitch, energy and durations - pitch_predictions = self.pitch_predictor(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_pitch is None else gold_pitch - energy_predictions = self.energy_predictor(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_energy is None else gold_energy - predicted_durations = self.duration_predictor.inference(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_durations is None else gold_durations + reduced_pitch_space = torchfunc.dropout(self.pitch_latent_reduction(encoded_texts), p=0.1).transpose(1, 2) + pitch_predictions = self.pitch_predictor(mu=reduced_pitch_space, + mask=text_masks.float(), + n_timesteps=10, + temperature=prosody_creativity, + c=utterance_embedding) if gold_pitch is None else gold_pitch + pitch_predictions = _scale_variance(pitch_predictions, pitch_variance_scale) + embedded_pitch_curve = self.pitch_embed(pitch_predictions).transpose(1, 2) + + reduced_energy_space = torchfunc.dropout(self.energy_latent_reduction(encoded_texts + embedded_pitch_curve), p=0.1).transpose(1, 2) + energy_predictions = self.energy_predictor(mu=reduced_energy_space, + mask=text_masks.float(), + n_timesteps=10, + temperature=prosody_creativity, + c=utterance_embedding) if gold_energy is None else gold_energy + energy_predictions = _scale_variance(energy_predictions, energy_variance_scale) + embedded_energy_curve = self.energy_embed(energy_predictions).transpose(1, 2) + + reduced_duration_space = torchfunc.dropout(self.duration_latent_reduction(encoded_texts + embedded_pitch_curve + embedded_energy_curve), p=0.1).transpose(1, 2) + predicted_durations = torch.clamp(torch.ceil(self.duration_predictor(mu=reduced_duration_space, + mask=text_masks.float(), + n_timesteps=10, + temperature=prosody_creativity, + c=utterance_embedding)), min=0.0).long().squeeze(1) if gold_durations is None else gold_durations # modifying the predictions with control parameters for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)): @@ -230,16 +251,10 @@ class ToucanTTS(torch.nn.Module): if phoneme_vector[get_feature_to_index_lookup()["silence"]] == 1 and pause_duration_scaling_factor != 1.0: predicted_durations[0][phoneme_index] = torch.round(predicted_durations[0][phoneme_index].float() * pause_duration_scaling_factor).long() if duration_scaling_factor != 1.0: - assert duration_scaling_factor > 0 + assert duration_scaling_factor > 0.0 predicted_durations = torch.round(predicted_durations.float() * duration_scaling_factor).long() - pitch_predictions = make_near_zero_to_zero(pitch_predictions.squeeze(0)).unsqueeze(0) - energy_predictions = make_near_zero_to_zero(energy_predictions.squeeze(0)).unsqueeze(0) - pitch_predictions = _scale_variance(pitch_predictions, pitch_variance_scale) - energy_predictions = _scale_variance(energy_predictions, energy_variance_scale) # enriching the text with pitch and energy info - embedded_pitch_curve = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2) - embedded_energy_curve = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2) enriched_encoded_texts = encoded_texts + embedded_pitch_curve + embedded_energy_curve # predicting durations for text and upsampling accordingly @@ -248,9 +263,13 @@ class ToucanTTS(torch.nn.Module): # decoding spectrogram decoded_speech, _ = self.decoder(upsampled_enriched_encoded_texts, None, utterance_embedding=utterance_embedding) - frames = self.output_projection(decoded_speech) + preliminary_spectrogram = self.output_projection(decoded_speech) - refined_codec_frames = self.post_flow(tgt_mels=None, infer=True, mel_out=frames, encoded_texts=upsampled_enriched_encoded_texts, tgt_nonpadding=None, glow_sampling_temperature=glow_sampling_temperature) + refined_codec_frames = self.flow_matching_decoder(mu=preliminary_spectrogram.transpose(1, 2), + mask=make_non_pad_mask([len(decoded_speech[0])], device=decoded_speech.device).unsqueeze(-2), + n_timesteps=15, + temperature=0.1, # low temperature, so the model follows the specified prosody curves better. + c=None).transpose(1, 2) return refined_codec_frames, predicted_durations.squeeze(), pitch_predictions.squeeze(), energy_predictions.squeeze() @@ -267,7 +286,7 @@ class ToucanTTS(torch.nn.Module): pitch_variance_scale=1.0, energy_variance_scale=1.0, pause_duration_scaling_factor=1.0, - glow_sampling_temperature=0.2): + prosody_creativity=0.1): """ Generate the sequence of spectrogram frames given the sequence of vectorized phonemes. @@ -319,7 +338,7 @@ class ToucanTTS(torch.nn.Module): pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, pause_duration_scaling_factor=pause_duration_scaling_factor, - glow_sampling_temperature=glow_sampling_temperature) + prosody_creativity=prosody_creativity) if return_duration_pitch_energy: return outs.squeeze().transpose(0, 1), predicted_durations, pitch_predictions, energy_predictions @@ -331,7 +350,8 @@ class ToucanTTS(torch.nn.Module): torch.nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return - self.post_flow.store_inverse() + + # self.post_flow.store_inverse() # we're no longer using glow, so this is deprecated self.apply(remove_weight_norm) @@ -342,9 +362,9 @@ def _scale_variance(sequence, scale): sequence = sequence - average # center sequence around 0 sequence = sequence * scale # scale the variance sequence = sequence + average # move center back to original with changed variance - for sequence_index in range(len(sequence[0])): - if sequence[0][sequence_index] < 0.0: - sequence[0][sequence_index] = 0.0 + for sequence_index in range(len(sequence[0][0])): + if sequence[0][0][sequence_index] < 0.0: + sequence[0][0][sequence_index] = 0.0 return sequence diff --git a/Architectures/ToucanTTS/LanguageEmbeddingSpaceStructureLoss.py b/Modules/ToucanTTS/LanguageEmbeddingSpaceStructureLoss.py similarity index 81% rename from Architectures/ToucanTTS/LanguageEmbeddingSpaceStructureLoss.py rename to Modules/ToucanTTS/LanguageEmbeddingSpaceStructureLoss.py index 2206f8a80ee7938334717110b29d631377a1c7d6..295f23153b0d74e06be49788e11dbbb1dabcc310 100644 --- a/Architectures/ToucanTTS/LanguageEmbeddingSpaceStructureLoss.py +++ b/Modules/ToucanTTS/LanguageEmbeddingSpaceStructureLoss.py @@ -1,5 +1,4 @@ import os.path -import pickle import torch @@ -16,14 +15,12 @@ class LanguageEmbeddingSpaceStructureLoss(torch.nn.Module): cc.create_tree_cache(cache_root="Preprocessing/multilinguality") if not os.path.exists('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json'): cc.create_map_cache(cache_root="Preprocessing/multilinguality") - if not os.path.exists("Preprocessing/multilinguality/asp_dict.pkl"): - print("download asp file") # TODO downloader script with release self.tree_dist = load_json_from_path('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json') self.map_dist = load_json_from_path('Preprocessing/multilinguality/lang_1_to_lang_2_to_map_dist.json') - with open("Preprocessing/multilinguality/asp_dict.pkl", 'rb') as dictfile: - self.asp_sim = pickle.load(dictfile) - self.lang_list = list(self.asp_sim.keys()) # list of all languages, to get lang_b's index + # with open("Preprocessing/multilinguality/asp_dict.pkl", 'rb') as dictfile: + # self.asp_sim = pickle.load(dictfile) + # self.lang_list = list(self.asp_sim.keys()) # list of all languages, to get lang_b's index self.largest_value_map_dist = 0.0 for _, values in self.map_dist.items(): @@ -64,11 +61,12 @@ class LanguageEmbeddingSpaceStructureLoss(torch.nn.Module): map_dist = self.map_dist[lang_2][lang_1] / self.largest_value_map_dist # Value Range Normalized ASP Dist - lang_2_idx = self.lang_list.index(lang_2) - asp_dist = 1.0 - self.asp_sim[lang_1][lang_2_idx] # it's a similarity measure that goes from 0 to 1, so we subtract it from 1 to turn it into a distance + # lang_2_idx = self.lang_list.index(lang_2) + # asp_dist = 1.0 - self.asp_sim[lang_1][lang_2_idx] # it's a similarity measure that goes from 0 to 1, so we subtract it from 1 to turn it into a distance # Average distance should be similar to embedding distance to bring some structure into the embedding-space - metric_distance = (torch.tensor(tree_dist) + torch.tensor(map_dist) + torch.tensor(asp_dist)) / 3 + # metric_distance = (torch.tensor(tree_dist) + torch.tensor(map_dist) + torch.tensor(asp_dist)) / 3 + metric_distance = (torch.tensor(tree_dist) + torch.tensor(map_dist)) / 2 losses.append(torch.nn.functional.l1_loss(embed_dist, metric_distance)) - return sum(losses) / len(losses) \ No newline at end of file + return sum(losses) / len(losses) diff --git a/Architectures/ToucanTTS/PitchCalculator.py b/Modules/ToucanTTS/PitchCalculator.py similarity index 100% rename from Architectures/ToucanTTS/PitchCalculator.py rename to Modules/ToucanTTS/PitchCalculator.py diff --git a/Architectures/ToucanTTS/README.md b/Modules/ToucanTTS/README.md similarity index 100% rename from Architectures/ToucanTTS/README.md rename to Modules/ToucanTTS/README.md diff --git a/Modules/ToucanTTS/StochasticToucanTTSLoss.py b/Modules/ToucanTTS/StochasticToucanTTSLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..13f4fa01d7bc513e0a7e5a357754ddb9da39c6e8 --- /dev/null +++ b/Modules/ToucanTTS/StochasticToucanTTSLoss.py @@ -0,0 +1,40 @@ +""" +Taken from ESPNet +Adapted by Flux +""" + +import torch + +from Utility.utils import make_non_pad_mask + + +class StochasticToucanTTSLoss(torch.nn.Module): + + def __init__(self): + super().__init__() + self.l1_criterion = torch.nn.L1Loss(reduction="none") + + def forward(self, predicted_features, gold_features, features_lengths): + """ + Args: + predicted_features (Tensor): Batch of outputs (B, Lmax, odim). + gold_features (Tensor): Batch of target features (B, Lmax, odim). + features_lengths (LongTensor): Batch of the lengths of each target (B,). + + Returns: + Tensor: L1 loss value. + """ + + # calculate loss + l1_loss = self.l1_criterion(predicted_features, gold_features) + + # make weighted mask and apply it + out_masks = make_non_pad_mask(features_lengths).unsqueeze(-1).to(gold_features.device) + out_masks = torch.nn.functional.pad(out_masks.transpose(1, 2), [0, gold_features.size(1) - out_masks.size(1), 0, 0, 0, 0], value=False).transpose(1, 2) + out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() + out_weights /= gold_features.size(0) * gold_features.size(2) + + # apply weight + l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum() + + return l1_loss diff --git a/Architectures/ToucanTTS/TTSDataset.py b/Modules/ToucanTTS/TTSDataset.py similarity index 69% rename from Architectures/ToucanTTS/TTSDataset.py rename to Modules/ToucanTTS/TTSDataset.py index 98eab5fe87152b17c2c3b9f9f85f39cc8df04529..878c6886d7141e4eca2305f7859105f2263e0837 100644 --- a/Architectures/ToucanTTS/TTSDataset.py +++ b/Modules/ToucanTTS/TTSDataset.py @@ -5,16 +5,15 @@ import torch from torch.utils.data import Dataset from tqdm import tqdm -from Architectures.Aligner.Aligner import Aligner -from Architectures.Aligner.CodecAlignerDataset import CodecAlignerDataset -from Architectures.ToucanTTS.DurationCalculator import DurationCalculator -from Architectures.ToucanTTS.EnergyCalculator import EnergyCalculator -from Architectures.ToucanTTS.PitchCalculator import Parselmouth +from Modules.Aligner.Aligner import Aligner +from Modules.Aligner.CodecAlignerDataset import CodecAlignerDataset +from Modules.ToucanTTS.DurationCalculator import DurationCalculator +from Modules.ToucanTTS.EnergyCalculator import EnergyCalculator +from Modules.ToucanTTS.PitchCalculator import Parselmouth from Preprocessing.AudioPreprocessor import AudioPreprocessor from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor from Preprocessing.TextFrontend import get_language_id from Preprocessing.articulatory_features import get_feature_to_index_lookup -from Utility.utils import remove_elements class TTSDataset(Dataset): @@ -107,14 +106,7 @@ class TTSDataset(Dataset): self.acoustic_model = Aligner() self.acoustic_model.load_state_dict(torch.load(acoustic_checkpoint_path, map_location="cpu")["asr_model"]) self.acoustic_model = self.acoustic_model.to(device) - - torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround - # careful: assumes 16kHz or 8kHz audio - silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad', force_reload=False, onnx=False, verbose=False) - (get_speech_timestamps, save_audio, read_audio, VADIterator, collect_chunks) = utils - torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets - # this to false globally during model loading rather than using inference_mode or no_grad - silero_model = silero_model.to(device) + self.acoustic_model.eval() # ========================================== # actual creation of datapoints starts here @@ -140,8 +132,6 @@ class TTSDataset(Dataset): text = self.dataset[index][0] - if annotate_silences: - text = self._annotate_silences(text, get_speech_timestamps, index, vis_dir, decoded_wave, device, features, silero_model, save_imgs, decoded_wave_length) cached_duration, ctc_loss = self._calculate_durations(text, index, os.path.join(vis_dir, "post_clean"), features, save_imgs) cached_energy = energy_calc(input_waves=torch.tensor(decoded_wave).unsqueeze(0).to(device), @@ -193,63 +183,6 @@ class TTSDataset(Dataset): sys.exit() del self.dataset - def _annotate_silences(self, text, get_speech_timestamps, index, vis_dir, decoded_wave, device, features, silero_model, save_imgs, decoded_wave_length): - """ - Takes in a text tensor and returns a text tensor with pauses added in all locations, where there are actually pauses in the speech signal. Unfortunately, this tends to make mistakes and not work quite as intended yet. I might revisit it in the future, if I see the need for extremely accurate labels for a small dataset of e.g. special data. - """ - text_with_pauses = list() - for phoneme_index, vector in enumerate(text): - # We add pauses before every word boundary, and later we remove the ones that were added too much - if vector[get_feature_to_index_lookup()["word-boundary"]] == 1: - if text[phoneme_index - 1][get_feature_to_index_lookup()["silence"]] != 1: - text_with_pauses.append([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0.]) - text_with_pauses.append(vector.numpy().tolist()) - else: - text_with_pauses.append(vector.numpy().tolist()) - text = torch.Tensor(text_with_pauses) - - cached_duration, _ = self._calculate_durations(text, index, os.path.join(vis_dir, "pre_clean"), features, save_imgs) - - cumsum = 0 - potential_silences = list() - phoneme_indexes_of_silences = list() - for phoneme_index, phone in enumerate(text): - if phone[get_feature_to_index_lookup()["silence"]] == 1 or phone[get_feature_to_index_lookup()["end of sentence"]] == 1 or phone[get_feature_to_index_lookup()["questionmark"]] == 1 or phone[get_feature_to_index_lookup()["exclamationmark"]] == 1 or phone[get_feature_to_index_lookup()["fullstop"]] == 1: - potential_silences.append([cumsum, cumsum + cached_duration[phoneme_index]]) - phoneme_indexes_of_silences.append(phoneme_index) - cumsum = cumsum + cached_duration[phoneme_index] - with torch.inference_mode(): - speech_timestamps = get_speech_timestamps(torch.Tensor(decoded_wave).to(device), silero_model, sampling_rate=16000) - vad_silences = list() - prev_end = 0 - for speech_segment in speech_timestamps: - if prev_end != 0: - vad_silences.append([prev_end, speech_segment["start"]]) - prev_end = speech_segment["end"] - # at this point we know all the silences and we know the legal silences. - # We have to transform them both into ratios, so we can compare them. - # If a silence overlaps with a legal silence, it can stay. - illegal_silences = list() - for silence_index, silence in enumerate(potential_silences): - illegal = True - start = silence[0] / len(features) - end = silence[1] / len(features) - for legal_silence in vad_silences: - legal_start = legal_silence[0] / decoded_wave_length - legal_end = legal_silence[1] / decoded_wave_length - if legal_start < start < legal_end or legal_start < end < legal_end: - illegal = False - break - if illegal: - # If it is an illegal silence, it is marked for removal in the original wave according to ration with real samplingrate. - illegal_silences.append(phoneme_indexes_of_silences[silence_index]) - - text = remove_elements(text, illegal_silences) # now we have all the silences where there should be silences and none where there shouldn't be any. We have to run the aligner again with this new information. - return text - def _calculate_durations(self, text, index, vis_dir, features, save_imgs): # We deal with the word boundaries by having 2 versions of text: with and without word boundaries. # We note the index of word boundaries and insert durations of 0 afterwards diff --git a/Architectures/ToucanTTS/ToucanTTS.py b/Modules/ToucanTTS/ToucanTTS.py similarity index 65% rename from Architectures/ToucanTTS/ToucanTTS.py rename to Modules/ToucanTTS/ToucanTTS.py index a2a39ba3315cfb80e7188a46fdea2e248b090612..849c81bb97de25b4604e75806b489bf4315732ff 100644 --- a/Architectures/ToucanTTS/ToucanTTS.py +++ b/Modules/ToucanTTS/ToucanTTS.py @@ -1,50 +1,36 @@ import torch +import torch.nn.functional as torchfunc from torch.nn import Linear from torch.nn import Sequential from torch.nn import Tanh -from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d -from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm -from Architectures.GeneralLayers.Conformer import Conformer -from Architectures.GeneralLayers.DurationPredictor import DurationPredictor -from Architectures.GeneralLayers.LengthRegulator import LengthRegulator -from Architectures.GeneralLayers.VariancePredictor import VariancePredictor -from Architectures.ToucanTTS.Glow import Glow -from Architectures.ToucanTTS.ToucanTTSLoss import ToucanTTSLoss +from Modules.GeneralLayers.Conformer import Conformer +from Modules.GeneralLayers.LengthRegulator import LengthRegulator +from Modules.ToucanTTS.StochasticToucanTTSLoss import StochasticToucanTTSLoss +from Modules.ToucanTTS.flow_matching import CFMDecoder from Preprocessing.articulatory_features import get_feature_to_index_lookup from Utility.utils import initialize -from Utility.utils import integrate_with_utt_embed from Utility.utils import make_non_pad_mask -from Utility.utils import make_pad_mask class ToucanTTS(torch.nn.Module): """ - ToucanTTS module, which is mostly just a FastSpeech 2 module, + ToucanTTS module, which is based on a FastSpeech 2 module, but with lots of designs from different architectures accumulated - and some major components added to put a large focus on multilinguality. - - Original contributions: - - Inputs are configurations of the articulatory tract - - Word boundaries are modeled explicitly in the encoder end removed before the decoder - - Speaker embedding conditioning is derived from GST and Adaspeech 4 - - Responsiveness of variance predictors to utterance embedding is increased through conditional layer norm - - The final output receives a GAN discriminator feedback signal + and some major components added to put a large focus on + multilinguality and controllability. Contributions inspired from elsewhere: - - The PostNet is also a normalizing flow, like in PortaSpeech + - The Decoder is a flow matching network, like in Matcha-TTS and StableTTS - Pitch and energy values are averaged per-phone, as in FastPitch to enable great controllability - - The encoder and decoder are Conformers + - The encoder and decoder are Conformers, like in ESPnet - Things that were tried, but showed inferior performance so far: - - Stochastic Duration Prediction - - Stochastic Pitch Prediction - - Stochastic Energy prediction """ def __init__(self, # network structure related - input_feature_dimensions=62, + input_feature_dimensions=64, + spec_channels=128, attention_dimension=384, attention_heads=4, positionwise_conv_kernel_size=1, @@ -74,33 +60,36 @@ class ToucanTTS(torch.nn.Module): transformer_dec_attn_dropout_rate=0.1, # duration predictor + prosody_channels=8, duration_predictor_layers=3, - duration_predictor_kernel_size=3, + duration_predictor_kernel_size=5, duration_predictor_dropout_rate=0.2, # pitch predictor - pitch_predictor_layers=5, + pitch_predictor_layers=3, pitch_predictor_kernel_size=5, - pitch_predictor_dropout=0.3, + pitch_predictor_dropout=0.2, pitch_embed_kernel_size=1, pitch_embed_dropout=0.0, # energy predictor - energy_predictor_layers=3, + energy_predictor_layers=2, energy_predictor_kernel_size=3, - energy_predictor_dropout=0.5, + energy_predictor_dropout=0.2, energy_embed_kernel_size=1, energy_embed_dropout=0.0, - # post glow - glow_kernel_size=5, - glow_blocks=12, - glow_layers=3, + # cfm decoder + cfm_filter_channels=256, + cfm_heads=4, + cfm_layers=3, + cfm_kernel_size=5, + cfm_p_dropout=0.1, # additional features utt_embed_dim=192, # 192 dim speaker embedding + 16 dim prosody embedding optionally (see older version, this one doesn't use the prosody embedding) lang_embs=8000, - lang_emb_size=16, + lang_emb_size=32, # lower dimensions seem to work better integrate_language_embedding_into_encoder_out=True, embedding_integration="AdaIN", # ["AdaIN" | "ConditionalLayerNorm" | "ConcatProject"] ): @@ -144,16 +133,26 @@ class ToucanTTS(torch.nn.Module): "energy_predictor_dropout" : energy_predictor_dropout, "energy_embed_kernel_size" : energy_embed_kernel_size, "energy_embed_dropout" : energy_embed_dropout, + "spec_channels" : spec_channels, + "cfm_filter_channels" : cfm_filter_channels, + "prosody_channels" : prosody_channels, + "cfm_heads" : cfm_heads, + "cfm_layers" : cfm_layers, + "cfm_kernel_size" : cfm_kernel_size, + "cfm_p_dropout" : cfm_p_dropout, "utt_embed_dim" : utt_embed_dim, "lang_embs" : lang_embs, "lang_emb_size" : lang_emb_size, "embedding_integration" : embedding_integration, - "glow_kernel_size" : glow_kernel_size, - "glow_blocks" : glow_blocks, - "glow_layers" : glow_layers, "integrate_language_embedding_into_encoder_out": integrate_language_embedding_into_encoder_out } + if lang_embs is None or lang_embs == 0: + lang_embs = None + integrate_language_embedding_into_encoder_out = False + if integrate_language_embedding_into_encoder_out: + utt_embed_dim = utt_embed_dim + lang_emb_size + self.input_feature_dimensions = input_feature_dimensions self.attention_dimension = attention_dimension self.use_scaled_pos_enc = use_scaled_positional_encoding @@ -185,35 +184,6 @@ class ToucanTTS(torch.nn.Module): use_output_norm=True, embedding_integration=embedding_integration) - if self.integrate_language_embedding_into_encoder_out: - if embedding_integration == "AdaIN": - self.language_embedding_infusion = AdaIN1d(style_dim=lang_emb_size, num_features=attention_dimension) - elif embedding_integration == "ConditionalLayerNorm": - self.language_embedding_infusion = ConditionalLayerNorm(speaker_embedding_dim=lang_emb_size, hidden_dim=attention_dimension) - else: - self.language_embedding_infusion = torch.nn.Linear(attention_dimension + lang_emb_size, attention_dimension) - - self.duration_predictor = DurationPredictor(idim=attention_dimension, n_layers=duration_predictor_layers, - n_chans=attention_dimension, - kernel_size=duration_predictor_kernel_size, - dropout_rate=duration_predictor_dropout_rate, - utt_embed_dim=utt_embed_dim, - embedding_integration=embedding_integration) - - self.pitch_predictor = VariancePredictor(idim=attention_dimension, n_layers=pitch_predictor_layers, - n_chans=attention_dimension, - kernel_size=pitch_predictor_kernel_size, - dropout_rate=pitch_predictor_dropout, - utt_embed_dim=utt_embed_dim, - embedding_integration=embedding_integration) - - self.energy_predictor = VariancePredictor(idim=attention_dimension, n_layers=energy_predictor_layers, - n_chans=attention_dimension, - kernel_size=energy_predictor_kernel_size, - dropout_rate=energy_predictor_dropout, - utt_embed_dim=utt_embed_dim, - embedding_integration=embedding_integration) - self.pitch_embed = Sequential(torch.nn.Conv1d(in_channels=1, out_channels=attention_dimension, kernel_size=pitch_embed_kernel_size, @@ -245,31 +215,55 @@ class ToucanTTS(torch.nn.Module): utt_embed=utt_embed_dim, embedding_integration=embedding_integration) - # due to the nature of the residual vector quantization, we have to predict the codebooks in a hierarchical way. - self.output_projection = torch.nn.Linear(attention_dimension, 128) - - self.post_flow = Glow( - in_channels=128, - hidden_channels=attention_dimension, # post_glow_hidden - kernel_size=glow_kernel_size, # post_glow_kernel_size - dilation_rate=1, - n_blocks=glow_blocks, # post_glow_n_blocks (original 12 in paper) - n_layers=glow_layers, # post_glow_n_block_layers (original 3 in paper) - n_split=4, - n_sqz=2, - text_condition_channels=attention_dimension, - share_cond_layers=False, # post_share_cond_layers - share_wn_layers=4, - sigmoid_scale=False, - condition_integration_projection=torch.nn.Conv1d(128 + attention_dimension, attention_dimension, 5, padding=2) - ) + self.output_projection = torch.nn.Linear(attention_dimension, spec_channels) + self.pitch_latent_reduction = torch.nn.Linear(attention_dimension, prosody_channels) + self.energy_latent_reduction = torch.nn.Linear(attention_dimension, prosody_channels) + self.duration_latent_reduction = torch.nn.Linear(attention_dimension, prosody_channels) # initialize parameters self._reset_parameters(init_type=init_type) if lang_embs is not None: torch.nn.init.normal_(self.encoder.language_embedding.weight, mean=0, std=attention_dimension ** -0.5) - self.criterion = ToucanTTSLoss() + # the following modules have their own init function, so they come AFTER the init. + + self.duration_predictor = CFMDecoder(hidden_channels=prosody_channels, + out_channels=1, + filter_channels=prosody_channels, + n_heads=1, + n_layers=duration_predictor_layers, + kernel_size=duration_predictor_kernel_size, + p_dropout=duration_predictor_dropout_rate, + gin_channels=utt_embed_dim) + + self.pitch_predictor = CFMDecoder(hidden_channels=prosody_channels, + out_channels=1, + filter_channels=prosody_channels, + n_heads=1, + n_layers=pitch_predictor_layers, + kernel_size=pitch_predictor_kernel_size, + p_dropout=pitch_predictor_dropout, + gin_channels=utt_embed_dim) + + self.energy_predictor = CFMDecoder(hidden_channels=prosody_channels, + out_channels=1, + filter_channels=prosody_channels, + n_heads=1, + n_layers=energy_predictor_layers, + kernel_size=energy_predictor_kernel_size, + p_dropout=energy_predictor_dropout, + gin_channels=utt_embed_dim) + + self.flow_matching_decoder = CFMDecoder(hidden_channels=spec_channels, + out_channels=spec_channels, + filter_channels=cfm_filter_channels, + n_heads=cfm_heads, + n_layers=cfm_layers, + kernel_size=cfm_kernel_size, + p_dropout=cfm_p_dropout, + gin_channels=utt_embed_dim) + + self.criterion = StochasticToucanTTSLoss() def forward(self, text_tensors, @@ -282,7 +276,7 @@ class ToucanTTS(torch.nn.Module): utterance_embedding, return_feats=False, lang_ids=None, - run_glow=True + run_stochastic=True ): """ Args: @@ -296,39 +290,32 @@ class ToucanTTS(torch.nn.Module): gold_energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1). lang_ids (LongTensor): The language IDs used to access the language embedding table, if the model is multilingual utterance_embedding (Tensor): Batch of embeddings to condition the TTS on, if the model is multispeaker - run_glow (Bool): Whether to detach the inputs to the normalizing flow for stability. + run_stochastic (Bool): Whether to detach the inputs to the normalizing flow for stability. """ outs, \ - glow_loss, \ - predicted_durations, \ - predicted_pitch, \ - predicted_energy = self._forward(text_tensors=text_tensors, - text_lengths=text_lengths, - gold_speech=gold_speech, - speech_lengths=speech_lengths, - gold_durations=gold_durations, - gold_pitch=gold_pitch, - gold_energy=gold_energy, - utterance_embedding=utterance_embedding, - is_inference=False, - lang_ids=lang_ids, - run_glow=run_glow) + stochastic_loss, \ + duration_loss, \ + pitch_loss, \ + energy_loss = self._forward(text_tensors=text_tensors, + text_lengths=text_lengths, + gold_speech=gold_speech, + speech_lengths=speech_lengths, + gold_durations=gold_durations, + gold_pitch=gold_pitch, + gold_energy=gold_energy, + utterance_embedding=utterance_embedding, + is_inference=False, + lang_ids=lang_ids, + run_stochastic=run_stochastic) # calculate loss - regression_loss, duration_loss, pitch_loss, energy_loss = self.criterion(predicted_features=outs, - gold_features=gold_speech, - features_lengths=speech_lengths, - text_lengths=text_lengths, - gold_durations=gold_durations, - predicted_durations=predicted_durations, - predicted_pitch=predicted_pitch, - predicted_energy=predicted_energy, - gold_pitch=gold_pitch, - gold_energy=gold_energy) + regression_loss = self.criterion(predicted_features=outs, + gold_features=gold_speech, + features_lengths=speech_lengths) if return_feats: - return regression_loss, glow_loss, duration_loss, pitch_loss, energy_loss, outs - return regression_loss, glow_loss, duration_loss, pitch_loss, energy_loss + return regression_loss, stochastic_loss, duration_loss, pitch_loss, energy_loss, outs + return regression_loss, stochastic_loss, duration_loss, pitch_loss, energy_loss def _forward(self, text_tensors, @@ -341,38 +328,48 @@ class ToucanTTS(torch.nn.Module): is_inference=False, utterance_embedding=None, lang_ids=None, - run_glow=False): + run_stochastic=False): + + text_tensors = torch.clamp(text_tensors, max=1.0) + # this is necessary, because of the way we represent modifiers to keep them identifiable. if not self.multilingual_model: lang_ids = None if not self.multispeaker_model: utterance_embedding = None - else: + + if utterance_embedding is not None: utterance_embedding = torch.nn.functional.normalize(utterance_embedding) + if self.integrate_language_embedding_into_encoder_out and lang_ids is not None: + lang_embs = self.encoder.language_embedding(lang_ids) + lang_embs = torch.nn.functional.normalize(lang_embs) + utterance_embedding = torch.cat([lang_embs, utterance_embedding], dim=1).detach() # encoding the texts text_masks = make_non_pad_mask(text_lengths, device=text_lengths.device).unsqueeze(-2) - padding_masks = make_pad_mask(text_lengths, device=text_lengths.device) encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) - if self.integrate_language_embedding_into_encoder_out: - lang_embs = self.encoder.language_embedding(lang_ids).squeeze(-1) - encoded_texts = integrate_with_utt_embed(hs=encoded_texts, utt_embeddings=lang_embs, projection=self.language_embedding_infusion, embedding_training=self.use_conditional_layernorm_embedding_integration) - if is_inference: # predicting pitch, energy and durations - pitch_predictions = self.pitch_predictor(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) - energy_predictions = self.energy_predictor(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) - predicted_durations = self.duration_predictor.inference(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) + reduced_pitch_space = torchfunc.dropout(self.pitch_latent_reduction(encoded_texts), p=0.1).transpose(1, 2) + pitch_predictions = self.pitch_predictor(mu=reduced_pitch_space, mask=text_masks.float(), n_timesteps=10, temperature=1.0, c=utterance_embedding) + embedded_pitch_curve = self.pitch_embed(pitch_predictions).transpose(1, 2) + + reduced_energy_space = torchfunc.dropout(self.energy_latent_reduction(encoded_texts + embedded_pitch_curve), p=0.1).transpose(1, 2) + energy_predictions = self.energy_predictor(mu=reduced_energy_space, mask=text_masks.float(), n_timesteps=10, temperature=1.0, c=utterance_embedding) + embedded_energy_curve = self.energy_embed(energy_predictions).transpose(1, 2) + + reduced_duration_space = torchfunc.dropout(self.duration_latent_reduction(encoded_texts + embedded_pitch_curve + embedded_energy_curve), p=0.1).transpose(1, 2) + predicted_durations = self.duration_predictor(mu=reduced_duration_space, mask=text_masks.float(), n_timesteps=10, temperature=1.0, c=utterance_embedding) + predicted_durations = torch.clamp(torch.ceil(predicted_durations), min=0.0).long().squeeze(1) # modifying the predictions for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)): if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1: predicted_durations[0][phoneme_index] = 0 + # enriching the text with pitch and energy info - embedded_pitch_curve = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2) - embedded_energy_curve = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2) enriched_encoded_texts = encoded_texts + embedded_pitch_curve + embedded_energy_curve # predicting durations for text and upsampling accordingly @@ -380,12 +377,26 @@ class ToucanTTS(torch.nn.Module): else: # training with teacher forcing - pitch_predictions = self.pitch_predictor(encoded_texts.detach(), padding_mask=padding_masks.unsqueeze(-1), utt_embed=utterance_embedding) - energy_predictions = self.energy_predictor(encoded_texts, padding_mask=padding_masks.unsqueeze(-1), utt_embed=utterance_embedding) - predicted_durations = self.duration_predictor(encoded_texts, padding_mask=padding_masks, utt_embed=utterance_embedding) - + reduced_pitch_space = torchfunc.dropout(self.pitch_latent_reduction(encoded_texts), p=0.1).transpose(1, 2) + pitch_loss, _ = self.pitch_predictor.compute_loss(mu=reduced_pitch_space, + x1=gold_pitch.transpose(1, 2), + mask=text_masks.float(), + c=utterance_embedding) embedded_pitch_curve = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2) + + reduced_energy_space = torchfunc.dropout(self.energy_latent_reduction(encoded_texts + embedded_pitch_curve), p=0.1).transpose(1, 2) + energy_loss, _ = self.energy_predictor.compute_loss(mu=reduced_energy_space, + x1=gold_energy.transpose(1, 2), + mask=text_masks.float(), + c=utterance_embedding) embedded_energy_curve = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2) + + reduced_duration_space = torchfunc.dropout(self.duration_latent_reduction(encoded_texts + embedded_pitch_curve + embedded_energy_curve), p=0.1).transpose(1, 2) + duration_loss, _ = self.duration_predictor.compute_loss(mu=reduced_duration_space, + x1=gold_durations.unsqueeze(-1).transpose(1, 2).float(), + mask=text_masks.float(), + c=utterance_embedding) + enriched_encoded_texts = encoded_texts + embedded_energy_curve + embedded_pitch_curve upsampled_enriched_encoded_texts = self.length_regulator(enriched_encoded_texts, gold_durations) @@ -397,8 +408,12 @@ class ToucanTTS(torch.nn.Module): preliminary_spectrogram = self.output_projection(decoded_speech) if is_inference: - if run_glow: - refined_codec_frames = self.post_flow(tgt_mels=gold_speech, infer=is_inference, mel_out=preliminary_spectrogram, encoded_texts=upsampled_enriched_encoded_texts, tgt_nonpadding=None) + if run_stochastic: + refined_codec_frames = self.flow_matching_decoder(mu=preliminary_spectrogram.transpose(1, 2), + mask=make_non_pad_mask([len(decoded_speech[0])], device=decoded_speech.device).unsqueeze(-2).float(), + n_timesteps=15, + temperature=0.2, + c=None).transpose(1, 2) else: refined_codec_frames = preliminary_spectrogram return refined_codec_frames, \ @@ -406,15 +421,18 @@ class ToucanTTS(torch.nn.Module): pitch_predictions.squeeze(), \ energy_predictions.squeeze() else: - if run_glow: - glow_loss = self.post_flow(tgt_mels=gold_speech, infer=is_inference, mel_out=preliminary_spectrogram, encoded_texts=upsampled_enriched_encoded_texts, tgt_nonpadding=decoder_masks) + if run_stochastic: + stochastic_loss, _ = self.flow_matching_decoder.compute_loss(x1=gold_speech.transpose(1, 2), + mask=decoder_masks.float(), + mu=preliminary_spectrogram.transpose(1, 2).detach(), + c=None) else: - glow_loss = None + stochastic_loss = None return preliminary_spectrogram, \ - glow_loss, \ - predicted_durations, \ - pitch_predictions, \ - energy_predictions + stochastic_loss, \ + duration_loss, \ + pitch_loss, \ + energy_loss @torch.inference_mode() def inference(self, @@ -423,7 +441,7 @@ class ToucanTTS(torch.nn.Module): utterance_embedding=None, return_duration_pitch_energy=False, lang_id=None, - run_glow=True): + run_stochastic=True): """ Args: text (LongTensor): Input sequence of characters (T,). @@ -431,7 +449,7 @@ class ToucanTTS(torch.nn.Module): return_duration_pitch_energy (Boolean): whether to return the list of predicted durations for nicer plotting lang_id (LongTensor): The language ID used to access the language embedding table, if the model is multilingual utterance_embedding (Tensor): Embedding to condition the TTS on, if the model is multispeaker - run_glow (bool): whether to use the output of the glow or of the out_projection to generate codec frames + run_stochastic (bool): whether to use the output of the stochastic or of the out_projection to generate codec frames """ self.eval() @@ -451,7 +469,7 @@ class ToucanTTS(torch.nn.Module): is_inference=True, utterance_embedding=utterance_embeddings, lang_ids=lang_id, - run_glow=run_glow) # (1, L, odim) + run_stochastic=run_stochastic) # (1, L, odim) self.train() if return_duration_pitch_energy: @@ -465,25 +483,16 @@ class ToucanTTS(torch.nn.Module): def reset_postnet(self, init_type="xavier_uniform"): # useful for after they explode - initialize(self.post_flow, init_type) + initialize(self.flow_matching_decoder, init_type) if __name__ == '__main__': model = ToucanTTS() print(sum(p.numel() for p in model.parameters() if p.requires_grad)) - print(sum(p.numel() for p in model.post_flow.parameters() if p.requires_grad)) - - print(" TESTING INFERENCE ") - dummy_text_batch = torch.randint(low=0, high=2, size=[12, 62]).float() # [Sequence Length, Features per Phone] - dummy_utterance_embed = torch.randn([192]) # [Dimensions of Speaker Embedding] - dummy_language_id = torch.LongTensor([2]) - print(model.inference(dummy_text_batch, - utterance_embedding=dummy_utterance_embed, - lang_id=dummy_language_id).shape) print(" TESTING TRAINING ") - dummy_text_batch = torch.randint(low=0, high=2, size=[3, 3, 62]).float() # [Batch, Sequence Length, Features per Phone] + dummy_text_batch = torch.randint(low=0, high=2, size=[3, 3, 64]).float() # [Batch, Sequence Length, Features per Phone] dummy_text_lens = torch.LongTensor([2, 3, 3]) dummy_speech_batch = torch.randn([3, 30, 128]) # [Batch, Sequence Length, Spectrogram Buckets] @@ -506,6 +515,14 @@ if __name__ == '__main__': utterance_embedding=dummy_utterance_embed, lang_ids=dummy_language_id) - loss = ce + dl + pl + el + loss = ce + dl + pl + el + fl print(loss) loss.backward() + + print(" TESTING INFERENCE ") + dummy_text_batch = torch.randint(low=0, high=2, size=[12, 64]).float() # [Sequence Length, Features per Phone] + dummy_utterance_embed = torch.randn([192]) # [Dimensions of Speaker Embedding] + dummy_language_id = torch.LongTensor([2]) + print(model.inference(dummy_text_batch, + utterance_embedding=dummy_utterance_embed, + lang_id=dummy_language_id).shape) diff --git a/Architectures/ToucanTTS/ToucanTTSLoss.py b/Modules/ToucanTTS/ToucanTTSLoss.py similarity index 97% rename from Architectures/ToucanTTS/ToucanTTSLoss.py rename to Modules/ToucanTTS/ToucanTTSLoss.py index 6d324093a088ffafb9067bc623b4def44daf40bc..4a182a6f6121314c93f3fa57ea8a04ae0d4b258d 100644 --- a/Architectures/ToucanTTS/ToucanTTSLoss.py +++ b/Modules/ToucanTTS/ToucanTTSLoss.py @@ -5,7 +5,7 @@ Adapted by Flux import torch -from Architectures.GeneralLayers.DurationPredictor import DurationPredictorLoss +from Modules.GeneralLayers.DurationPredictor import DurationPredictorLoss from Utility.utils import make_non_pad_mask diff --git a/Architectures/ToucanTTS/StochasticToucanTTS/__init__.py b/Modules/ToucanTTS/__init__.py similarity index 100% rename from Architectures/ToucanTTS/StochasticToucanTTS/__init__.py rename to Modules/ToucanTTS/__init__.py diff --git a/Modules/ToucanTTS/dit.py b/Modules/ToucanTTS/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..339cba8b70fbe90a7e346e3ebb87e29635543242 --- /dev/null +++ b/Modules/ToucanTTS/dit.py @@ -0,0 +1,219 @@ +""" +Copied from https://github.com/KdaiP/StableTTS by https://github.com/KdaiP + +https://github.com/KdaiP/StableTTS/blob/eebb177ebf195fd1246dedabec4ef69d9351a4f8/models/dit.py + +Code is under MIT License +""" + +# References: +# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py +# https://github.com/jaywalnut310/vits/blob/main/attentions.py +# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.drop = nn.Dropout(p_dropout) + self.act1 = nn.GELU(approximate="tanh") + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = self.act1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, out_channels, n_heads, p_dropout=0.): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + + # from https://nn.labml.ai/transformers/rope/index.html + self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(x) + v = self.conv_v(x) + + x = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head] + key = self.key_rotary_pe(key) + + output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output + + +# modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390 +class DiTConVBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__(self, hidden_channels, out_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_channels + out_channels, elementwise_affine=False, eps=1e-6) + self.attn = MultiHeadAttention(hidden_channels + out_channels, hidden_channels + out_channels, num_heads, p_dropout) + self.norm2 = nn.LayerNorm(hidden_channels + out_channels, elementwise_affine=False, eps=1e-6) + self.mlp = FFN(hidden_channels + out_channels, hidden_channels + out_channels, filter_channels, kernel_size, p_dropout=p_dropout) + self.adaLN_modulation = nn.Sequential( + nn.Linear(gin_channels, hidden_channels + out_channels) if gin_channels != hidden_channels + out_channels else nn.Identity(), + nn.SiLU(), + nn.Linear(hidden_channels + out_channels, 6 * (hidden_channels + out_channels), bias=True) + ) + + def forward(self, x, c, x_mask): + """ + Args: + x : [batch_size, channel, time] + c : [batch_size, channel] + x_mask : [batch_size, 1, time] + return the same shape as x + """ + x = x * x_mask + attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time] + # attn_mask = attn_mask.to(torch.bool) + if c is not None: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1] + x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1, 2)).transpose(1, 2), shift_msa, scale_msa), attn_mask) * x_mask + # x = x.masked_fill(~x_mask, 0.0) + x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1, 2)).transpose(1, 2), shift_mlp, scale_mlp), x_mask) * x_mask + else: + # no condition version + x = x + self.attn(self.norm1(x.transpose(1, 2)).transpose(1, 2), attn_mask) + x = x + self.mlp(self.norm1(x.transpose(1, 2)).transpose(1, 2), x_mask) + return x + + @staticmethod + def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class RotaryPositionalEmbeddings(nn.Module): + """ + ## RoPE module + + Rotary encoding transforms pairs of features by rotating in the 2D plane. + That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. + Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it + by an angle depending on the position of the token. + """ + + def __init__(self, d: int, base: int = 10_000): + r""" + * `d` is the number of features $d$ + * `base` is the constant used for calculating $\Theta$ + """ + super().__init__() + + self.base = base + self.d = int(d) + self.cos_cached = None + self.sin_cached = None + + def _build_cache(self, x: torch.Tensor): + r""" + Cache $\cos$ and $\sin$ values + """ + # Return if cache is already built + if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: + return + + # Get sequence length + seq_len = x.shape[0] + + # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.einsum("n,d->nd", seq_idx, theta) + + # Concatenate so that for row $m$ we have + # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ + idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) + + # Cache them + self.cos_cached = idx_theta2.cos()[:, None, None, :] + self.sin_cached = idx_theta2.sin()[:, None, None, :] + + def _neg_half(self, x: torch.Tensor): + # $\frac{d}{2}$ + d_2 = self.d // 2 + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + + def forward(self, x: torch.Tensor): + """ + * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` + """ + # Cache $\cos$ and $\sin$ values + x = x.permute(2, 0, 1, 3) # b h t d -> t b h d + + self._build_cache(x) + + # Split the features, we can choose to apply rotary embeddings only to a partial set of features. + x_rope, x_pass = x[..., : self.d], x[..., self.d:] + + # Calculate + # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + neg_half_x = self._neg_half(x_rope) + + x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) + + return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d + + +class Transpose(nn.Identity): + """(N, T, D) -> (N, D, T)""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input.transpose(1, 2) diff --git a/Modules/ToucanTTS/dit_wrapper.py b/Modules/ToucanTTS/dit_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..762939dde84fa8f5be95ef6cd801afca8a673e5f --- /dev/null +++ b/Modules/ToucanTTS/dit_wrapper.py @@ -0,0 +1,177 @@ +""" +Copied from https://github.com/KdaiP/StableTTS by https://github.com/KdaiP + +https://github.com/KdaiP/StableTTS/blob/eebb177ebf195fd1246dedabec4ef69d9351a4f8/models/estimator.py + +Code is under MIT License +""" + +import math + +import torch +import torch.nn as nn + +from Modules.ToucanTTS.dit import DiTConVBlock + + +class DitWrapper(nn.Module): + """ add FiLM layer to condition time embedding to DiT """ + + def __init__(self, hidden_channels, out_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0, time_channels=0): + super().__init__() + self.time_fusion = FiLMLayer(hidden_channels, out_channels, time_channels) + self.conv1 = ConvNeXtBlock(hidden_channels, out_channels, filter_channels, gin_channels) + self.conv2 = ConvNeXtBlock(hidden_channels, out_channels, filter_channels, gin_channels) + self.conv3 = ConvNeXtBlock(hidden_channels, out_channels, filter_channels, gin_channels) + self.block = DiTConVBlock(hidden_channels, out_channels, hidden_channels, num_heads, kernel_size, p_dropout, gin_channels) + + def forward(self, x, c, t, x_mask): + x = self.time_fusion(x, t) * x_mask + x = self.conv1(x, c, x_mask) + x = self.conv2(x, c, x_mask) + x = self.conv3(x, c, x_mask) + x = self.block(x, c, x_mask) + return x + + +class FiLMLayer(nn.Module): + """ + Feature-wise Linear Modulation (FiLM) layer + Reference: https://arxiv.org/abs/1709.07871 + """ + + def __init__(self, in_channels, out_channels, cond_channels): + super(FiLMLayer, self).__init__() + self.in_channels = in_channels + self.film = nn.Conv1d(cond_channels, (in_channels + out_channels) * 2, 1) + + def forward(self, x, c): + gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1) + return gamma * x + beta + + +class ConvNeXtBlock(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, gin_channels): + super().__init__() + self.dwconv = nn.Conv1d(in_channels + out_channels, in_channels + out_channels, kernel_size=7, padding=3, groups=in_channels + out_channels) + self.norm = StyleAdaptiveLayerNorm(in_channels + out_channels, gin_channels) + self.pwconv = nn.Sequential(nn.Linear(in_channels + out_channels, filter_channels), + nn.GELU(), + nn.Linear(filter_channels, in_channels + out_channels)) + + def forward(self, x, c, x_mask) -> torch.Tensor: + residual = x + x = self.dwconv(x) * x_mask + if c is not None: + x = self.norm(x.transpose(1, 2), c) + else: + x = x.transpose(1, 2) + x = self.pwconv(x).transpose(1, 2) + x = residual + x + return x * x_mask + + +class StyleAdaptiveLayerNorm(nn.Module): + def __init__(self, in_channels, cond_channels): + """ + Style Adaptive Layer Normalization (SALN) module. + + Parameters: + in_channels: The number of channels in the input feature maps. + cond_channels: The number of channels in the conditioning input. + """ + super(StyleAdaptiveLayerNorm, self).__init__() + self.in_channels = in_channels + + self.saln = nn.Linear(cond_channels, in_channels * 2, 1) + self.norm = nn.LayerNorm(in_channels, elementwise_affine=False) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.saln.bias.data[:self.in_channels], 1) + nn.init.constant_(self.saln.bias.data[self.in_channels:], 0) + + def forward(self, x, c): + gamma, beta = torch.chunk(self.saln(c.unsqueeze(1)), chunks=2, dim=-1) + return gamma * self.norm(x) + beta + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=x.device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels): + super().__init__() + + self.layer = nn.Sequential( + nn.Linear(in_channels, filter_channels), + nn.SiLU(inplace=True), + nn.Linear(filter_channels, out_channels) + ) + + def forward(self, x): + return self.layer(x) + + +# reference: https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/decoder.py +class Decoder(nn.Module): + def __init__(self, hidden_channels, out_channels, filter_channels, dropout=0.05, n_layers=1, n_heads=4, kernel_size=3, gin_channels=0): + super().__init__() + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + + self.time_embeddings = SinusoidalPosEmb(hidden_channels) + self.time_mlp = TimestepEmbedding(hidden_channels, hidden_channels, filter_channels) + + self.blocks = nn.ModuleList([DitWrapper(hidden_channels, out_channels, filter_channels, n_heads, kernel_size, dropout, gin_channels, hidden_channels) for _ in range(n_layers)]) + self.final_proj = nn.Conv1d(hidden_channels + out_channels, out_channels, 1) + + self.initialize_weights() + + def initialize_weights(self): + for block in self.blocks: + nn.init.constant_(block.block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.block.adaLN_modulation[-1].bias, 0) + + def forward(self, x, mask, mu, t, c): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + c (_type_): shape (batch_size, gin_channels) + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + t = self.time_mlp(self.time_embeddings(t)) + + x = torch.cat((x, mu), dim=1) + + for block in self.blocks: + x = block(x, c, t, mask) + + output = self.final_proj(x * mask) + + return output * mask diff --git a/Modules/ToucanTTS/flow_matching.py b/Modules/ToucanTTS/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..c0ba1af423816e7d75c6715bb2f2f107c58d6a20 --- /dev/null +++ b/Modules/ToucanTTS/flow_matching.py @@ -0,0 +1,134 @@ +""" +Copied from https://github.com/KdaiP/StableTTS by https://github.com/KdaiP + +https://github.com/KdaiP/StableTTS/blob/eebb177ebf195fd1246dedabec4ef69d9351a4f8/models/flow_matching.py + +Code is under MIT License +""" + +import imageio +import torch +import torch.nn.functional as F + +from Modules.ToucanTTS.dit_wrapper import Decoder +from Utility.utils import plot_spec_tensor + + +# copied from https://github.com/jaywalnut310/vits/blob/main/commons.py#L121 +def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor: + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +# modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/flow_matching.py +class CFMDecoder(torch.nn.Module): + def __init__(self, hidden_channels, out_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels): + super().__init__() + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.gin_channels = gin_channels + self.sigma_min = 1e-4 + + self.estimator = Decoder(hidden_channels, out_channels, filter_channels, p_dropout, n_layers, n_heads, kernel_size, gin_channels) + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, c=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + c (torch.Tensor, optional): shape: (batch_size, gin_channels) + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + size = list(mu.size()) + size[1] = self.out_channels + z = torch.randn(size=size).to(mu.device) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, c=c) + + def solve_euler(self, x, t_span, mu, mask, c, plot_solutions=False): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + c (torch.Tensor, optional): speaker condition. + shape: (batch_size, gin_channels) + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + sol = [] + + for step in range(1, len(t_span)): + + dphi_dt = self.estimator(x, mask, mu, t, c) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + if plot_solutions: + create_plot_of_all_solutions(sol) + + return sol[-1] + + def compute_loss(self, x1, mask, mu, c): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + c (torch.Tensor, optional): speaker condition. + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), c), + u, + reduction="sum") / (torch.sum(mask) * u.shape[1]) + return loss, y + + +def create_plot_of_all_solutions(sol): + gif_collector = list() + for step_index, solution in enumerate(sol): + unbatched_solution = solution[0] # remove the batch axis (if there are more than one element in the batch, we only take the first) + plot_spec_tensor(unbatched_solution, "tmp", step_index, title=step_index + 1) + gif_collector.append(imageio.v2.imread(f"tmp/{step_index}.png")) + for _ in range(10): + gif_collector.append(gif_collector[-1]) + imageio.mimsave("tmp/animation.gif", gif_collector, fps=6, loop=0) diff --git a/Architectures/ToucanTTS/glow_utils.py b/Modules/ToucanTTS/glow_utils.py similarity index 100% rename from Architectures/ToucanTTS/glow_utils.py rename to Modules/ToucanTTS/glow_utils.py diff --git a/Architectures/ToucanTTS/toucantts_meta_train_loop.py b/Modules/ToucanTTS/toucantts_meta_train_loop.py similarity index 61% rename from Architectures/ToucanTTS/toucantts_meta_train_loop.py rename to Modules/ToucanTTS/toucantts_meta_train_loop.py index c439a958eda4b028111b24d0f77e184cec9e71f1..657cc4c092fd77b55cc9f7edfae01e361df3664f 100644 --- a/Architectures/ToucanTTS/toucantts_meta_train_loop.py +++ b/Modules/ToucanTTS/toucantts_meta_train_loop.py @@ -4,7 +4,7 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.data.dataloader import DataLoader from tqdm import tqdm -from Architectures.ToucanTTS.LanguageEmbeddingSpaceStructureLoss import LanguageEmbeddingSpaceStructureLoss +from Modules.ToucanTTS.LanguageEmbeddingSpaceStructureLoss import LanguageEmbeddingSpaceStructureLoss from Preprocessing.AudioPreprocessor import AudioPreprocessor from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor from Utility.WarmupScheduler import ToucanWarmupScheduler as WarmupScheduler @@ -49,6 +49,7 @@ def train_loop(net, train_samplers, gpu_count, use_less_loss, + freeze_lang_embs ): """ see train loop arbiter for explanations of the arguments @@ -66,17 +67,19 @@ def train_loop(net, if use_less_loss: less_loss = LanguageEmbeddingSpaceStructureLoss() - pretrained_language_codes = ['eng', 'deu', 'fra', 'spa', 'cmn', 'por', 'pol', 'ita', 'nld', 'ell', 'fin', 'vie', 'rus', 'hun', 'bem', 'swh', 'amh', 'wol', 'mal', 'chv', 'iba', 'jav', 'fon', 'hau', 'lbb', 'kik', 'lin', 'lug', 'luo', 'sxb', 'yor', 'nya', 'loz', 'toi', 'afr', 'arb', 'asm', 'ast', 'azj', 'bel', 'bul', 'ben', 'bos', 'cat', 'ceb', 'sdh', - 'ces', 'cym', 'dan', 'ekk', 'pes', 'fil', 'gle', 'glg', 'guj', 'heb', 'hin', 'hrv', 'hye', 'ind', 'ibo', 'isl', 'kat', 'kam', 'kea', 'kaz', 'khm', 'kan', 'kor', 'ltz', 'lao', 'lit', 'lvs', 'mri', 'mkd', 'xng', 'mar', 'zsm', 'mlt', 'oci', 'ory', 'pan', 'pst', 'ron', 'snd', 'slk', 'slv', 'sna', 'som', 'srp', 'swe', 'tam', - 'tel', 'tgk', 'tur', 'ukr', 'umb', 'urd', 'uzn', 'bhd', 'kfs', 'dgo', 'gbk', 'bgc', 'xnr', 'kfx', 'mjl', 'bfz', 'acf', 'bss', 'inb', 'nca', 'quh', 'wap', 'acr', 'bus', 'dgr', 'maz', 'nch', 'qul', 'tav', 'wmw', 'acu', 'byr', 'dik', 'iou', 'mbb', 'ncj', 'qvc', 'tbc', 'xed', 'agd', 'bzh', 'djk', 'ipi', 'mbc', 'ncl', 'qve', - 'tbg', 'xon', 'agg', 'bzj', 'dop', 'jac', 'mbh', 'ncu', 'qvh', 'tbl', 'xtd', 'agn', 'caa', 'jic', 'mbj', 'ndj', 'qvm', 'tbz', 'xtm', 'agr', 'cab', 'emp', 'jiv', 'mbt', 'nfa', 'qvn', 'tca', 'yaa', 'agu', 'cap', 'jvn', 'mca', 'ngp', 'qvs', 'tcs', 'yad', 'aia', 'car', 'ese', 'mcb', 'ngu', 'qvw', 'yal', 'cax', 'kaq', 'mcd', - 'nhe', 'qvz', 'tee', 'ycn', 'ake', 'cbc', 'far', 'mco', 'qwh', 'yka', 'alp', 'cbi', 'kdc', 'mcp', 'nhu', 'qxh', 'ame', 'cbr', 'gai', 'kde', 'mcq', 'nhw', 'qxn', 'tew', 'yre', 'amf', 'cbs', 'gam', 'kdl', 'mdy', 'nhy', 'qxo', 'tfr', 'yva', 'amk', 'cbt', 'geb', 'kek', 'med', 'nin', 'rai', 'zaa', 'apb', 'cbu', 'glk', 'ken', - 'mee', 'nko', 'rgu', 'zab', 'apr', 'cbv', 'meq', 'tgo', 'zac', 'arl', 'cco', 'gng', 'kje', 'met', 'nlg', 'rop', 'tgp', 'zad', 'grc', 'klv', 'mgh', 'nnq', 'rro', 'zai', 'ata', 'cek', 'gub', 'kmu', 'mib', 'noa', 'ruf', 'tna', 'zam', 'atb', 'cgc', 'guh', 'kne', 'mie', 'not', 'rug', 'tnk', 'zao', 'atg', 'chf', 'knf', 'mih', - 'npl', 'tnn', 'zar', 'awb', 'chz', 'gum', 'knj', 'mil', 'sab', 'tnp', 'zas', 'cjo', 'guo', 'ksr', 'mio', 'obo', 'seh', 'toc', 'zav', 'azg', 'cle', 'gux', 'kue', 'mit', 'omw', 'sey', 'tos', 'zaw', 'azz', 'cme', 'gvc', 'kvn', 'miz', 'ood', 'sgb', 'tpi', 'zca', 'bao', 'cni', 'gwi', 'kwd', 'mkl', 'shp', 'tpt', 'zga', 'bba', - 'cnl', 'gym', 'kwf', 'mkn', 'ote', 'sja', 'trc', 'ziw', 'bbb', 'cnt', 'gyr', 'kwi', 'mop', 'otq', 'snn', 'ttc', 'zlm', 'cof', 'hat', 'kyc', 'mox', 'pab', 'snp', 'tte', 'zos', 'bgt', 'con', 'kyf', 'mpm', 'pad', 'tue', 'zpc', 'bjr', 'cot', 'kyg', 'mpp', 'soy', 'tuf', 'zpl', 'bjv', 'cpa', 'kyq', 'mpx', 'pao', 'tuo', 'zpm', - 'bjz', 'cpb', 'hlt', 'kyz', 'mqb', 'pib', 'spp', 'zpo', 'bkd', 'cpu', 'hns', 'lac', 'mqj', 'pir', 'spy', 'txq', 'zpu', 'blz', 'crn', 'hto', 'lat', 'msy', 'pjt', 'sri', 'txu', 'zpz', 'bmr', 'cso', 'hub', 'lex', 'mto', 'pls', 'srm', 'udu', 'ztq', 'bmu', 'ctu', 'lgl', 'muy', 'poi', 'srn', 'zty', 'bnp', 'cuc', 'lid', 'mxb', - 'stp', 'upv', 'zyp', 'boa', 'cui', 'huu', 'mxq', 'sus', 'ura', 'boj', 'cuk', 'huv', 'llg', 'mxt', 'poy', 'suz', 'urb', 'box', 'cwe', 'hvn', 'prf', 'urt', 'bpr', 'cya', 'ign', 'lww', 'myk', 'ptu', 'usp', 'bps', 'daa', 'ikk', 'maj', 'myy', 'vid', 'bqc', 'dah', 'nab', 'qub', 'tac', 'bqp', 'ded', 'imo', 'maq', 'nas', 'quf', - 'taj', 'vmy'] + pretrained_language_codes = [ + "eng", "deu", "fra", "spa", "cmn", "por", "pol", "ita", "nld", "ell", "fin", "vie", "jpn", "rus", "hun", "asm", "ben", "brx", "dgo", "guj", "hin", "kan", "kas", "knn", "mai", "mal", "mni", "mar", "nep", "ory", "pan", "san", "sat", "snd", "tam", "tel", "urd", "bem", "swh", "amh", "wol", "chv", "iba", "jav", "fon", "hau", "lbb", + "kik", "lin", "lug", "luo", "sxb", "yor", "nya", "loz", "toi", "afr", "arb", "ast", "azj", "bel", "bul", "bos", "cat", "ceb", "sdh", "ces", "cym", "dan", "ekk", "pes", "fil", "gle", "glg", "heb", "hrv", "hye", "ind", "ibo", "isl", "kat", "kam", "kea", "kaz", "khm", "kor", "ltz", "lao", "lit", "lvs", "mri", "mkd", "xng", "zsm", + "mlt", "oci", "pst", "ron", "slk", "slv", "sna", "som", "srp", "swe", "tgk", "tur", "ukr", "umb", "uzn", "bhd", "kfs", "gbk", "bgc", "xnr", "kfx", "mjl", "bfz", "acf", "bss", "inb", "nca", "quh", "wap", "acr", "bus", "dgr", "maz", "nch", "qul", "tav", "wmw", "acu", "byr", "dik", "iou", "mbb", "ncj", "qvc", "tbc", "xed", "agd", + "bzh", "djk", "ipi", "mbc", "ncl", "qve", "tbg", "xon", "agg", "bzj", "dop", "jac", "mbh", "ncu", "qvh", "tbl", "xtd", "agn", "caa", "jic", "mbj", "ndj", "qvm", "tbz", "xtm", "agr", "cab", "emp", "jiv", "mbt", "nfa", "qvn", "tca", "yaa", "agu", "cap", "jvn", "mca", "ngp", "qvs", "tcs", "yad", "aia", "car", "ese", "mcb", "ngu", + "qvw", "yal", "cax", "kaq", "mcd", "nhe", "qvz", "tee", "ycn", "ake", "cbc", "far", "mco", "qwh", "yka", "alp", "cbi", "kdc", "mcp", "nhu", "qxh", "ame", "cbr", "gai", "kde", "mcq", "nhw", "qxn", "tew", "yre", "amf", "cbs", "gam", "kdl", "mdy", "nhy", "qxo", "tfr", "yva", "amk", "cbt", "geb", "kek", "med", "nin", "rai", "zaa", + "apb", "cbu", "glk", "ken", "mee", "nko", "rgu", "zab", "apr", "cbv", "meq", "tgo", "zac", "arl", "cco", "gng", "kje", "met", "nlg", "rop", "tgp", "zad", "grc", "klv", "mgh", "nnq", "rro", "zai", "ata", "cek", "gub", "kmu", "mib", "noa", "ruf", "tna", "zam", "atb", "cgc", "guh", "kne", "mie", "not", "rug", "tnk", "zao", "atg", + "chf", "knf", "mih", "npl", "tnn", "zar", "awb", "chz", "gum", "knj", "mil", "sab", "tnp", "zas", "cjo", "guo", "ksr", "mio", "obo", "seh", "toc", "zav", "azg", "cle", "gux", "kue", "mit", "omw", "sey", "tos", "zaw", "azz", "cme", "gvc", "kvn", "miz", "ood", "sgb", "tpi", "zca", "bao", "cni", "gwi", "kwd", "mkl", "shp", "tpt", + "zga", "bba", "cnl", "gym", "kwf", "mkn", "ote", "sja", "trc", "ziw", "bbb", "cnt", "gyr", "kwi", "mop", "otq", "snn", "ttc", "zlm", "cof", "hat", "kyc", "mox", "pab", "snp", "tte", "zos", "bgt", "con", "kyf", "mpm", "pad", "tue", "zpc", "bjr", "cot", "kyg", "mpp", "soy", "tuf", "zpl", "bjv", "cpa", "kyq", "mpx", "pao", "tuo", + "zpm", "bjz", "cpb", "hlt", "kyz", "mqb", "pib", "spp", "zpo", "bkd", "cpu", "hns", "lac", "mqj", "pir", "spy", "txq", "zpu", "blz", "crn", "hto", "lat", "msy", "pjt", "sri", "txu", "zpz", "bmr", "cso", "hub", "lex", "mto", "pls", "srm", "udu", "ztq", "bmu", "ctu", "lgl", "muy", "poi", "srn", "zty", "bnp", "cuc", "lid", "mxb", + "stp", "upv", "zyp", "boa", "cui", "huu", "mxq", "sus", "ura", "boj", "cuk", "huv", "llg", "mxt", "poy", "suz", "urb", "box", "cwe", "hvn", "prf", "urt", "bpr", "cya", "ign", "lww", "myk", "ptu", "usp", "bps", "daa", "ikk", "maj", "myy", "vid", "bqc", "dah", "nab", "qub", "tac", "bqp", "ded", "imo", "maq", "nas", "quf", "taj", + "vmy" + ] pretrained_language_ids = list() # an alternative to the valid_language_ids for language_code in pretrained_language_codes: pretrained_language_ids.append(less_loss.iso_codes_to_ids[language_code]) @@ -102,37 +105,42 @@ def train_loop(net, batch_sampler_train = torch.utils.data.BatchSampler(sampler, 1, drop_last=True) train_loaders.append(DataLoader(dataset=dataset, batch_sampler=batch_sampler_train, - num_workers=0, + num_workers=0, # has to be 0, otherwise copies of the dataset are created, which is not feasible for large scale trainings. This is not optimal for small trainings, but necessary for scalability. pin_memory=True, prefetch_factor=None, collate_fn=lambda x: x[0])) train_iters.append(iter(train_loaders[-1])) # embedding training is not supported here - optimizer = torch.optim.Adam([p for name, p in model.named_parameters() if 'post_flow' not in name], lr=lr) - flow_optimizer = torch.optim.Adam(model.post_flow.parameters(), lr=lr) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = WarmupScheduler(optimizer, peak_lr=lr, warmup_steps=warmup_steps, max_steps=steps) - flow_scheduler = WarmupScheduler(flow_optimizer, peak_lr=lr, warmup_steps=(warmup_steps // 4), max_steps=steps) steps_run_previously = 0 regression_losses_total = list() - glow_losses_total = list() + stochastic_losses_total = list() duration_losses_total = list() pitch_losses_total = list() energy_losses_total = list() - less_losses_total = list() if resume: path_to_checkpoint = get_most_recent_checkpoint(checkpoint_dir=save_directory) if path_to_checkpoint is not None: check_dict = torch.load(path_to_checkpoint, map_location=device) - model.load_state_dict(check_dict["model"]) - if not fine_tune: + if freeze_lang_embs: + filtered_state_dict = {} + for name, param in check_dict["model"].items(): + if name in model.state_dict(): + if param.size() == model.state_dict()[name].size(): + filtered_state_dict[name] = param + print(f"Loading parameter {name}") + model.load_state_dict(filtered_state_dict, strict=False) + model.encoder.language_embedding.weight.requires_grad = False # and we never reset that. + else: + model.load_state_dict(check_dict["model"]) + if not fine_tune and not freeze_lang_embs: optimizer.load_state_dict(check_dict["optimizer"]) scheduler.load_state_dict(check_dict["scheduler"]) - flow_optimizer.load_state_dict(check_dict["flow_optimizer"]) - flow_scheduler.load_state_dict(check_dict["flow_scheduler"]) steps_run_previously = check_dict["step_counter"] if steps_run_previously > steps: print("Desired steps already reached in loaded checkpoint.") @@ -143,23 +151,29 @@ def train_loop(net, # Actual train loop starts here # ============================= - if not fine_tune and not resume and use_less_loss: + if not fine_tune and not resume and use_less_loss and not freeze_lang_embs: print("Priming the language embedding space...") + original_lr = optimizer.param_groups[0]['lr'] + pretraining_lr = 0.001 + for param_group in optimizer.param_groups: + param_group['lr'] = pretraining_lr less_values = list() - for i in tqdm(range(warmup_steps * 2)): + for i in tqdm(range(warmup_steps * 8)): language_ids = random.sample(valid_language_ids, batch_size) language_embeddings = model.encoder.language_embedding(torch.LongTensor(language_ids).to(device)) less_value_unsupervised = less_loss(language_ids, language_embeddings) - optimizer.zero_grad() less_values.append(less_value_unsupervised.item()) + optimizer.zero_grad() less_value_unsupervised.backward() optimizer.step() if i % warmup_steps // 2 == 0: print(sum(less_values) / len(less_values)) less_values = list() + for param_group in optimizer.param_groups: + param_group['lr'] = original_lr for step_counter in tqdm(range(steps_run_previously, steps)): - run_glow = step_counter > (warmup_steps * 2) + run_stochastic = step_counter > warmup_steps * 2 batches = [] while len(batches) < batch_size: @@ -185,7 +199,7 @@ def train_loop(net, lang_ids = batch[8].squeeze(1).to(device) speech_batch = list() # I wish this could be done in the collate function or in the getitem, but using DL models in multiprocessing on very large datasets causes just way too many issues. - for speech_sample in speech_indexes: + for index, speech_sample in enumerate(speech_indexes): with torch.inference_mode(): wave = ap.indexes_to_audio(speech_sample.int().to(device)).detach() mel = spec_extractor.audio_to_mel_spec_tensor(wave, explicit_sampling_rate=16000).transpose(0, 1).detach().cpu() @@ -199,7 +213,7 @@ def train_loop(net, # step (i.e. iterations of inner loop = 1) utterance_embedding = batch[9].to(device) - regression_loss, glow_loss, duration_loss, pitch_loss, energy_loss = net( + regression_loss, stochastic_loss, duration_loss, pitch_loss, energy_loss = net( text_tensors=text_tensors, text_lengths=text_lengths, gold_speech=gold_speech, @@ -210,15 +224,9 @@ def train_loop(net, utterance_embedding=utterance_embedding, lang_ids=lang_ids, return_feats=False, - run_glow=run_glow + run_stochastic=run_stochastic ) - if use_less_loss: - language_embeddings_seen = model.encoder.language_embedding(lang_ids) - language_ids = random.sample(valid_language_ids, batch_size) - language_embeddings_random = model.encoder.language_embedding(torch.LongTensor(language_ids).to(device)) - less_value = less_loss(lang_ids.cpu().squeeze().tolist() + language_ids, torch.cat([language_embeddings_seen, language_embeddings_random], dim=0)) - # then we directly update our meta-parameters without # the need for any task specific parameters @@ -230,44 +238,29 @@ def train_loop(net, train_loss = train_loss + duration_loss train_loss = train_loss + pitch_loss train_loss = train_loss + energy_loss - if use_less_loss: - train_loss = train_loss + less_value * 2 - - if glow_loss is not None: # even if run_glow is true, this can still happen if the log prob cannot be calculated. - if torch.isnan(glow_loss) or torch.isinf(glow_loss): - print("Glow loss turned to NaN! Skipping this batch ...") + if stochastic_loss is not None: + if torch.isnan(stochastic_loss) or torch.isinf(stochastic_loss): + print("Flow loss turned to NaN! Skipping this batch ...") continue - - train_loss = train_loss + glow_loss - - if glow_loss < 0.0: - glow_losses_total.append(glow_loss.item()) - else: - glow_losses_total.append(0.1) # just to avoid super large numbers during plotting that mess up the scaling + train_loss = train_loss + stochastic_loss + stochastic_losses_total.append(stochastic_loss.item()) else: - glow_losses_total.append(0) + stochastic_losses_total.append(0) regression_losses_total.append(regression_loss.item()) duration_losses_total.append(duration_loss.item()) pitch_losses_total.append(pitch_loss.item()) energy_losses_total.append(energy_loss.item()) - if use_less_loss: - less_losses_total.append(less_value.item()) optimizer.zero_grad() - flow_optimizer.zero_grad() if type(train_loss) is float: print("There is no loss for this step! Skipping ...") continue train_loss.backward() - torch.nn.utils.clip_grad_norm_([p for name, p in model.named_parameters() if 'post_flow' not in name], 1.0, error_if_nonfinite=False) + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0, error_if_nonfinite=False) optimizer.step() scheduler.step() - if glow_loss is not None: - torch.nn.utils.clip_grad_norm_(model.post_flow.parameters(), 1.0, error_if_nonfinite=False) - flow_optimizer.step() - flow_scheduler.step() if step_counter % steps_per_checkpoint == 0 and step_counter != 0: # ============================== @@ -281,29 +274,27 @@ def train_loop(net, net.eval() default_embedding = datasets[0][0][9].to(device) print("Reconstruction Loss: {}".format(round(sum(regression_losses_total) / len(regression_losses_total), 3))) + print("Steps: {}\n".format(step_counter)) torch.save({ - "model" : model.state_dict(), - "optimizer" : optimizer.state_dict(), - "scheduler" : scheduler.state_dict(), - "flow_optimizer": flow_optimizer.state_dict(), - "flow_scheduler": flow_scheduler.state_dict(), - "step_counter" : step_counter, - "default_emb" : default_embedding, - "config" : model.config + "model" : model.state_dict(), + "optimizer" : optimizer.state_dict(), + "scheduler" : scheduler.state_dict(), + "step_counter": step_counter, + "default_emb" : default_embedding, + "config" : model.config }, os.path.join(save_directory, "checkpoint_{}.pt".format(step_counter))) delete_old_checkpoints(save_directory, keep=5) if use_wandb: wandb.log({ - "regression_loss" : round(sum(regression_losses_total) / len(regression_losses_total), 5), - "glow_loss" : round(sum(glow_losses_total) / len(glow_losses_total), 5), - "duration_loss" : round(sum(duration_losses_total) / len(duration_losses_total), 5), - "pitch_loss" : round(sum(pitch_losses_total) / len(pitch_losses_total), 5), - "energy_loss" : round(sum(energy_losses_total) / len(energy_losses_total), 5), - "embedding_structure_loss": 0.0 if len(less_losses_total) == 0 else round(sum(less_losses_total) / len(less_losses_total), 5), - "learning_rate" : optimizer.param_groups[0]['lr'] + "regression_loss": round(sum(regression_losses_total) / len(regression_losses_total), 5), + "stochastic_loss": round(sum(stochastic_losses_total) / len(stochastic_losses_total), 5), + "duration_loss" : round(sum(duration_losses_total) / len(duration_losses_total), 5), + "pitch_loss" : round(sum(pitch_losses_total) / len(pitch_losses_total), 5), + "energy_loss" : round(sum(energy_losses_total) / len(energy_losses_total), 5), + "learning_rate" : optimizer.param_groups[0]['lr'] }, step=step_counter) try: @@ -313,7 +304,7 @@ def train_loop(net, step=step_counter, lang=lang, default_emb=default_embedding, - run_glow=run_glow) + run_stochastic=run_stochastic) if use_wandb: wandb.log({ "progress_plot": wandb.Image(path_to_most_recent_plot) @@ -329,8 +320,14 @@ def train_loop(net, net.train() regression_losses_total = list() - glow_losses_total = list() + stochastic_losses_total = list() duration_losses_total = list() pitch_losses_total = list() energy_losses_total = list() - less_losses_total = list() + if gpu_count > 1: + # just to be extra sure tht all models are synchronous + torch.distributed.barrier() + checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=save_directory, n=1) + check_dict = torch.load(checkpoint_paths[0], map_location=device) + model.load_state_dict(check_dict["model"]) + torch.distributed.barrier() diff --git a/Architectures/ToucanTTS/toucantts_train_loop.py b/Modules/ToucanTTS/toucantts_train_loop.py similarity index 80% rename from Architectures/ToucanTTS/toucantts_train_loop.py rename to Modules/ToucanTTS/toucantts_train_loop.py index 3227959f7a65c9f2abd46a42dea7bef7573413f4..d749d50c2264c2c86d816862073363c4e14d9645 100644 --- a/Architectures/ToucanTTS/toucantts_train_loop.py +++ b/Modules/ToucanTTS/toucantts_train_loop.py @@ -70,7 +70,7 @@ def train_loop(net, batch_sampler_train = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True) train_loader = DataLoader(dataset=train_dataset, batch_sampler=batch_sampler_train, - num_workers=0, + num_workers=0, # has to be 0, otherwise copies of the dataset are created, which is not feasible for large scale trainings. This is not optimal for small trainings, but necessary for scalability. pin_memory=True, prefetch_factor=None, collate_fn=collate_and_pad) @@ -83,11 +83,9 @@ def train_loop(net, model = net.module else: model = net - optimizer = torch.optim.Adam([p for name, p in model.named_parameters() if 'post_flow' not in name], lr=lr) - flow_optimizer = torch.optim.Adam(model.post_flow.parameters(), lr=lr) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = WarmupScheduler(optimizer, peak_lr=lr, warmup_steps=warmup_steps, max_steps=steps) - flow_scheduler = WarmupScheduler(flow_optimizer, peak_lr=lr, warmup_steps=(warmup_steps // 4), max_steps=steps) epoch = 0 if resume: @@ -98,12 +96,10 @@ def train_loop(net, if not fine_tune: optimizer.load_state_dict(check_dict["optimizer"]) scheduler.load_state_dict(check_dict["scheduler"]) - flow_optimizer.load_state_dict(check_dict["flow_optimizer"]) - flow_scheduler.load_state_dict(check_dict["flow_scheduler"]) step_counter = check_dict["step_counter"] start_time = time.time() regression_losses_total = list() - glow_losses_total = list() + stochastic_losses_total = list() duration_losses_total = list() pitch_losses_total = list() energy_losses_total = list() @@ -130,11 +126,11 @@ def train_loop(net, speech_batch.append(gold_speech_sample) gold_speech = pad_sequence(speech_batch, batch_first=True).to(device) - run_glow = step_counter > (warmup_steps * 2) or fine_tune + run_stochastic = (step_counter > warmup_steps * 2) or fine_tune train_loss = 0.0 utterance_embedding = batch[9].to(device) - regression_loss, glow_loss, duration_loss, pitch_loss, energy_loss = net( + regression_loss, stochastic_loss, duration_loss, pitch_loss, energy_loss = net( text_tensors=text_tensors, text_lengths=text_lengths, gold_speech=gold_speech, @@ -145,7 +141,7 @@ def train_loop(net, utterance_embedding=utterance_embedding, lang_ids=lang_ids, return_feats=False, - run_glow=run_glow + run_stochastic=run_stochastic ) if torch.isnan(regression_loss) or torch.isnan(duration_loss) or torch.isnan(pitch_loss) or torch.isnan(energy_loss): @@ -162,34 +158,25 @@ def train_loop(net, pitch_losses_total.append(pitch_loss.item()) energy_losses_total.append(energy_loss.item()) - if glow_loss is not None: + if stochastic_loss is not None: - if torch.isnan(glow_loss): - print("Glow loss turned to NaN! Skipping this batch ...") + if torch.isnan(stochastic_loss): + print("Flow loss turned to NaN! Skipping this batch ...") continue - if glow_loss < 0.0: - glow_losses_total.append(glow_loss.item()) - else: - glow_losses_total.append(0.1) - - train_loss = train_loss + glow_loss + stochastic_losses_total.append(stochastic_loss.item()) + train_loss = train_loss + stochastic_loss else: - glow_losses_total.append(0) + stochastic_losses_total.append(0) optimizer.zero_grad() - flow_optimizer.zero_grad() if type(train_loss) is float: print("There is no loss for this step! Skipping ...") continue train_loss.backward() - torch.nn.utils.clip_grad_norm_([p for name, p in model.named_parameters() if 'post_flow' not in name], 1.0, error_if_nonfinite=False) + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0, error_if_nonfinite=False) optimizer.step() scheduler.step() - if glow_loss is not None: - torch.nn.utils.clip_grad_norm_(model.post_flow.parameters(), 1.0, error_if_nonfinite=False) - flow_optimizer.step() - flow_scheduler.step() step_counter += 1 if step_counter % steps_per_checkpoint == 0: # evaluation interval is happening @@ -197,34 +184,32 @@ def train_loop(net, net.eval() default_embedding = train_dataset[0][9].to(device) torch.save({ - "model" : model.state_dict(), - "optimizer" : optimizer.state_dict(), - "step_counter" : step_counter, - "scheduler" : scheduler.state_dict(), - "flow_optimizer": flow_optimizer.state_dict(), - "flow_scheduler": flow_scheduler.state_dict(), - "default_emb" : default_embedding, - "config" : model.config + "model" : model.state_dict(), + "optimizer" : optimizer.state_dict(), + "step_counter": step_counter, + "scheduler" : scheduler.state_dict(), + "default_emb" : default_embedding, + "config" : model.config }, os.path.join(save_directory, "checkpoint_{}.pt".format(step_counter))) delete_old_checkpoints(save_directory, keep=5) print(f"\nEpoch: {epoch}") print(f"Time elapsed: {round((time.time() - start_time) / 60)} Minutes") - print(f"Reconstruction Loss: {round(sum(regression_losses_total) / len(regression_losses_total), 4)}") + print("Reconstruction Loss: {}".format(round(sum(regression_losses_total) / len(regression_losses_total), 3))) print(f"Steps: {step_counter}\n") if use_wandb: wandb.log({ "regression_loss": round(sum(regression_losses_total) / len(regression_losses_total), 5), - "glow_loss" : round(sum(glow_losses_total) / len(glow_losses_total), 5), + "stochastic_loss": round(sum(stochastic_losses_total) / len(stochastic_losses_total), 5), "duration_loss" : round(sum(duration_losses_total) / len(duration_losses_total), 5), "pitch_loss" : round(sum(pitch_losses_total) / len(pitch_losses_total), 5), "energy_loss" : round(sum(energy_losses_total) / len(energy_losses_total), 5), "learning_rate" : optimizer.param_groups[0]['lr'] }, step=step_counter) regression_losses_total = list() - glow_losses_total = list() + stochastic_losses_total = list() duration_losses_total = list() pitch_losses_total = list() energy_losses_total = list() @@ -235,7 +220,7 @@ def train_loop(net, step=step_counter, lang=lang, default_emb=default_embedding, - run_glow=run_glow) + run_stochastic=run_stochastic) if use_wandb: wandb.log({ "progress_plot": wandb.Image(path_to_most_recent_plot) @@ -249,5 +234,11 @@ def train_loop(net, return # DONE net.train() - + if gpu_count > 1: + # just to be extra sure tht all models are synchronous + torch.distributed.barrier() + checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=save_directory, n=1) + check_dict = torch.load(checkpoint_paths[0], map_location=device) + model.load_state_dict(check_dict["model"]) + torch.distributed.barrier() print("\n\n\nEPOCH COMPLETE\n\n\n") diff --git a/Architectures/ToucanTTS/toucantts_train_loop_arbiter.py b/Modules/ToucanTTS/toucantts_train_loop_arbiter.py similarity index 83% rename from Architectures/ToucanTTS/toucantts_train_loop_arbiter.py rename to Modules/ToucanTTS/toucantts_train_loop_arbiter.py index 8b52156482486805a8c7d07578a5391482dc36a0..3cee24bd4ace106536cce7a7275ebff447b111d6 100644 --- a/Architectures/ToucanTTS/toucantts_train_loop_arbiter.py +++ b/Modules/ToucanTTS/toucantts_train_loop_arbiter.py @@ -1,7 +1,7 @@ import torch -from Architectures.ToucanTTS.toucantts_meta_train_loop import train_loop as multi_language_loop -from Architectures.ToucanTTS.toucantts_train_loop import train_loop as mono_language_loop +from Modules.ToucanTTS.toucantts_meta_train_loop import train_loop as multi_language_loop +from Modules.ToucanTTS.toucantts_train_loop import train_loop as mono_language_loop def train_loop(net, # an already initialized ToucanTTS model that should be trained. @@ -19,13 +19,14 @@ def train_loop(net, # an already initialized ToucanTTS model that should be tra path_to_checkpoint=None, # path to a trained checkpoint to either continue training or fine-tune from. lr=0.0001, # learning rate of the model. resume=False, # whether to automatically load the most recent checkpoint and resume training from it. - warmup_steps=12000, # how many steps until the learning rate reaches the specified value and starts decreasing again. + warmup_steps=4000, # how many steps until the learning rate reaches the specified value and starts decreasing again. use_wandb=False, # whether to use online experiment tracking with weights and biases. Requires prior CLI login. - batch_size=8, # how many samples to put into one batch. Higher batch size is more stable, but requires more VRAM. + batch_size=32, # how many samples to put into one batch. Higher batch size is more stable, but requires more VRAM. eval_lang="eng", # in which language the evaluation sentence is to be plotted. fine_tune=False, # whether to use the provided checkpoint as basis for fine-tuning. steps=200000, # how many updates to run until training is completed use_less_loss=False, # whether to use the loss that enforces a structure in the language embedding space + freeze_lang_embs=False, # whether to use the language embeddings from a checkpoint without modifying them, to maintain compatibility with the zero-shot method. This treats language embeddings from the given checkpoint as constants. ): torch.multiprocessing.set_start_method('spawn', force=True) if type(datasets) != list: @@ -48,6 +49,7 @@ def train_loop(net, # an already initialized ToucanTTS model that should be tra use_wandb=use_wandb, gpu_count=gpu_count, use_less_loss=use_less_loss, + freeze_lang_embs=freeze_lang_embs ) else: mono_language_loop(net=net, diff --git a/Architectures/ToucanTTS/wavenet.py b/Modules/ToucanTTS/wavenet.py similarity index 100% rename from Architectures/ToucanTTS/wavenet.py rename to Modules/ToucanTTS/wavenet.py diff --git a/Architectures/Vocoder/AMP.py b/Modules/Vocoder/AMP.py similarity index 98% rename from Architectures/Vocoder/AMP.py rename to Modules/Vocoder/AMP.py index 9e8021ef80c305968ef61e6a0c88d7003ebf71ba..c6ae0b4e6b67af8ac596d33b0996fbbafcb268a4 100644 --- a/Architectures/Vocoder/AMP.py +++ b/Modules/Vocoder/AMP.py @@ -11,7 +11,7 @@ from torch.nn import Conv1d from torch.nn.utils import remove_weight_norm from torch.nn.utils import weight_norm -from Architectures.Vocoder.Snake import SnakeBeta +from Modules.Vocoder.Snake import SnakeBeta LRELU_SLOPE = 0.1 diff --git a/Architectures/Vocoder/AdversarialLoss.py b/Modules/Vocoder/AdversarialLoss.py similarity index 100% rename from Architectures/Vocoder/AdversarialLoss.py rename to Modules/Vocoder/AdversarialLoss.py diff --git a/Architectures/Vocoder/Avocodo_Discriminators.py b/Modules/Vocoder/Avocodo_Discriminators.py similarity index 99% rename from Architectures/Vocoder/Avocodo_Discriminators.py rename to Modules/Vocoder/Avocodo_Discriminators.py index 4fa73984f995af67a305a8305c5b4bef8bc07904..9a9477e69eb33ec55dc6696e179754c58279d6ee 100644 --- a/Architectures/Vocoder/Avocodo_Discriminators.py +++ b/Modules/Vocoder/Avocodo_Discriminators.py @@ -14,7 +14,7 @@ from torch.nn import Conv1d from torch.nn.utils import spectral_norm from torch.nn.utils import weight_norm -from Architectures.Vocoder.SAN_modules import SANConv1d +from Modules.Vocoder.SAN_modules import SANConv1d def get_padding(kernel_size, dilation=1): @@ -238,7 +238,7 @@ class PQMF(torch.nn.Module): for k in range(N): constant_factor = (2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - - ((taps - 1) / 2)) # TODO: (taps - 1) -> taps + ((taps - 1) / 2)) phase = (-1) ** k * np.pi / 4 H[k] = 2 * QMF * np.cos(constant_factor + phase) diff --git a/Architectures/Vocoder/Avocodo_LICENSE b/Modules/Vocoder/Avocodo_LICENSE similarity index 100% rename from Architectures/Vocoder/Avocodo_LICENSE rename to Modules/Vocoder/Avocodo_LICENSE diff --git a/Architectures/Vocoder/FeatureMatchingLoss.py b/Modules/Vocoder/FeatureMatchingLoss.py similarity index 100% rename from Architectures/Vocoder/FeatureMatchingLoss.py rename to Modules/Vocoder/FeatureMatchingLoss.py diff --git a/Architectures/Vocoder/HiFiGAN_Dataset.py b/Modules/Vocoder/HiFiGAN_Dataset.py similarity index 91% rename from Architectures/Vocoder/HiFiGAN_Dataset.py rename to Modules/Vocoder/HiFiGAN_Dataset.py index 391b5c6ecc4558cb8e7c612370b999f961061a41..c7ba32c3277abd5546d01ed793e5bfb96d6e527f 100644 --- a/Architectures/Vocoder/HiFiGAN_Dataset.py +++ b/Modules/Vocoder/HiFiGAN_Dataset.py @@ -9,7 +9,6 @@ import soundfile as sf import torch import torchaudio from torch.utils.data import Dataset -from torchvision.transforms.v2 import GaussianBlur from tqdm import tqdm from Preprocessing.AudioPreprocessor import AudioPreprocessor @@ -57,9 +56,7 @@ class HiFiGANDataset(Dataset): process_list[-1].start() for process in process_list: process.join() - self.blurrer = GaussianBlur(kernel_size=(5, 5), sigma=(0.5, 2.0)) # simulating the smoothness of a generated spectrogram # self.masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True) # up to 16 consecutive bands can be masked, each element in the batch gets a different mask. Taken out because it seems too extreme. - self.spec_augs = [self.blurrer, lambda x: x, lambda x: x, lambda x: x, lambda x: x] self.wave_augs = [random_pitch_shifter, polarity_inverter, lambda x: x, lambda x: x, lambda x: x, lambda x: x] # just some data augmentation self.wave_distortions = [CodecSimulator(), lambda x: x, lambda x: x, lambda x: x, lambda x: x] # simulating the fact, that we train the TTS on codec-compressed waves print("{} eligible audios found".format(len(self.waves))) @@ -108,9 +105,6 @@ class HiFiGANDataset(Dataset): melspec = self.melspec_ap.audio_to_mel_spec_tensor(resampled_segment, explicit_sampling_rate=16000, normalize=False).transpose(0, 1)[:-1].transpose(0, 1) - if self.use_random_corruption: - # augmentations for the spec - melspec = random.choice(self.spec_augs)(melspec.unsqueeze(0)).squeeze(0) return segment.detach(), melspec.detach() except RuntimeError: print("encountered a runtime error, using fallback strategy") @@ -148,7 +142,6 @@ if __name__ == '__main__': normalize=False).transpose(0, 1)[:-1].transpose(0, 1) cs = CodecSimulator() - blurrer = GaussianBlur(kernel_size=(5, 5), sigma=(0.5, 2.0)) masker = torchaudio.transforms.FrequencyMasking(freq_mask_param=16, iid_masks=True) # up to 8 consecutive bands can be masked # testing codec simulator @@ -159,12 +152,6 @@ if __name__ == '__main__': plt.title("Codec Simulator") plt.show() - # testing Gaussian blur - blurred_spec = blurrer(spec.unsqueeze(0)).squeeze(0) - plt.imshow(blurred_spec.cpu().numpy(), origin="lower", cmap='GnBu') - plt.title("Blurred Spec") - plt.show() - # testing spectrogram masking for _ in range(5): masked_spec = masker(spec.unsqueeze(0)).squeeze(0) diff --git a/Architectures/Vocoder/HiFiGAN_Discriminators.py b/Modules/Vocoder/HiFiGAN_Discriminators.py similarity index 98% rename from Architectures/Vocoder/HiFiGAN_Discriminators.py rename to Modules/Vocoder/HiFiGAN_Discriminators.py index bd7337ba304c5cfe37d7c087bd6356908ee01b04..d6c874ab423b6f4da637bc1ef213b6e8b4a372fe 100644 --- a/Architectures/Vocoder/HiFiGAN_Discriminators.py +++ b/Modules/Vocoder/HiFiGAN_Discriminators.py @@ -9,10 +9,10 @@ import copy import torch import torch.nn.functional as F -from Architectures.Vocoder.Avocodo_Discriminators import MultiCoMBDiscriminator -from Architectures.Vocoder.Avocodo_Discriminators import MultiSubBandDiscriminator -from Architectures.Vocoder.SAN_modules import SANConv1d -from Architectures.Vocoder.SAN_modules import SANConv2d +from Modules.Vocoder.Avocodo_Discriminators import MultiCoMBDiscriminator +from Modules.Vocoder.Avocodo_Discriminators import MultiSubBandDiscriminator +from Modules.Vocoder.SAN_modules import SANConv1d +from Modules.Vocoder.SAN_modules import SANConv2d class HiFiGANPeriodDiscriminator(torch.nn.Module): diff --git a/Architectures/Vocoder/HiFiGAN_Generator.py b/Modules/Vocoder/HiFiGAN_Generator.py similarity index 98% rename from Architectures/Vocoder/HiFiGAN_Generator.py rename to Modules/Vocoder/HiFiGAN_Generator.py index 280f432f813a1a063da34ca9767197a2fd742e56..39ad04072fd35a78d918044ab47ffd5df2ba042d 100644 --- a/Architectures/Vocoder/HiFiGAN_Generator.py +++ b/Modules/Vocoder/HiFiGAN_Generator.py @@ -7,7 +7,7 @@ import torch -from Architectures.GeneralLayers.ResidualBlock import HiFiGANResidualBlock as ResidualBlock +from Modules.GeneralLayers.ResidualBlock import HiFiGANResidualBlock as ResidualBlock class HiFiGAN(torch.nn.Module): @@ -185,4 +185,4 @@ class HiFiGAN(torch.nn.Module): if __name__ == "__main__": hifi = HiFiGAN() - print(f"HiFiGAN parameter count: {sum(p.numel() for p in hifi.parameters() if p.requires_grad)}") + print(f"HiFiGAN parameter count: {sum(p.numel() for p in hifi.parameters() if p.requires_grad)}") \ No newline at end of file diff --git a/Architectures/Vocoder/HiFiGAN_LICENSE b/Modules/Vocoder/HiFiGAN_LICENSE similarity index 100% rename from Architectures/Vocoder/HiFiGAN_LICENSE rename to Modules/Vocoder/HiFiGAN_LICENSE diff --git a/Architectures/Vocoder/HiFiGAN_train_loop.py b/Modules/Vocoder/HiFiGAN_train_loop.py similarity index 95% rename from Architectures/Vocoder/HiFiGAN_train_loop.py rename to Modules/Vocoder/HiFiGAN_train_loop.py index 6b7fe7f796662529dc382e6da8adcce739b95bb8..c0837fa6fd2070bb255f0e362bb16db414739952 100644 --- a/Architectures/Vocoder/HiFiGAN_train_loop.py +++ b/Modules/Vocoder/HiFiGAN_train_loop.py @@ -8,10 +8,10 @@ from torch.optim.lr_scheduler import MultiStepLR from torch.utils.data.dataloader import DataLoader from tqdm import tqdm -from Architectures.Vocoder.AdversarialLoss import discriminator_adv_loss -from Architectures.Vocoder.AdversarialLoss import generator_adv_loss -from Architectures.Vocoder.FeatureMatchingLoss import feature_loss -from Architectures.Vocoder.MelSpecLoss import MelSpectrogramLoss +from Modules.Vocoder.AdversarialLoss import discriminator_adv_loss +from Modules.Vocoder.AdversarialLoss import generator_adv_loss +from Modules.Vocoder.FeatureMatchingLoss import feature_loss +from Modules.Vocoder.MelSpecLoss import MelSpectrogramLoss from Utility.utils import delete_old_checkpoints from Utility.utils import get_most_recent_checkpoint from run_weight_averaging import average_checkpoints @@ -44,10 +44,10 @@ def train_loop(generator, g.train() d.train() - optimizer_g = torch.optim.RAdam(g.parameters(), betas=(0.5, 0.9), lr=0.001, weight_decay=0.0) - scheduler_g = MultiStepLR(optimizer_g, gamma=0.5, milestones=[500000, 1000000, 1200000, 1400000]) - optimizer_d = torch.optim.RAdam(d.parameters(), betas=(0.5, 0.9), lr=0.0005, weight_decay=0.0) - scheduler_d = MultiStepLR(optimizer_d, gamma=0.5, milestones=[500000, 1000000, 1200000, 1400000]) + optimizer_g = torch.optim.RAdam(g.parameters(), betas=(0.5, 0.9), lr=0.0005, weight_decay=0.0) + scheduler_g = MultiStepLR(optimizer_g, gamma=0.5, milestones=[400000, 8000000, 000000, 1000000]) + optimizer_d = torch.optim.RAdam(d.parameters(), betas=(0.5, 0.9), lr=0.00025, weight_decay=0.0) + scheduler_d = MultiStepLR(optimizer_d, gamma=0.5, milestones=[400000, 8000000, 000000, 1000000]) train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, diff --git a/Architectures/Vocoder/MelSpecLoss.py b/Modules/Vocoder/MelSpecLoss.py similarity index 100% rename from Architectures/Vocoder/MelSpecLoss.py rename to Modules/Vocoder/MelSpecLoss.py diff --git a/Architectures/Vocoder/README.md b/Modules/Vocoder/README.md similarity index 100% rename from Architectures/Vocoder/README.md rename to Modules/Vocoder/README.md diff --git a/Architectures/Vocoder/SAN_LICENSE b/Modules/Vocoder/SAN_LICENSE similarity index 100% rename from Architectures/Vocoder/SAN_LICENSE rename to Modules/Vocoder/SAN_LICENSE diff --git a/Architectures/Vocoder/SAN_modules.py b/Modules/Vocoder/SAN_modules.py similarity index 100% rename from Architectures/Vocoder/SAN_modules.py rename to Modules/Vocoder/SAN_modules.py diff --git a/Architectures/Vocoder/Snake.py b/Modules/Vocoder/Snake.py similarity index 100% rename from Architectures/Vocoder/Snake.py rename to Modules/Vocoder/Snake.py diff --git a/Architectures/ToucanTTS/__init__.py b/Modules/Vocoder/__init__.py similarity index 100% rename from Architectures/ToucanTTS/__init__.py rename to Modules/Vocoder/__init__.py diff --git a/Architectures/Vocoder/__init__.py b/Modules/__init__.py similarity index 100% rename from Architectures/Vocoder/__init__.py rename to Modules/__init__.py diff --git a/Preprocessing/Codec/config_16k_320d.json b/Preprocessing/Codec/config_16k_320d.json new file mode 100644 index 0000000000000000000000000000000000000000..f63f5b5c6364d22b7d35791ff7b749864a6db7e6 --- /dev/null +++ b/Preprocessing/Codec/config_16k_320d.json @@ -0,0 +1,65 @@ +{ + "resblock": "1", + "num_gpus": 8, + "batch_size": 64, + "learning_rate": 0.0002, + "adam_b1": 0.5, + "adam_b2": 0.9, + "lr_decay": 0.98, + "seed": 1234, + "upsample_rates": [ + 8, + 5, + 4, + 2 + ], + "upsample_kernel_sizes": [ + 16, + 11, + 8, + 4 + ], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "segment_size": 16000, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 200, + "win_size": 800, + "sampling_rate": 16000, + "n_code_groups": 2, + "n_codes": 1024, + "codebook_loss_lambda": 1.0, + "commitment_loss_lambda": 0.25, + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + "num_workers": 12, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} \ No newline at end of file diff --git a/Preprocessing/Codec/config_24k_320d.json b/Preprocessing/Codec/config_24k_320d.json new file mode 100644 index 0000000000000000000000000000000000000000..e2996a80abe522f201f2d25b519431fca7680bb4 --- /dev/null +++ b/Preprocessing/Codec/config_24k_320d.json @@ -0,0 +1,65 @@ +{ + "resblock": "1", + "num_gpus": 8, + "batch_size": 80, + "learning_rate": 0.0002, + "adam_b1": 0.5, + "adam_b2": 0.9, + "lr_decay": 0.98, + "seed": 1234, + "upsample_rates": [ + 8, + 5, + 4, + 2 + ], + "upsample_kernel_sizes": [ + 16, + 11, + 8, + 4 + ], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "segment_size": 16000, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 240, + "win_size": 1024, + "sampling_rate": 24000, + "n_code_groups": 2, + "n_codes": 1024, + "codebook_loss_lambda": 1.0, + "commitment_loss_lambda": 0.25, + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": null, + "num_workers": 12, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1 + } +} diff --git a/Preprocessing/TextFrontend.py b/Preprocessing/TextFrontend.py index 946e8e887ce32f5c559aa56aafcc6f6729a8387e..e96b47b1d42f00a790c1e689cf5a2217e2bcfb7d 100644 --- a/Preprocessing/TextFrontend.py +++ b/Preprocessing/TextFrontend.py @@ -2,6 +2,7 @@ import json +import logging import re import torch @@ -29,7 +30,8 @@ class ArticulatoryCombinedTextFrontend: use_lexical_stress=True, silent=True, add_silence_to_end=True, - use_word_boundaries=True): + use_word_boundaries=True, + device="cpu"): """ Mostly preparing ID lookups """ @@ -67,7 +69,7 @@ class ArticulatoryCombinedTextFrontend: elif register_to_height[first_tone] < register_to_height[second_tone] > register_to_height[third_tone]: self.peaking_perms.append(first_tone + second_tone + third_tone) - if language == "eng": + if language == "eng" or language == "en-us": self.g2p_lang = "en-us" # English as spoken in USA self.expand_abbreviations = english_text_expansion self.phonemizer = "espeak" @@ -87,540 +89,8 @@ class ArticulatoryCombinedTextFrontend: self.expand_abbreviations = lambda x: x self.phonemizer = "espeak" - elif language == "fin": - self.g2p_lang = "fi" # Finnish - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "rus": - self.g2p_lang = "ru" # Russian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "hun": - self.g2p_lang = "hu" # Hungarian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "nld": - self.g2p_lang = "nl" # Dutch - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "fra": - self.g2p_lang = "fr-fr" # French - self.expand_abbreviations = remove_french_spacing - self.phonemizer = "espeak" - - elif language == "ita": - self.g2p_lang = "it" # Italian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "por": - self.g2p_lang = "pt" # Portuguese - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "pol": - self.g2p_lang = "pl" # Polish - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "cmn": - self.g2p_lang = "cmn" # Mandarin - self.expand_abbreviations = convert_kanji_to_pinyin_mandarin - self.phonemizer = "dragonmapper" - - elif language == "vie": - self.g2p_lang = "vi" # Northern Vietnamese - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "ukr": - self.g2p_lang = "uk" # Ukrainian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "pes": - self.g2p_lang = "fa" # Western Farsi - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "afr": - self.g2p_lang = "af" # Afrikaans - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "aln": - self.g2p_lang = "sq" # Albanian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "amh": - self.g2p_lang = "am" # Amharic - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "arb": - self.g2p_lang = "ar" # Arabic - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "arg": - self.g2p_lang = "an" # Aragonese - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "hye": - self.g2p_lang = "hy" # East Armenian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "hyw": - self.g2p_lang = "hyw" # West Armenian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "azj": - self.g2p_lang = "az" # Azerbaijani - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "bak": - self.g2p_lang = "ba" # Bashkir - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "eus": - self.g2p_lang = "eu" # Basque - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "bel": - self.g2p_lang = "be" # Belarusian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "ben": - self.g2p_lang = "bn" # Bengali - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "bpy": - self.g2p_lang = "bpy" # Bishnupriya Manipuri - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "bos": - self.g2p_lang = "bs" # Bosnian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "bul": - self.g2p_lang = "bg" # Bulgarian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "mya": - self.g2p_lang = "my" # Burmese - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "chr": - self.g2p_lang = "chr" # Cherokee - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "yue": - self.g2p_lang = "yue" # Chinese Cantonese - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "hak": - self.g2p_lang = "hak" # Chinese Hakka - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "haw": - self.g2p_lang = "haw" # Hawaiian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "hrv": - self.g2p_lang = "hr" # Croatian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "ces": - self.g2p_lang = "cs" # Czech - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "dan": - self.g2p_lang = "da" # Danish - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "ekk": - self.g2p_lang = "et" # Estonian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "gle": - self.g2p_lang = "ga" # Gaelic Irish - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "gla": - self.g2p_lang = "gd" # Gaelic Scottish - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "kat": - self.g2p_lang = "ka" # Georgian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "kal": - self.g2p_lang = "kl" # Greenlandic - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "guj": - self.g2p_lang = "gu" # Gujarati - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "heb": - self.g2p_lang = "he" # Hebrew - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "hin": - self.g2p_lang = "hi" # Hindi - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "isl": - self.g2p_lang = "is" # Icelandic - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "ind": - self.g2p_lang = "id" # Indonesian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "jpn": - self.g2p_lang = "ja" # Japanese - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "kan": - self.g2p_lang = "kn" # Kannada - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "knn": - self.g2p_lang = "kok" # Konkani - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "kor": - self.g2p_lang = "ko" # Korean - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "ckb": - self.g2p_lang = "ku" # Kurdish - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "kaz": - self.g2p_lang = "kk" # Kazakh - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "kir": - self.g2p_lang = "ky" # Kyrgyz - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "lat": - self.g2p_lang = "la" # Latin - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "ltz": - self.g2p_lang = "lb" # Luxembourgish - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "lvs": - self.g2p_lang = "lv" # Latvian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "lit": - self.g2p_lang = "lt" # Lithuanian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "mri": - self.g2p_lang = "mi" # Māori - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "mkd": - self.g2p_lang = "mk" # Macedonian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "zlm": - self.g2p_lang = "ms" # Malay - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "mal": - self.g2p_lang = "ml" # Malayalam - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "mlt": - self.g2p_lang = "mt" # Maltese - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "mar": - self.g2p_lang = "mr" # Marathi - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "nci": - self.g2p_lang = "nci" # Nahuatl - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "npi": - self.g2p_lang = "ne" # Nepali - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "nob": - self.g2p_lang = "nb" # Norwegian Bokmål - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "nog": - self.g2p_lang = "nog" # Nogai - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "ory": - self.g2p_lang = "or" # Oriya - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "gaz": - self.g2p_lang = "om" # Oromo - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "pap": - self.g2p_lang = "pap" # Papiamento - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "pan": - self.g2p_lang = "pa" # Punjabi - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "ron": - self.g2p_lang = "ro" # Romanian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "lav": - self.g2p_lang = "ru-lv" # Russian Latvia - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "srp": - self.g2p_lang = "sr" # Serbian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "tsn": - self.g2p_lang = "tn" # Setswana - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "snd": - self.g2p_lang = "sd" # Sindhi - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "slk": - self.g2p_lang = "sk" # Slovak - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "slv": - self.g2p_lang = "sl" # Slovenian - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "smj": - self.g2p_lang = "smj" # Lule Saami - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "swh": - self.g2p_lang = "sw" # Swahili - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "swe": - self.g2p_lang = "sv" # Swedish - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "tam": - self.g2p_lang = "ta" # Tamil - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "tha": - self.g2p_lang = "th" # Thai - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "tuk": - self.g2p_lang = "tk" # Turkmen - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "tat": - self.g2p_lang = "tt" # Tatar - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "tel": - self.g2p_lang = "te" # Telugu - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "tur": - self.g2p_lang = "tr" # Turkish - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "uig": - self.g2p_lang = "ug" # Uyghur - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "urd": - self.g2p_lang = "ur" # Urdu - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "uzn": - self.g2p_lang = "uz" # Uzbek - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "cym": - self.g2p_lang = "cy" # Welsh - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - else: - # blanket solution for the rest - self.g2p_lang = language - self.phonemizer = "transphone" - self.expand_abbreviations = lambda x: x - self.transphone = read_g2p() - - # remember to also update get_language_id() below when adding something here, as well as the get_example_sentence function - - if self.phonemizer == "espeak": - try: - self.phonemizer_backend = EspeakBackend(language=self.g2p_lang, - punctuation_marks=';:,.!?¡¿—…"«»“”~/。【】、‥،؟“”؛', - preserve_punctuation=True, - language_switch='remove-flags', - with_stress=self.use_stress) - except RuntimeError: - print("Error in loading espeak! \n" - "Maybe espeak is not installed on your system? \n" - "Falling back to transphone.") - from transphone.g2p import read_g2p - self.g2p_lang = self.language - self.phonemizer = "transphone" - self.expand_abbreviations = lambda x: x - self.transphone = read_g2p() - self.phone_to_vector = generate_feature_table() - self.phone_to_id = get_phone_to_id() - self.id_to_phone = {v: k for k, v in self.phone_to_id.items()} - self.text_vector_to_phone_cache = dict() - - def change_lang(self, language, - use_explicit_eos=True, - use_lexical_stress=True, - silent=True, - add_silence_to_end=True, - use_word_boundaries=True): - self.language = language - self.use_explicit_eos = use_explicit_eos - self.use_stress = use_lexical_stress - self.add_silence_to_end = add_silence_to_end - self.use_word_boundaries = use_word_boundaries - from transphone.g2p import read_g2p - - register_to_height = { - "˥": 5, - "˦": 4, - "˧": 3, - "˨": 2, - "˩": 1 - } - self.rising_perms = list() - self.falling_perms = list() - self.peaking_perms = list() - self.dipping_perms = list() - - for first_tone in ["˥", "˦", "˧", "˨", "˩"]: - for second_tone in ["˥", "˦", "˧", "˨", "˩"]: - if register_to_height[first_tone] > register_to_height[second_tone]: - self.falling_perms.append(first_tone + second_tone) - else: - self.rising_perms.append(first_tone + second_tone) - for third_tone in ["˥", "˦", "˧", "˨", "˩"]: - if register_to_height[first_tone] > register_to_height[second_tone] < register_to_height[third_tone]: - self.dipping_perms.append(first_tone + second_tone + third_tone) - elif register_to_height[first_tone] < register_to_height[second_tone] > register_to_height[third_tone]: - self.peaking_perms.append(first_tone + second_tone + third_tone) - - if language == "eng": - self.g2p_lang = "en-us" # English as spoken in USA - self.expand_abbreviations = english_text_expansion - self.phonemizer = "espeak" - - elif language == "deu": - self.g2p_lang = "de" # German - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "ell": - self.g2p_lang = "el" # Greek - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" - - elif language == "spa": - self.g2p_lang = "es" # Spanish + elif language == "spa-lat": + self.g2p_lang = "es-419" # Spanish self.expand_abbreviations = lambda x: x self.phonemizer = "espeak" @@ -649,6 +119,16 @@ class ArticulatoryCombinedTextFrontend: self.expand_abbreviations = remove_french_spacing self.phonemizer = "espeak" + elif language == "fr-be": + self.g2p_lang = "fr-be" # French + self.expand_abbreviations = remove_french_spacing + self.phonemizer = "espeak" + + elif language == "fr-sw": + self.g2p_lang = "fr-ch" # French + self.expand_abbreviations = remove_french_spacing + self.phonemizer = "espeak" + elif language == "ita": self.g2p_lang = "it" # Italian self.expand_abbreviations = lambda x: x @@ -659,6 +139,11 @@ class ArticulatoryCombinedTextFrontend: self.expand_abbreviations = lambda x: x self.phonemizer = "espeak" + elif language == "pt-br": + self.g2p_lang = "pt-br" # Portuguese + self.expand_abbreviations = lambda x: x + self.phonemizer = "espeak" + elif language == "pol": self.g2p_lang = "pl" # Polish self.expand_abbreviations = lambda x: x @@ -674,6 +159,16 @@ class ArticulatoryCombinedTextFrontend: self.expand_abbreviations = lambda x: x self.phonemizer = "espeak" + elif language == "vi-ctr": + self.g2p_lang = "vi-vn-x-central" # Central Vietnamese + self.expand_abbreviations = lambda x: x + self.phonemizer = "espeak" + + elif language == "vi-so": + self.g2p_lang = "vi-vn-x-south" # Southern Vietnamese + self.expand_abbreviations = lambda x: x + self.phonemizer = "espeak" + elif language == "ukr": self.g2p_lang = "uk" # Ukrainian self.expand_abbreviations = lambda x: x @@ -814,6 +309,11 @@ class ArticulatoryCombinedTextFrontend: self.expand_abbreviations = lambda x: x self.phonemizer = "espeak" + elif language == "en-sc": + self.g2p_lang = "en-gb-scotland" + self.expand_abbreviations = lambda x: x + self.phonemizer = "espeak" + elif language == "kat": self.g2p_lang = "ka" # Georgian self.expand_abbreviations = lambda x: x @@ -850,9 +350,13 @@ class ArticulatoryCombinedTextFrontend: self.phonemizer = "espeak" elif language == "jpn": - self.g2p_lang = "ja" # Japanese - self.expand_abbreviations = lambda x: x - self.phonemizer = "espeak" + import pykakasi + + self.kakasi = pykakasi.Kakasi() # this is not a satisfactory solution, but it is the best one I could come up with so far. + self.expand_abbreviations = lambda x: " ".join([chunk["hepburn"] for chunk in self.kakasi.convert(x)]) + self.g2p_lang = language + self.phonemizer = "transphone" + self.transphone = read_g2p(device=device) elif language == "kan": self.g2p_lang = "kn" # Kannada @@ -1076,20 +580,22 @@ class ArticulatoryCombinedTextFrontend: else: # blanket solution for the rest + print("Using Transphone. A specialized phonemizer might work better.") self.g2p_lang = language self.phonemizer = "transphone" self.expand_abbreviations = lambda x: x - self.transphone = read_g2p() + self.transphone = read_g2p(device=device) # remember to also update get_language_id() below when adding something here, as well as the get_example_sentence function if self.phonemizer == "espeak": try: self.phonemizer_backend = EspeakBackend(language=self.g2p_lang, - punctuation_marks=';:,.!?¡¿—…"«»“”~/。【】、‥،؟“”؛', + punctuation_marks=';:,.!?¡¿—…()"«»“”~/。【】、‥،؟“”؛', preserve_punctuation=True, language_switch='remove-flags', - with_stress=self.use_stress) + with_stress=self.use_stress, + logger=logging.getLogger(__file__)) except RuntimeError: print("Error in loading espeak! \n" "Maybe espeak is not installed on your system? \n" @@ -1158,50 +664,96 @@ class ArticulatoryCombinedTextFrontend: for char in phones: # affects following phoneme ----------------- - if char == '\u02C8': + if char.strip() == '\u02C8': # primary stress stressed_flag = True # affects previous phoneme ----------------- - elif char == '\u02D0': + elif char.strip() == '\u02D0': # lengthened phones_vector[-1][get_feature_to_index_lookup()["lengthened"]] = 1 - elif char == '\u02D1': + elif char.strip() == '\u02D1': # half length phones_vector[-1][get_feature_to_index_lookup()["half-length"]] = 1 - elif char == '\u0306': + elif char.strip() == '\u0306': # shortened phones_vector[-1][get_feature_to_index_lookup()["shortened"]] = 1 - elif char == '̃': + elif char.strip() == '̃' and phones_vector[-1][get_feature_to_index_lookup()["nasal"]] != 1: # nasalized (vowel) - phones_vector[-1][get_feature_to_index_lookup()["nasal"]] = 1 - elif char == "̧": + phones_vector[-1][get_feature_to_index_lookup()["nasal"]] = 2 + elif char.strip() == "̧" != phones_vector[-1][get_feature_to_index_lookup()["palatal"]] != 1: # palatalized - phones_vector[-1][get_feature_to_index_lookup()["palatal"]] = 1 - elif char == "˥": + phones_vector[-1][get_feature_to_index_lookup()["palatal"]] = 2 + elif char.strip() == "ʷ" and phones_vector[-1][get_feature_to_index_lookup()["labial-velar"]] != 1: + # labialized + phones_vector[-1][get_feature_to_index_lookup()["labial-velar"]] = 2 + elif char.strip() == "ʰ" and phones_vector[-1][get_feature_to_index_lookup()["aspirated"]] != 1: + # aspirated + phones_vector[-1][get_feature_to_index_lookup()["aspirated"]] = 2 + elif char.strip() == "ˠ" and phones_vector[-1][get_feature_to_index_lookup()["velar"]] != 1: + # velarized + phones_vector[-1][get_feature_to_index_lookup()["velar"]] = 2 + elif char.strip() == "ˁ" and phones_vector[-1][get_feature_to_index_lookup()["pharyngal"]] != 1: + # pharyngealized + phones_vector[-1][get_feature_to_index_lookup()["pharyngal"]] = 2 + elif char.strip() == "ˀ" and phones_vector[-1][get_feature_to_index_lookup()["glottal"]] != 1: + # glottalized + phones_vector[-1][get_feature_to_index_lookup()["glottal"]] = 2 + elif char.strip() == "ʼ" and phones_vector[-1][get_feature_to_index_lookup()["ejective"]] != 1: + # ejective + phones_vector[-1][get_feature_to_index_lookup()["ejective"]] = 2 + elif char.strip() == "̹" and phones_vector[-1][get_feature_to_index_lookup()["rounded"]] != 1: + # rounding + phones_vector[-1][get_feature_to_index_lookup()["rounded"]] = 2 + elif char.strip() == "̞" and phones_vector[-1][get_feature_to_index_lookup()["open"]] != 1: + # open + phones_vector[-1][get_feature_to_index_lookup()["open"]] = 2 + elif char.strip() == "̪" and phones_vector[-1][get_feature_to_index_lookup()["dental"]] != 1: + # dental + phones_vector[-1][get_feature_to_index_lookup()["dental"]] = 2 + elif char.strip() == "̬" and phones_vector[-1][get_feature_to_index_lookup()["voiced"]] != 1: + # voiced + phones_vector[-1][get_feature_to_index_lookup()["voiced"]] = 2 + elif char.strip() == "̝" and phones_vector[-1][get_feature_to_index_lookup()["close"]] != 1: + # closed + phones_vector[-1][get_feature_to_index_lookup()["close"]] = 2 + elif char.strip() == "̰" and phones_vector[-1][get_feature_to_index_lookup()["glottal"]] != 1 and phones_vector[-1][get_feature_to_index_lookup()["epiglottal"]] != 1: + # laryngalization + phones_vector[-1][get_feature_to_index_lookup()["glottal"]] = 2 + phones_vector[-1][get_feature_to_index_lookup()["epiglottal"]] = 2 + elif char.strip() == "̈" and phones_vector[-1][get_feature_to_index_lookup()["central"]] != 1: + # centralization + phones_vector[-1][get_feature_to_index_lookup()["central"]] = 2 + elif char.strip() == "̜" and phones_vector[-1][get_feature_to_index_lookup()["unrounded"]] != 1: + # unrounded + phones_vector[-1][get_feature_to_index_lookup()["unrounded"]] = 2 + elif char.strip() == "̥" and phones_vector[-1][get_feature_to_index_lookup()["unvoiced"]] != 1: + # voiceless + phones_vector[-1][get_feature_to_index_lookup()["unvoiced"]] = 2 + elif char.strip() == "˥": # very high tone phones_vector[-1][get_feature_to_index_lookup()["very-high-tone"]] = 1 - elif char == "˦": + elif char.strip() == "˦": # high tone phones_vector[-1][get_feature_to_index_lookup()["high-tone"]] = 1 - elif char == "˧": + elif char.strip() == "˧": # mid tone phones_vector[-1][get_feature_to_index_lookup()["mid-tone"]] = 1 - elif char == "˨": + elif char.strip() == "˨": # low tone phones_vector[-1][get_feature_to_index_lookup()["low-tone"]] = 1 - elif char == "˩": + elif char.strip() == "˩": # very low tone phones_vector[-1][get_feature_to_index_lookup()["very-low-tone"]] = 1 - elif char == "⭧": + elif char.strip() == "⭧": # rising tone phones_vector[-1][get_feature_to_index_lookup()["rising-tone"]] = 1 - elif char == "⭨": + elif char.strip() == "⭨": # falling tone phones_vector[-1][get_feature_to_index_lookup()["falling-tone"]] = 1 - elif char == "⮁": + elif char.strip() == "⮁": # peaking tone phones_vector[-1][get_feature_to_index_lookup()["peaking-tone"]] = 1 - elif char == "⮃": + elif char.strip() == "⮃": # dipping tone phones_vector[-1][get_feature_to_index_lookup()["dipping-tone"]] = 1 else: @@ -1212,7 +764,12 @@ class ArticulatoryCombinedTextFrontend: print("unknown phoneme: {}".format(char)) else: phones_vector.append(self.phone_to_vector[char].copy()) # leave error handling to elsewhere - + # the following lines try to emulate whispering by removing all voiced features + # phones_vector[-1][get_feature_to_index_lookup()["voiced"]] = 0 + # phones_vector[-1][get_feature_to_index_lookup()["unvoiced"]] = 1 + # the following lines explore what would happen, if the system is told to produce sounds a human cannot + # for dim, _ in enumerate(phones_vector[-1]): + # phones_vector[-1][dim] = 1 if stressed_flag: stressed_flag = False phones_vector[-1][get_feature_to_index_lookup()["stressed"]] = 1 @@ -1259,6 +816,7 @@ class ArticulatoryCombinedTextFrontend: (" ;", "~"), ("-", "~"), ("·", " "), + ("`", ""), # symbols that indicate a pause or silence ('"', "~"), (" - ", "~ "), @@ -1299,7 +857,14 @@ class ArticulatoryCombinedTextFrontend: phones = phones.replace('5', "˧˩˧") phones = phones.replace('6', "˧˩˨ʔ") # very weird tone, because the tone introduces another phoneme phones = phones.replace('7', "˧") - # TODO add more of this handling for more tonal languages + elif self.g2p_lang == "yue": + phones = phones.replace('1', "˥") + phones = phones.replace('2', "˧˥") + phones = phones.replace('3', "˧") + phones = phones.replace('4', "˧˩") + phones = phones.replace('5', "˩˧") + phones = phones.replace('6', "˨") + # more of this handling for more tonal languages can be added here, simply make an elif statement and check for the language. return self.postprocess_phoneme_string(phones, for_feature_extraction, include_eos_symbol, for_plot_labels) def postprocess_phoneme_string(self, phoneme_string, for_feature_extraction, include_eos_symbol, for_plot_labels): @@ -1330,6 +895,8 @@ class ArticulatoryCombinedTextFrontend: # latin script punctuation ("/", " "), ("—", ""), + ("(", "~"), + (")", "~"), ("...", "…"), ("\n", ", "), ("\t", " "), @@ -1338,8 +905,13 @@ class ArticulatoryCombinedTextFrontend: ("«", '"'), ("»", '"'), # unifying some phoneme representations + ("N", "ŋ"), # somehow transphone doesn't transform this to IPA ("ɫ", "l"), # alveolopalatal ("ɚ", "ə"), + ("g", "ɡ"), + ("ε", "e"), + ("ʦ", "ts"), + ("ˤ", "ˁ"), ('ᵻ', 'ɨ'), ("ɧ", "ç"), # velopalatal ("ɥ", "j"), # labiopalatal @@ -1370,11 +942,8 @@ class ArticulatoryCombinedTextFrontend: (";", "~"), (",", "~") # make sure this remains the final one when adding new ones ] - unsupported_ipa_characters = {'̹', '̙', '̞', '̯', '̤', '̪', '̩', '̠', '̟', 'ꜜ', - '̬', '̽', 'ʰ', '|', '̝', '•', 'ˠ', '↘', - '‖', '̰', '‿', 'ᷝ', '̈', 'ᷠ', '̜', 'ʷ', - '̚', '↗', 'ꜛ', '̻', '̥', 'ˁ', '̘', '͡', '̺'} - # TODO support more of these. Problem: bridge over to aligner ID lookups after modifying the feature vector + unsupported_ipa_characters = {'̙', '̯', '̤', '̩', '̠', '̟', 'ꜜ', '̽', '|', '•', '↘', + '‖', '‿', 'ᷝ', 'ᷠ', '̚', '↗', 'ꜛ', '̻', '̘', '͡', '̺'} # https://en.wikipedia.org/wiki/IPA_number for char in unsupported_ipa_characters: replacements.append((char, "")) @@ -1398,6 +967,22 @@ class ArticulatoryCombinedTextFrontend: ('⮃', ""), # dipping ('⮁', ""), # peaking ('̃', ""), # nasalizing + ("̧", ""), # palatalized + ("ʷ", ""), # labialized + ("ʰ", ""), # aspirated + ("ˠ", ""), # velarized + ("ˁ", ""), # pharyngealized + ("ˀ", ""), # glottalized + ("ʼ", ""), # ejective + ("̹", ""), # rounding + ("̞", ""), # open + ("̪", ""), # dental + ("̬", ""), # voiced + ("̝", ""), # closed + ("̰", ""), # laryngalization + ("̈", ""), # centralization + ("̜", ""), # unrounded + ("̥", ""), # voiceless ] for replacement in replacements: phoneme_string = phoneme_string.replace(replacement[0], replacement[1]) @@ -1443,11 +1028,12 @@ class ArticulatoryCombinedTextFrontend: if immutable_vector in self.text_vector_to_phone_cache: tokens.append(self.phone_to_id[self.text_vector_to_phone_cache[immutable_vector]]) continue - if vector[get_feature_to_index_lookup()["vowel"]] == 1 and vector[get_feature_to_index_lookup()["nasal"]] == 1: - # for the sake of alignment, we ignore the difference between nasalized vowels and regular vowels - features[get_feature_to_index_lookup()["nasal"]] = 0 features = features[13:] # the first 12 dimensions are for modifiers, so we ignore those when trying to find the phoneme in the ID lookup + for index in range(len(features)): + if features[index] == 2: + # we remove all features that stem from a modifier, so we can map back to the unmodified sound + features[index] = 0 for phone in self.phone_to_vector: if features == self.phone_to_vector[phone][13:]: tokens.append(self.phone_to_id[phone]) @@ -1492,33 +1078,33 @@ def get_language_id(language): except FileNotFoundError: iso_codes_to_ids = load_json_from_path("iso_lookup.json")[-1] if language not in iso_codes_to_ids: - print("Please specify the language as ISO 639-2 code (https://en.wikipedia.org/wiki/List_of_ISO_639-2_codes)") + print("Please specify the language as ISO 639-3 code (https://en.wikipedia.org/wiki/List_of_ISO_639-3_codes)") return None return torch.LongTensor([iso_codes_to_ids[language]]) if __name__ == '__main__': + print("\n\nEnglish Test") tf = ArticulatoryCombinedTextFrontend(language="eng") tf.string_to_tensor("This is a complex sentence, it even has a pause! But can it do this? Nice.", view=True) - tf = ArticulatoryCombinedTextFrontend(language="deu") - tf.string_to_tensor("Alles klar, jetzt testen wir einen deutschen Satz. Ich hoffe es gibt nicht mehr viele unspezifizierte Phoneme.", view=True) - + print("\n\nChinese Test") tf = ArticulatoryCombinedTextFrontend(language="cmn") tf.string_to_tensor("这是一个复杂的句子,它甚至包含一个停顿。", view=True) tf.string_to_tensor("李绅 《悯农》 锄禾日当午, 汗滴禾下土。 谁知盘中餐, 粒粒皆辛苦。", view=True) tf.string_to_tensor("巴 拔 把 爸 吧", view=True) + print("\n\nVietnamese Test") tf = ArticulatoryCombinedTextFrontend(language="vie") tf.string_to_tensor("Xin chào thế giới, quả là một ngày tốt lành để học nói tiếng Việt!", view=True) tf.string_to_tensor("ba bà bá bạ bả bã", view=True) - tf = ArticulatoryCombinedTextFrontend(language="fra") - tf.string_to_tensor("Je ne te fais pas un dessin.", view=True) - print(tf.get_phone_string("Je ne te fais pas un dessin.")) + print("\n\nJapanese Test") + tf = ArticulatoryCombinedTextFrontend(language="jpn") + tf.string_to_tensor("医師会がなくても、近隣の病院なら紹介してくれると思います。", view=True) + print(tf.get_phone_string("医師会がなくても、近隣の病院なら紹介してくれると思います。")) + print("\n\nZero-Shot Test") tf = ArticulatoryCombinedTextFrontend(language="acr") - tf.string_to_tensor("I don't know this language, but this is just a dummy anyway.", view=True) - print(tf.get_phone_string("I don't know this language, but this is just a dummy anyway.")) - - print(get_language_id("eng")) + tf.string_to_tensor("I don't know this language, but this is just a dummy text anyway.", view=True) + print(tf.get_phone_string("I don't know this language, but this is just a dummy text anyway.")) diff --git a/Preprocessing/UtteranceEmbeddingExtractor.py b/Preprocessing/UtteranceEmbeddingExtractor.py index 62f41229245f23e6d4a2ea4a55a47d1ce83213b5..81e7b0061c068231a391bc11a3ecfbe71658c78c 100644 --- a/Preprocessing/UtteranceEmbeddingExtractor.py +++ b/Preprocessing/UtteranceEmbeddingExtractor.py @@ -5,7 +5,7 @@ import torch.multiprocessing from speechbrain.pretrained import EncoderClassifier from torchaudio.transforms import Resample -from Architectures.EmbeddingModel.StyleEmbedding import StyleEmbedding +from Modules.EmbeddingModel.StyleEmbedding import StyleEmbedding from Preprocessing.HiFiCodecAudioPreprocessor import CodecAudioPreprocessor from Utility.storage_config import MODELS_DIR diff --git a/Preprocessing/articulatory_features.py b/Preprocessing/articulatory_features.py index 52063a901cd8389268a3f8f6319f623c4b15cb2a..ecc35a6fc20b8ed26b5fd5f382dd3896ce267ced 100644 --- a/Preprocessing/articulatory_features.py +++ b/Preprocessing/articulatory_features.py @@ -811,6 +811,14 @@ def get_phone_to_id(): phone_to_id = dict() for index, phone in enumerate("~#?!ǃ.ɜəaðɛɪŋɔɒɾʃθʊʌʒæbʔdefghijklmnɳopɡɹrstuvwxzʀøçɐœyʏɑcɲɣʎβʝɟqɕɭɵʑʋʁɨʂɓʙɗɖχʛʟɽɢɠǂɦǁĩʍʕɻʄũɤɶõʡʈʜɱɯǀɸʘʐɰɘħɞʉɴʢѵ"): phone_to_id[phone] = index + # the following lines fix an issue with the aligner: While the different punctuation marks have + # different effects on their context, their realization in the signal is typically just silence. + # Since this is common for all of them, the CTC objective malfunctions for our purposes of + # alignment search. So it turned out that it's better to map all punctuation marks to silence. + phone_to_id["#"] = phone_to_id["~"] + phone_to_id["?"] = phone_to_id["~"] + phone_to_id["!"] = phone_to_id["~"] + phone_to_id["."] = phone_to_id["~"] return phone_to_id @@ -894,10 +902,12 @@ def get_feature_to_index_lookup(): "implosive" : 57, "vibrant" : 58, "click" : 59, + "ejective" : 60, # TYPE - "unvoiced" : 60, - "voiced" : 61, + "aspirated" : 61, + "unvoiced" : 62, + "voiced" : 63, } @@ -933,8 +943,8 @@ def generate_feature_table(): phone_to_vector = dict() for ipa in ipa_to_phonemefeats: if len(ipa) == 1: - phone_to_vector[ipa] = [0] * (13 + sum([len(values) for values in [feat_to_val_set[feat] for feat in feat_to_val_set]])) - # there are 13 features which do not occur in the vectors, because they are context dependent and not lexical + phone_to_vector[ipa] = [0] * (15 + sum([len(values) for values in [feat_to_val_set[feat] for feat in feat_to_val_set]])) + # 15 features come from modifiers, not from lexical sounds, so we have to add them to the ones we encounter naturally in the lexical sounds for feat in ipa_to_phonemefeats[ipa]: if ipa_to_phonemefeats[ipa][feat] in value_to_index: phone_to_vector[ipa][value_to_index[ipa_to_phonemefeats[ipa][feat]]] = 1 diff --git a/Preprocessing/multilinguality/MetricMetaLearner.py b/Preprocessing/multilinguality/MetricMetaLearner.py index 8edafaaa240c48b6a72d0eb05f538ab997f0fb61..4fe14abdd8a92414e7c3678e589af2f0ef290b68 100644 --- a/Preprocessing/multilinguality/MetricMetaLearner.py +++ b/Preprocessing/multilinguality/MetricMetaLearner.py @@ -3,25 +3,21 @@ import os import pickle import random +import kan import torch from tqdm import tqdm -from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS -from Utility.storage_config import MODELS_DIR +from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS from Utility.utils import load_json_from_path class MetricsCombiner(torch.nn.Module): - def __init__(self): + def __init__(self, m): super().__init__() - self.scoring_function = torch.nn.Sequential(torch.nn.Linear(3, 8), - torch.nn.Tanh(), - torch.nn.Linear(8, 8), - torch.nn.Tanh(), - torch.nn.Linear(8, 1)) + self.scoring_function = kan.KAN(width=[3, 5, 1], grid=5, k=5, seed=m) def forward(self, x): - return self.scoring_function(x) + return self.scoring_function(x.squeeze()) class EnsembleModel(torch.nn.Module): @@ -64,12 +60,13 @@ def create_learned_cache(model_path, cache_root="."): print_intermediate_results = False # ensemble preparation - for _ in range(10): - model_list.append(MetricsCombiner()) - model_list[-1].train() - optim = torch.optim.Adam(model_list[-1].parameters(), lr=0.00005) + n_models = 5 + print(f"Training ensemble of {n_models} models for learned distance metric.") + for m in range(n_models): + model_list.append(MetricsCombiner(m)) + optim = torch.optim.Adam(model_list[-1].parameters(), lr=0.0005) running_loss = list() - for epoch in tqdm(range(35)): + for epoch in tqdm(range(35), desc=f"MetricsCombiner {m + 1}/{n_models} - Epoch"): for i in range(1000): # we have no dataloader, so first we build a batch embedding_distance_batch = list() @@ -109,6 +106,9 @@ def create_learned_cache(model_path, cache_root="."): print("\n\n") running_loss = list() + # model_list[-1].scoring_function.plot(folder=f"kan_vis_{m}", beta=5000) + # plt.show() + # Time to see if the final ensemble is any good ensemble = EnsembleModel(model_list) @@ -163,7 +163,7 @@ def create_learned_cache(model_path, cache_root="."): _asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)] metric_distance = torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32) with torch.inference_mode(): - predicted_distance = ensemble(metric_distance) + predicted_distance = ensemble(metric_distance.unsqueeze(0)).squeeze() language_to_language_to_learned_distance[lang_1][lang_2] = predicted_distance.item() except ValueError: continue @@ -175,4 +175,4 @@ def create_learned_cache(model_path, cache_root="."): if __name__ == '__main__': - create_learned_cache(os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt")) # MODELS_DIR must be absolute path, the relative path will fail at this location + create_learned_cache("../../Models/ToucanTTS_Meta/best.pt") diff --git a/Preprocessing/multilinguality/README.md b/Preprocessing/multilinguality/README.md index 55a4445ad7c9404c0fd3b235795ef0477eaaafcd..a88aac6266abedc2c93d92f4edf513e581a3d419 100644 --- a/Preprocessing/multilinguality/README.md +++ b/Preprocessing/multilinguality/README.md @@ -1,18 +1,21 @@ ## Zero-Shot Approximation of Language Embeddings This directory contains all scripts that are needed to reproduce the meta learning for zero-shot part of our system. These scripts allow you to predict representations of languages purely based on distances between them, as measured by a variety of linguistically informed metrics, or even better, a learned combination thereof. -### Learned distance metric -If you want to use a learned distance metric, you need to run `MetricMetaLearner.py` first to generate a lookup file for the learned distances. - -Note: **The learned distances are (obviously) only useful for the model it was trained on**, i.e., different Toucan models require different learned-distance lookups. - ### Applying zero-shot approximation to a trained model Use `run_zero_shot_lang_emb_injection.py` to update the language embeddings of a trained model for all languages that were *not* seen during training (by default, `supervised_languages.json` is used to determine which languages *were* seen). See the script for arguments that can be passed (e.g. to use a custom model path). Here is an example: ``` +cd IMS-Toucan/ python run_zero_shot_lang_emb_injection.py -m -d -k ``` -By default, the updated model is saved with a modified filename in the same directory. \ No newline at end of file +By default, the updated model is saved with a modified filename in the same directory. + +### Cached distance lookups +In order to apply any zero-shot approximation, cache files for distance lookups are required. + +The ASP lookup file (`asp_dict.pkl`) needs to be downloaded from the release page. All other cache files are automatically generated as required when running `run_zero_shot_lang_emb_injection.py`. + +**Note:** While the map, tree, and inverse ASP distances are model independent, **the learned distance lookup is only applicable for the model it was trained on**, i.e., different Toucan models require different learned-distance lookups. If you want to apply zero-shot approximation to a new model, make sure that you are not using an outdated, pre-existing learned distance lookup, but instead train a new learned distance metric. diff --git a/Preprocessing/multilinguality/create_distance_lookups.py b/Preprocessing/multilinguality/create_distance_lookups.py index cd1f27113832f028a06acb079c1a61349752de62..ee658b3040e8e06dda5c45a3e4543a8589b54341 100644 --- a/Preprocessing/multilinguality/create_distance_lookups.py +++ b/Preprocessing/multilinguality/create_distance_lookups.py @@ -28,16 +28,14 @@ class CacheCreator: self.pair_to_depth = dict() for pair in tqdm(self.pairs, desc="Generating tree pairs"): self.pair_to_tree_similarity[pair] = len(set(iso_to_family_memberships[pair[0]]).intersection(set(iso_to_family_memberships[pair[1]]))) - self.pair_to_depth[pair] = len(iso_to_family_memberships[pair[0]]) + len(iso_to_family_memberships[pair[1]]) - lang_1_to_lang_2_to_tree_dist = dict() - for pair in self.pair_to_tree_similarity: + for pair in tqdm(self.pair_to_tree_similarity): lang_1 = pair[0] lang_2 = pair[1] if self.pair_to_tree_similarity[pair] == 2: dist = 1.0 else: - dist = 1 - ((self.pair_to_tree_similarity[pair] * 2) / self.pair_to_depth[pair]) + dist = 1.0 - (self.pair_to_tree_similarity[pair] / max(len(iso_to_family_memberships[pair[0]]), len(iso_to_family_memberships[pair[1]]))) if lang_1 not in lang_1_to_lang_2_to_tree_dist.keys(): lang_1_to_lang_2_to_tree_dist[lang_1] = dict() lang_1_to_lang_2_to_tree_dist[lang_1][lang_2] = dist @@ -69,7 +67,7 @@ class CacheCreator: def create_oracle_cache(self, model_path, cache_root="."): """Oracle language-embedding distance of supervised languages is only used for evaluation, not usable for zero-shot. - + Note: The generated oracle cache is only valid for the given `model_path`!""" loss_fn = torch.nn.MSELoss(reduction="mean") self.pair_to_oracle_dist = dict() @@ -101,10 +99,10 @@ class CacheCreator: self.create_tree_cache(cache_root="Preprocessing/multilinguality") if not os.path.exists(os.path.join(self.cache_root, "lang_1_to_lang_2_to_map_dist.json")): self.create_map_cache(cache_root="Preprocessing/multilinguality") - if not os.path.exists(os.path.join(self.cache_root, "lang_1_to_lang_2_to_learned_dist.json")): - self.create_learned_cache(model_path=model_path, cache_root="Preprocessing/multilinguality") if not os.path.exists(os.path.join(self.cache_root, "asp_dict.pkl")): raise FileNotFoundError("asp_dict.pkl must be downloaded separately.") + if not os.path.exists(os.path.join(self.cache_root, "lang_1_to_lang_2_to_learned_dist.json")): + self.create_learned_cache(model_path=model_path, cache_root="Preprocessing/multilinguality") if create_oracle: if not os.path.exists(os.path.join(self.cache_root, "lang_1_to_lang_2_to_oracle_dist.json")): if not model_path: diff --git a/Preprocessing/multilinguality/create_lang_dist_dataset.py b/Preprocessing/multilinguality/create_lang_dist_dataset.py index 197160257668248c51bfa811108200eeec304b94..c9d2d8ff60df8c89de9e5982e2d4b3f2c7db4671 100644 --- a/Preprocessing/multilinguality/create_lang_dist_dataset.py +++ b/Preprocessing/multilinguality/create_lang_dist_dataset.py @@ -191,8 +191,6 @@ if __name__ == "__main__": default_model_path = os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt") # MODELS_DIR must be absolute path, the relative path will fail at this location parser = argparse.ArgumentParser() parser.add_argument("--model_path", "-m", type=str, default=default_model_path, help="model path from which to obtain pretrained language embeddings") - parser.add_argument("--learned_dist_path", type=str, default="lang_1_to_lang_2_to_learned_dist.json", - help="filepath of JSON file containing the meta-learned pairwise distances") args = parser.parse_args() dc = LangDistDatasetCreator(args.model_path) diff --git a/Preprocessing/multilinguality/eval_lang_emb_approximation.py b/Preprocessing/multilinguality/eval_lang_emb_approximation.py index b2a36c46c383d290231c8819b73b5012145f7430..36b0518c0b2fd6391c86805b16fc2f22074031de 100644 --- a/Preprocessing/multilinguality/eval_lang_emb_approximation.py +++ b/Preprocessing/multilinguality/eval_lang_emb_approximation.py @@ -85,7 +85,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe if __name__ == "__main__": - default_model_path = os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt") # MODELS_DIR must be absolute path, the relative path will fail at this location + default_model_path = os.path.join("../..", MODELS_DIR, "ToucanTTS_Meta", "best.pt") # MODELS_DIR must be absolute path, the relative path will fail at this location parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default=default_model_path, help="model path that should be used for creating oracle lang emb distance cache") parser.add_argument("--min_n_langs", type=int, default=5, help="minimum amount of languages used for averaging") diff --git a/Preprocessing/multilinguality/generate_zero_shot_lang_embs.py b/Preprocessing/multilinguality/generate_zero_shot_lang_embs.py index aac0f6e3f618a9c678a86e94069d48f9d27158f3..28c5c79d05f16a0821f7e643d8c49b85e626d378 100644 --- a/Preprocessing/multilinguality/generate_zero_shot_lang_embs.py +++ b/Preprocessing/multilinguality/generate_zero_shot_lang_embs.py @@ -1,12 +1,15 @@ -import torch +import argparse +import json +import os + import numpy as np import pandas as pd -import json -import argparse +import torch from tqdm import tqdm -import os + from Utility.storage_config import MODELS_DIR + def approximate_and_inject_language_embeddings(model_path, df, iso_lookup, min_n_langs=5, max_n_langs=25, threshold_percentile=50): # load pretrained language_embeddings model = torch.load(model_path, map_location="cpu") @@ -48,7 +51,7 @@ def approximate_and_inject_language_embeddings(model_path, df, iso_lookup, min_n threshold = np.percentile(df[closest_dist_columns[-1]], threshold_percentile) print(f"threshold: {threshold:.4f}") for row in tqdm(df.itertuples(), total=df.shape[0], desc="Approximating language embeddings"): - avg_emb = torch.zeros([16]) + avg_emb = torch.zeros([32]) # If you change the size of the language embedding in the model, you need to change the size here as well. TODO automate this dists = [getattr(row, d) for i, d in enumerate(closest_dist_columns) if i < min_n_langs or getattr(row, d) < threshold] langs = [getattr(row, l) for l in closest_lang_columns[:len(dists)]] diff --git a/Preprocessing/multilinguality/iso_to_fullname.json b/Preprocessing/multilinguality/iso_to_fullname.json index 1eed5486d6f6eb58e412584b06977d1f383a5931..bc1aa1847a5b17700d0da7b1e31168eba398b3c4 100644 --- a/Preprocessing/multilinguality/iso_to_fullname.json +++ b/Preprocessing/multilinguality/iso_to_fullname.json @@ -4601,7 +4601,7 @@ "obu": "Obulom", "oca": "Ocaina", "och": "Old Chinese", - "oci": "Occitan (post 1500)", + "oci": "Occitan", "ocu": "Atzingo Matlatzinca", "odk": "Od", "odt": "Old Dutch", diff --git a/Preprocessing/multilinguality/supervised_languages.json b/Preprocessing/multilinguality/supervised_languages.json index f950590314d3283cad126d01be81558494c75c1b..bcb83864fc39b31a0cc1975c8baab18e0d93f9c6 100644 --- a/Preprocessing/multilinguality/supervised_languages.json +++ b/Preprocessing/multilinguality/supervised_languages.json @@ -1,464 +1,473 @@ [ - "eng", - "deu", - "fra", - "spa", - "cmn", - "por", - "pol", - "ita", - "nld", - "ell", - "fin", - "vie", - "rus", - "hun", - "bem", - "swh", - "amh", - "wol", - "mal", - "chv", - "iba", - "jav", - "fon", - "hau", - "lbb", - "kik", - "lin", - "lug", - "luo", - "sxb", - "yor", - "nya", - "loz", - "toi", - "afr", - "arb", - "asm", - "ast", - "azj", - "bel", - "bul", - "ben", - "bos", - "cat", - "ceb", - "sdh", - "ces", - "cym", - "dan", - "ekk", - "pes", - "fil", - "gle", - "glg", - "guj", - "heb", - "hin", - "hrv", - "hye", - "ind", - "ibo", - "isl", - "kat", - "kam", - "kea", - "kaz", - "khm", - "kan", - "kor", - "ltz", - "lao", - "lit", - "lvs", - "mri", - "mkd", - "xng", - "mar", - "zsm", - "mlt", - "oci", - "ory", - "pan", - "pst", - "ron", - "snd", - "slk", - "slv", - "sna", - "som", - "srp", - "swe", - "tam", - "tel", - "tgk", - "tur", - "ukr", - "umb", - "urd", - "uzn", - "bhd", - "kfs", - "dgo", - "gbk", - "bgc", - "xnr", - "kfx", - "mjl", - "bfz", - "acf", - "bss", - "inb", - "nca", - "quh", - "wap", - "acr", - "bus", - "dgr", - "maz", - "nch", - "qul", - "tav", - "wmw", - "acu", - "byr", - "dik", - "iou", - "mbb", - "ncj", - "qvc", - "tbc", - "xed", - "agd", - "bzh", - "djk", - "ipi", - "mbc", - "ncl", - "qve", - "tbg", - "xon", - "agg", - "bzj", - "dop", - "jac", - "mbh", - "ncu", - "qvh", - "tbl", - "xtd", - "agn", - "caa", - "jic", - "mbj", - "ndj", - "qvm", - "tbz", - "xtm", - "agr", - "cab", - "emp", - "jiv", - "mbt", - "nfa", - "qvn", - "tca", - "yaa", - "agu", - "cap", - "jvn", - "mca", - "ngp", - "qvs", - "tcs", - "yad", - "aia", - "car", - "ese", - "mcb", - "ngu", - "qvw", - "yal", - "cax", - "kaq", - "mcd", - "nhe", - "qvz", - "tee", - "ycn", - "ake", - "cbc", - "far", - "mco", - "qwh", - "yka", - "alp", - "cbi", - "kdc", - "mcp", - "nhu", - "qxh", - "ame", - "cbr", - "gai", - "kde", - "mcq", - "nhw", - "qxn", - "tew", - "yre", - "amf", - "cbs", - "gam", - "kdl", - "mdy", - "nhy", - "qxo", - "tfr", - "yva", - "amk", - "cbt", - "geb", - "kek", - "med", - "nin", - "rai", - "zaa", - "apb", - "cbu", - "glk", - "ken", - "mee", - "nko", - "rgu", - "zab", - "apr", - "cbv", - "meq", - "tgo", - "zac", - "arl", - "cco", - "gng", - "kje", - "met", - "nlg", - "rop", - "tgp", - "zad", - "grc", - "klv", - "mgh", - "nnq", - "rro", - "zai", - "ata", - "cek", - "gub", - "kmu", - "mib", - "noa", - "ruf", - "tna", - "zam", - "atb", - "cgc", - "guh", - "kne", - "mie", - "not", - "rug", - "tnk", - "zao", - "atg", - "chf", - "knf", - "mih", - "npl", - "tnn", - "zar", - "awb", - "chz", - "gum", - "knj", - "mil", - "sab", - "tnp", - "zas", - "cjo", - "guo", - "ksr", - "mio", - "obo", - "seh", - "toc", - "zav", - "azg", - "cle", - "gux", - "kue", - "mit", - "omw", - "sey", - "tos", - "zaw", - "azz", - "cme", - "gvc", - "kvn", - "miz", - "ood", - "sgb", - "tpi", - "zca", - "bao", - "cni", - "gwi", - "kwd", - "mkl", - "shp", - "tpt", - "zga", - "bba", - "cnl", - "gym", - "kwf", - "mkn", - "ote", - "sja", - "trc", - "ziw", - "bbb", - "cnt", - "gyr", - "kwi", - "mop", - "otq", - "snn", - "ttc", - "zlm", - "cof", - "hat", - "kyc", - "mox", - "pab", - "snp", - "tte", - "zos", - "bgt", - "con", - "kyf", - "mpm", - "pad", - "tue", - "zpc", - "bjr", - "cot", - "kyg", - "mpp", - "soy", - "tuf", - "zpl", - "bjv", - "cpa", - "kyq", - "mpx", - "pao", - "tuo", - "zpm", - "bjz", - "cpb", - "hlt", - "kyz", - "mqb", - "pib", - "spp", - "zpo", - "bkd", - "cpu", - "hns", - "lac", - "mqj", - "pir", - "spy", - "txq", - "zpu", - "blz", - "crn", - "hto", - "lat", - "msy", - "pjt", - "sri", - "txu", - "zpz", - "bmr", - "cso", - "hub", - "lex", - "mto", - "pls", - "srm", - "udu", - "ztq", - "bmu", - "ctu", - "lgl", - "muy", - "poi", - "srn", - "zty", - "bnp", - "cuc", - "lid", - "mxb", - "stp", - "upv", - "zyp", - "boa", - "cui", - "huu", - "mxq", - "sus", - "ura", - "boj", - "cuk", - "huv", - "llg", - "mxt", - "poy", - "suz", - "urb", - "box", - "cwe", - "hvn", - "prf", - "urt", - "bpr", - "cya", - "ign", - "lww", - "myk", - "ptu", - "usp", - "bps", - "daa", - "ikk", - "maj", - "myy", - "vid", - "bqc", - "dah", - "nab", - "qub", - "tac", - "bqp", - "ded", - "imo", - "maq", - "nas", - "quf", - "taj", - "vmy" + "eng", + "deu", + "fra", + "spa", + "cmn", + "por", + "pol", + "ita", + "nld", + "ell", + "fin", + "vie", + "jpn", + "rus", + "hun", + "asm", + "ben", + "brx", + "dgo", + "guj", + "hin", + "kan", + "kas", + "knn", + "mai", + "mal", + "mni", + "mar", + "nep", + "ory", + "pan", + "san", + "sat", + "snd", + "tam", + "tel", + "urd", + "bem", + "swh", + "amh", + "wol", + "chv", + "iba", + "jav", + "fon", + "hau", + "lbb", + "kik", + "lin", + "lug", + "luo", + "sxb", + "yor", + "nya", + "loz", + "toi", + "afr", + "arb", + "ast", + "azj", + "bel", + "bul", + "bos", + "cat", + "ceb", + "sdh", + "ces", + "cym", + "dan", + "ekk", + "pes", + "fil", + "gle", + "glg", + "heb", + "hrv", + "hye", + "ind", + "ibo", + "isl", + "kat", + "kam", + "kea", + "kaz", + "khm", + "kor", + "ltz", + "lao", + "lit", + "lvs", + "mri", + "mkd", + "xng", + "zsm", + "mlt", + "oci", + "pst", + "ron", + "slk", + "slv", + "sna", + "som", + "srp", + "swe", + "tgk", + "tur", + "ukr", + "umb", + "uzn", + "bhd", + "kfs", + "gbk", + "bgc", + "xnr", + "kfx", + "mjl", + "bfz", + "acf", + "bss", + "inb", + "nca", + "quh", + "wap", + "acr", + "bus", + "dgr", + "maz", + "nch", + "qul", + "tav", + "wmw", + "acu", + "byr", + "dik", + "iou", + "mbb", + "ncj", + "qvc", + "tbc", + "xed", + "agd", + "bzh", + "djk", + "ipi", + "mbc", + "ncl", + "qve", + "tbg", + "xon", + "agg", + "bzj", + "dop", + "jac", + "mbh", + "ncu", + "qvh", + "tbl", + "xtd", + "agn", + "caa", + "jic", + "mbj", + "ndj", + "qvm", + "tbz", + "xtm", + "agr", + "cab", + "emp", + "jiv", + "mbt", + "nfa", + "qvn", + "tca", + "yaa", + "agu", + "cap", + "jvn", + "mca", + "ngp", + "qvs", + "tcs", + "yad", + "aia", + "car", + "ese", + "mcb", + "ngu", + "qvw", + "yal", + "cax", + "kaq", + "mcd", + "nhe", + "qvz", + "tee", + "ycn", + "ake", + "cbc", + "far", + "mco", + "qwh", + "yka", + "alp", + "cbi", + "kdc", + "mcp", + "nhu", + "qxh", + "ame", + "cbr", + "gai", + "kde", + "mcq", + "nhw", + "qxn", + "tew", + "yre", + "amf", + "cbs", + "gam", + "kdl", + "mdy", + "nhy", + "qxo", + "tfr", + "yva", + "amk", + "cbt", + "geb", + "kek", + "med", + "nin", + "rai", + "zaa", + "apb", + "cbu", + "glk", + "ken", + "mee", + "nko", + "rgu", + "zab", + "apr", + "cbv", + "meq", + "tgo", + "zac", + "arl", + "cco", + "gng", + "kje", + "met", + "nlg", + "rop", + "tgp", + "zad", + "grc", + "klv", + "mgh", + "nnq", + "rro", + "zai", + "ata", + "cek", + "gub", + "kmu", + "mib", + "noa", + "ruf", + "tna", + "zam", + "atb", + "cgc", + "guh", + "kne", + "mie", + "not", + "rug", + "tnk", + "zao", + "atg", + "chf", + "knf", + "mih", + "npl", + "tnn", + "zar", + "awb", + "chz", + "gum", + "knj", + "mil", + "sab", + "tnp", + "zas", + "cjo", + "guo", + "ksr", + "mio", + "obo", + "seh", + "toc", + "zav", + "azg", + "cle", + "gux", + "kue", + "mit", + "omw", + "sey", + "tos", + "zaw", + "azz", + "cme", + "gvc", + "kvn", + "miz", + "ood", + "sgb", + "tpi", + "zca", + "bao", + "cni", + "gwi", + "kwd", + "mkl", + "shp", + "tpt", + "zga", + "bba", + "cnl", + "gym", + "kwf", + "mkn", + "ote", + "sja", + "trc", + "ziw", + "bbb", + "cnt", + "gyr", + "kwi", + "mop", + "otq", + "snn", + "ttc", + "zlm", + "cof", + "hat", + "kyc", + "mox", + "pab", + "snp", + "tte", + "zos", + "bgt", + "con", + "kyf", + "mpm", + "pad", + "tue", + "zpc", + "bjr", + "cot", + "kyg", + "mpp", + "soy", + "tuf", + "zpl", + "bjv", + "cpa", + "kyq", + "mpx", + "pao", + "tuo", + "zpm", + "bjz", + "cpb", + "hlt", + "kyz", + "mqb", + "pib", + "spp", + "zpo", + "bkd", + "cpu", + "hns", + "lac", + "mqj", + "pir", + "spy", + "txq", + "zpu", + "blz", + "crn", + "hto", + "lat", + "msy", + "pjt", + "sri", + "txu", + "zpz", + "bmr", + "cso", + "hub", + "lex", + "mto", + "pls", + "srm", + "udu", + "ztq", + "bmu", + "ctu", + "lgl", + "muy", + "poi", + "srn", + "zty", + "bnp", + "cuc", + "lid", + "mxb", + "stp", + "upv", + "zyp", + "boa", + "cui", + "huu", + "mxq", + "sus", + "ura", + "boj", + "cuk", + "huv", + "llg", + "mxt", + "poy", + "suz", + "urb", + "box", + "cwe", + "hvn", + "prf", + "urt", + "bpr", + "cya", + "ign", + "lww", + "myk", + "ptu", + "usp", + "bps", + "daa", + "ikk", + "maj", + "myy", + "vid", + "bqc", + "dah", + "nab", + "qub", + "tac", + "bqp", + "ded", + "imo", + "maq", + "nas", + "quf", + "taj", + "vmy" ] \ No newline at end of file diff --git a/Preprocessing/multilinguality/visualize_distances.py b/Preprocessing/multilinguality/visualize_distances.py new file mode 100644 index 0000000000000000000000000000000000000000..fe46cec2fb9a3a61ba0bbd19dd1e27fcfa656af2 --- /dev/null +++ b/Preprocessing/multilinguality/visualize_distances.py @@ -0,0 +1,206 @@ +import os +import pickle + +import matplotlib.pyplot as plt +import networkx as nx +import torch +from tqdm import tqdm + +from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS +from Utility.utils import load_json_from_path + +distance_types = ["tree", "asp", "map", "learned", "l1"] +modes = ["plot_all", "plot_neighbors"] +neighbor = "Latin" +num_neighbors = 12 +distance_type = distance_types[0] # switch here +mode = modes[1] +edge_threshold = 0.01 +# TODO histograms to figure out a good threshold + +cache_root = "." +supervised_iso_codes = load_json_from_path(os.path.join(cache_root, "supervised_languages.json")) + +if distance_type == "l1": + iso_codes_to_ids = load_json_from_path(os.path.join(cache_root, "iso_lookup.json"))[-1] + model_path = "../../Models/ToucanTTS_Meta/best.pt" + checkpoint = torch.load(model_path, map_location='cpu') + embedding_provider = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"]).encoder.language_embedding + embedding_provider.requires_grad_(False) + l1_dist = dict() + seen_langs = set() + for lang_1 in supervised_iso_codes: + if lang_1 not in seen_langs: + seen_langs.add(lang_1) + l1_dist[lang_1] = dict() + for lang_2 in supervised_iso_codes: + if lang_2 not in seen_langs: # it's symmetric + l1_dist[lang_1][lang_2] = torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze()) + largest_value_l1_dist = 0.0 + for _, values in l1_dist.items(): + for _, value in values.items(): + largest_value_l1_dist = max(largest_value_l1_dist, value) + for key1 in l1_dist: + for key2 in l1_dist[key1]: + l1_dist[key1][key2] = l1_dist[key1][key2] / largest_value_l1_dist + distance_measure = l1_dist + +if distance_type == "tree": + tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json") + tree_dist = load_json_from_path(tree_lookup_path) + distance_measure = tree_dist + +if distance_type == "map": + map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json") + map_dist = load_json_from_path(map_lookup_path) + largest_value_map_dist = 0.0 + for _, values in map_dist.items(): + for _, value in values.items(): + largest_value_map_dist = max(largest_value_map_dist, value) + for key1 in map_dist: + for key2 in map_dist[key1]: + map_dist[key1][key2] = map_dist[key1][key2] / largest_value_map_dist + distance_measure = map_dist + +if distance_type == "learned": + learned_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json") + learned_dist = load_json_from_path(learned_lookup_path) + largest_value_learned_dist = 0.0 + for _, values in learned_dist.items(): + for _, value in values.items(): + largest_value_learned_dist = max(largest_value_learned_dist, value) + for key1 in learned_dist: + for key2 in learned_dist[key1]: + learned_dist[key1][key2] = learned_dist[key1][key2] / largest_value_learned_dist + distance_measure = learned_dist + +if distance_type == "asp": + asp_dict_path = os.path.join(cache_root, "asp_dict.pkl") + with open(asp_dict_path, 'rb') as dictfile: + asp_sim = pickle.load(dictfile) + lang_list = list(asp_sim.keys()) + asp_dist = dict() + seen_langs = set() + for lang_1 in lang_list: + if lang_1 not in seen_langs: + seen_langs.add(lang_1) + asp_dist[lang_1] = dict() + for index, lang_2 in enumerate(lang_list): + if lang_2 not in seen_langs: # it's symmetric + asp_dist[lang_1][lang_2] = 1 - asp_sim[lang_1][index] + distance_measure = asp_dist + +iso_codes_to_names = load_json_from_path(os.path.join(cache_root, "iso_to_fullname.json")) +distances = list() + +for lang_1 in distance_measure: + if lang_1 not in iso_codes_to_names: + continue + if lang_1 not in supervised_iso_codes and iso_codes_to_names[lang_1] != neighbor: + continue + for lang_2 in distance_measure[lang_1]: + try: + if lang_2 not in supervised_iso_codes and iso_codes_to_names[lang_2] != neighbor: + continue + except KeyError: + continue + distances.append((iso_codes_to_names[lang_1], iso_codes_to_names[lang_2], distance_measure[lang_1][lang_2])) + +# Create a graph +G = nx.Graph() + +# Add edges along with distances as weights +min_dist = min(d for _, _, d in distances) +max_dist = max(d for _, _, d in distances) +normalized_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances] + +if mode == "plot_neighbors": + fullnames = list() + fullnames.append(neighbor) + for code in supervised_iso_codes: + fullnames.append(iso_codes_to_names[code]) + supervised_iso_codes = fullnames + d_dist = list() + for entity1, entity2, d in tqdm(normalized_distances): + if (neighbor == entity2 or neighbor == entity1) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): + if entity1 != entity2: + d_dist.append(d) + thresh = sorted(d_dist)[num_neighbors] + # distance_scores = sorted(d_dist)[:num_neighbors] + neighbors = list() + for entity1, entity2, d in tqdm(normalized_distances): + if (d < thresh and (neighbor == entity2 or neighbor == entity1)) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): + neighbors.append(entity1) + neighbors.append(entity2) + unique_neighbors = list(set(neighbors)) + unique_neighbors.remove(neighbor) + for entity1, entity2, d in tqdm(normalized_distances): + if (neighbor == entity2 or neighbor == entity1) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): + if entity1 != entity2 and d < thresh: + spring_tension = ((thresh - d) ** 2) * 20000 # for vis purposes + print(f"{d}-->{spring_tension}") + G.add_edge(entity1, entity2, weight=spring_tension) + for entity1, entity2, d in tqdm(normalized_distances): + if (entity2 in unique_neighbors and entity1 in unique_neighbors) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): + if entity1 != entity2: + spring_tension = 1 - d + G.add_edge(entity1, entity2, weight=spring_tension) + + # Draw the graph + pos = nx.spring_layout(G, weight="weight") # Positions for all nodes + edges = G.edges(data=True) + + # Draw nodes + nx.draw_networkx_nodes(G, pos, node_size=1, alpha=0.01) + + # Draw edges with labels + edges_connected_to_specific_node = [(u, v) for u, v in G.edges() if u == neighbor or v == neighbor] + # nx.draw_networkx_edges(G, pos, alpha=0.1) + nx.draw_networkx_edges(G, pos, edgelist=edges_connected_to_specific_node, edge_color='red', alpha=0.3, width=3) + for u, v, d in edges: + if u == neighbor or v == neighbor: + nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): round((thresh - (d['weight'] / 20000) ** (1 / 2)) * 10, 2)}, font_color="red", alpha=0.3) # reverse modifications + else: + pass + # nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['weight']}) + + # Draw node labels + nx.draw_networkx_labels(G, pos, font_size=14, font_family='sans-serif', font_color='green') + nx.draw_networkx_labels(G, pos, labels={neighbor: neighbor}, font_size=14, font_family='sans-serif', font_color='red') + + plt.title(f'Graph of {distance_type} Distances') + + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + plt.tight_layout(pad=0) + + plt.savefig("avg.png", dpi=300) + plt.show() + + + +elif mode == "plot_all": + for entity1, entity2, d in tqdm(normalized_distances): + if d < edge_threshold and entity1 != entity2: + spring_tension = edge_threshold - d + G.add_edge(entity1, entity2, weight=spring_tension) + + # Draw the graph + pos = nx.spring_layout(G, weight="weight") # Positions for all nodes + edges = G.edges(data=True) + + # Draw nodes + nx.draw_networkx_nodes(G, pos, node_size=1, alpha=0.01) + + # Draw edges with labels + nx.draw_networkx_edges(G, pos, alpha=0.1, edge_color="blue") + # nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['weight'] for u, v, d in edges}) + + # Draw node labels + nx.draw_networkx_labels(G, pos, font_size=10, font_family='sans-serif') + + plt.title(f'Graph of {distance_type} Distances') + + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + plt.tight_layout(pad=0) + + plt.show() diff --git a/Utility/Scorer.py b/Utility/Scorer.py index 1651bf1d32ad457d7cf4529b1c54e65379fe70d7..52bb4262b5f58c3c78ef7f6384d277ea6bf0bc3a 100644 --- a/Utility/Scorer.py +++ b/Utility/Scorer.py @@ -6,11 +6,14 @@ find mispronunciations or errors in the labels. The TTS scorer can help you find outliers in the audio part of text-audio pairs. """ +import math +import statistics + import torch import torch.multiprocessing from tqdm import tqdm -from Architectures.ToucanTTS.ToucanTTS import ToucanTTS +from Modules.ToucanTTS.ToucanTTS import ToucanTTS from Preprocessing.AudioPreprocessor import AudioPreprocessor from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor from Utility.corpus_preparation import prepare_tts_corpus @@ -76,8 +79,8 @@ class TTSScorer: utterance_embedding=utterance_embedding, lang_ids=lang_ids, return_feats=False, - run_glow=False) - loss = regression_loss + duration_loss + pitch_loss + energy_loss # we omit the glow loss + run_stochastic=False) + loss = regression_loss # + duration_loss + pitch_loss + energy_loss # we omit the stochastic loss except TypeError: loss = torch.tensor(torch.nan) if torch.isnan(loss): @@ -125,6 +128,28 @@ class TTSScorer: self.current_dset.remove_samples(remove_ids) self.nans_removed = True + def remove_samples_with_loss_three_std_devs_higher_than_avg(self): + if self.current_dset is None: + print("Please run the scoring first.") + else: + if self.nans_removed: + print("Indexes are no longer accurate. Please re-run the scoring. \n\n" + "This function also removes NaNs, so if you want to remove the NaN samples and the outliers, only call this one here.") + else: + remove_ids = list() + remove_ids.extend(self.nan_indexes) + scores_without_nans = [value for value in list(self.path_to_score.values()) if not math.isnan(value)] + avg = statistics.mean(scores_without_nans) + std = statistics.stdev(scores_without_nans) + thresh = avg + (3 * std) + for path in self.path_to_score: + if not math.isnan(self.path_to_score[path]): + if self.path_to_score[path] > thresh: # we found an outlier! + remove_ids.append(self.path_to_id[path]) + print(f"removing {len(remove_ids)} outliers!") + self.current_dset.remove_samples(remove_ids) + self.nans_removed = True + def remove_nans(self): if self.nans_removed: print("NaNs have already been removed!") diff --git a/Utility/corpus_preparation.py b/Utility/corpus_preparation.py index 747122e47f43d1f18a678e7086b9f4e7738169fa..c55fbfc5b1d5427eea01307d5352df7c8cfa9354 100644 --- a/Utility/corpus_preparation.py +++ b/Utility/corpus_preparation.py @@ -1,8 +1,8 @@ import torch.multiprocessing -from Architectures.Aligner.CodecAlignerDataset import CodecAlignerDataset -from Architectures.Aligner.autoaligner_train_loop import train_loop as train_aligner -from Architectures.ToucanTTS.TTSDataset import TTSDataset +from Modules.Aligner.CodecAlignerDataset import CodecAlignerDataset +from Modules.Aligner.autoaligner_train_loop import train_loop as train_aligner +from Modules.ToucanTTS.TTSDataset import TTSDataset from Utility.path_to_transcript_dicts import * from Utility.storage_config import MODELS_DIR @@ -13,7 +13,7 @@ def prepare_aligner_corpus(transcript_dict, corpus_dir, lang, device, phone_inpu return CodecAlignerDataset(transcript_dict, cache_dir=corpus_dir, lang=lang, - loading_processes=10, # this can be increased for massive clusters, but the overheads that are introduced are kind of not really worth it + loading_processes=5, # this can be increased for massive clusters, but the overheads that are introduced are kind of not really worth it device=device, phone_input=phone_input, gpu_count=gpu_count, diff --git a/Utility/path_to_transcript_dicts.py b/Utility/path_to_transcript_dicts.py index e65d87a6ac8ec20dc3fb9d2f7f0480197669a0a3..fececfbdfe7a7b61955492410c8962fa2bc63a24 100644 --- a/Utility/path_to_transcript_dicts.py +++ b/Utility/path_to_transcript_dicts.py @@ -1,6 +1,10 @@ import glob +import json import os import random +import xml.etree.ElementTree as ET +from csv import DictReader +from pathlib import Path import torch @@ -66,8 +70,156 @@ def build_path_to_transcript_dict_hui_template(root): return path_to_transcript +def indic_voices_template(root, lang): + path_to_transcript = dict() + transcripts = list() + import json + for jpath in [f"{root}/{lang}/metadata_test.json", + f"{root}/{lang}/metadata_train.json"]: + with open(jpath, encoding='utf-8', mode='r') as jfile: + for line in jfile.read().split("\n"): + if line.strip() != "": + transcripts.append(json.loads(line)) + for transcript in transcripts: + path = f"{root}/{lang}/{lang}/wavs/{transcript['filepath']}" + norm_text = transcript["normalized"] + path_to_transcript[path] = norm_text + return path_to_transcript + + # ENGLISH +def build_path_to_transcript_dict_ears(re_cache=False): + transcript_for_ears = { + "emo_adoration_sentences" : "You're just the sweetest person I know and I am so happy to call you my friend. I had the best time with you, I just adore you. I love this gift, thank you!", + "emo_amazement_sentences" : "I just love how you can play guitar. You're so impressive. I admire your abilities so much.", + "emo_amusement_sentences" : "The sound that baby just made was quite amusing. I liked that stand up comic, I found her pretty funny. What a fun little show to watch!", + "emo_anger_sentences" : "I'm so mad right now I could punch a hole in the wall. I can't believe he said that, he's such a jerk! There's a stop sign there and parents are just letting their kids run around!", + "emo_confusion_sentences" : "Huh, what is going on over here? What is this? Where are we going?", + "emo_contentment_sentences" : "I really enjoyed dinner tonight, it was quite nice. Everything is working out just fine. I'm good either way.", + "emo_cuteness_sentences" : "Look at that cute little kitty cat! Oh my goodness, she's so cute! That's the cutest thing I've ever seen!", + "emo_desire_sentences" : "Mmm that chocolate fudge lava cake looks divine. I want that car so badly. I can't wait to see you again.", + "emo_disappointment_sentences": "I'm so disappointed in myself. I wish I had worked harder. I had such higher expectations for you. I really was hoping you were better than this.", + "emo_disgust_sentences" : "I have never seen anything grosser than this in my entire life. This is the worst dinner I've ever had. Yuck, I can't even look at that.", + "emo_distress_sentences" : "Oh god, I am not sure if we are going to make this flight on time. This is all too stressful to handle right now. I don't know where anything is and I'm running late.", + "emo_embarassment_sentences" : "I don't know what happened, I followed the recipe perfectly but the cake just deflated. I'm so embarrassed. I hope no one saw that, I'd be mortified if they did.", + "emo_extasy_sentences" : "This is the most exciting thing I've ever seen in my life! I can't believe I got to see that. I'm so excited, I've never been there before.", + "emo_fear_sentences" : "Did you hear that sound? I'm afraid someone or something is outside. Oh my gosh, what is that? What do you think is going to happen if we don't run?", + "emo_guilt_sentences" : "I'm sorry I did that to you. I really didn't mean to hurt you. I feel horrible that happened to you.", + "emo_interest_sentences" : "Hmm, I wonder what that cookie tastes like. Oh, what is that over there? So what exactly is it that you do?", + "emo_neutral_sentences" : "That wall in the living room is white. There is one more piece of bread in the pantry. The store closes at 8pm tonight.", + "emo_pain_sentences" : "Oh, this headache is the worst one I've ever had! My foot hurts so badly right now! I'm in terrible pain from that medication.", + "emo_pride_sentences" : "That was all me, I'm the one who found the project, created the company and made it succeed. I have worked hard to get here and I deserve it. I'm really proud of how well you did.", + "emo_realization_sentences" : "Wow, I never know that the body was made up of 75% water. Did you know that a flamingo is actually white but turns pink because it eats too many shrimp? Apparently dolphins sleep with one eye open.", + "emo_relief_sentences" : "I'm so relieved my taxes are done. That was so stressful. I'm so relieved that is over with. Thank goodness that's all done.", + "emo_sadness_sentences" : "I am so upset by the state of the world. I hope it gets better soon. I really miss her, life isn't the same without her. I'm sorry for your loss.", + "emo_serenity_sentences" : "This has been the most peaceful day of my life. I am very calm right now. I'm going to relax and take a nap here on the beach.", + "rainbow_01_fast" : "When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. The rainbow is a division of white light into many beautiful colors.", + "rainbow_01_highpitch" : "When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. The rainbow is a division of white light into many beautiful colors.", + "rainbow_01_loud" : "When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. The rainbow is a division of white light into many beautiful colors.", + "rainbow_01_lowpitch" : "When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. The rainbow is a division of white light into many beautiful colors.", + "rainbow_01_regular" : "When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. The rainbow is a division of white light into many beautiful colors.", + "rainbow_01_slow" : "When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. The rainbow is a division of white light into many beautiful colors.", + "rainbow_01_whisper" : "When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow. The rainbow is a division of white light into many beautiful colors.", + "rainbow_02_fast" : "These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon. There is, according to legend, a boiling pot of gold at one end.", + "rainbow_02_highpitch" : "These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon. There is, according to legend, a boiling pot of gold at one end.", + "rainbow_02_loud" : "These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon. There is, according to legend, a boiling pot of gold at one end.", + "rainbow_02_lowpitch" : "These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon. There is, according to legend, a boiling pot of gold at one end.", + "rainbow_02_regular" : "These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon. There is, according to legend, a boiling pot of gold at one end.", + "rainbow_02_slow" : "These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon. There is, according to legend, a boiling pot of gold at one end.", + "rainbow_02_whisper" : "These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon. There is, according to legend, a boiling pot of gold at one end.", + "rainbow_03_fast" : "People look, but no one ever finds it. When a man looks for something beyond his reach, his friends say he is looking for the pot of gold at the end of the rainbow. Throughout the centuries people have explained the rainbow in various ways.", + "rainbow_03_highpitch" : "People look, but no one ever finds it. When a man looks for something beyond his reach, his friends say he is looking for the pot of gold at the end of the rainbow. Throughout the centuries people have explained the rainbow in various ways.", + "rainbow_03_loud" : "People look, but no one ever finds it. When a man looks for something beyond his reach, his friends say he is looking for the pot of gold at the end of the rainbow. Throughout the centuries people have explained the rainbow in various ways.", + "rainbow_03_lowpitch" : "People look, but no one ever finds it. When a man looks for something beyond his reach, his friends say he is looking for the pot of gold at the end of the rainbow. Throughout the centuries people have explained the rainbow in various ways.", + "rainbow_03_regular" : "People look, but no one ever finds it. When a man looks for something beyond his reach, his friends say he is looking for the pot of gold at the end of the rainbow. Throughout the centuries people have explained the rainbow in various ways.", + "rainbow_03_slow" : "People look, but no one ever finds it. When a man looks for something beyond his reach, his friends say he is looking for the pot of gold at the end of the rainbow. Throughout the centuries people have explained the rainbow in various ways.", + "rainbow_03_whisper" : "People look, but no one ever finds it. When a man looks for something beyond his reach, his friends say he is looking for the pot of gold at the end of the rainbow. Throughout the centuries people have explained the rainbow in various ways.", + "rainbow_04_fast" : "Some have accepted it as a miracle without physical explanation. To the Hebrews it was a token that there would be no more universal floods. The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain.", + "rainbow_04_highpitch" : "Some have accepted it as a miracle without physical explanation. To the Hebrews it was a token that there would be no more universal floods. The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain.", + "rainbow_04_loud" : "Some have accepted it as a miracle without physical explanation. To the Hebrews it was a token that there would be no more universal floods. The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain.", + "rainbow_04_lowpitch" : "Some have accepted it as a miracle without physical explanation. To the Hebrews it was a token that there would be no more universal floods. The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain.", + "rainbow_04_regular" : "Some have accepted it as a miracle without physical explanation. To the Hebrews it was a token that there would be no more universal floods. The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain.", + "rainbow_04_slow" : "Some have accepted it as a miracle without physical explanation. To the Hebrews it was a token that there would be no more universal floods. The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain.", + "rainbow_04_whisper" : "Some have accepted it as a miracle without physical explanation. To the Hebrews it was a token that there would be no more universal floods. The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain.", + "rainbow_05_fast" : "The Norsemen considered the rainbow as a bridge over which the gods passed from earth to their home in the sky. Others have tried to explain the phenomenon physically. Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain.", + "rainbow_05_highpitch" : "The Norsemen considered the rainbow as a bridge over which the gods passed from earth to their home in the sky. Others have tried to explain the phenomenon physically. Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain.", + "rainbow_05_loud" : "The Norsemen considered the rainbow as a bridge over which the gods passed from earth to their home in the sky. Others have tried to explain the phenomenon physically. Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain.", + "rainbow_05_lowpitch" : "The Norsemen considered the rainbow as a bridge over which the gods passed from earth to their home in the sky. Others have tried to explain the phenomenon physically. Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain.", + "rainbow_05_regular" : "The Norsemen considered the rainbow as a bridge over which the gods passed from earth to their home in the sky. Others have tried to explain the phenomenon physically. Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain.", + "rainbow_05_slow" : "The Norsemen considered the rainbow as a bridge over which the gods passed from earth to their home in the sky. Others have tried to explain the phenomenon physically. Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain.", + "rainbow_05_whisper" : "The Norsemen considered the rainbow as a bridge over which the gods passed from earth to their home in the sky. Others have tried to explain the phenomenon physically. Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain.", + "rainbow_06_fast" : "Since then physicists have found that it is not reflection, but refraction by the raindrops which causes the rainbows. Many complicated ideas about the rainbow have been formed.", + "rainbow_06_highpitch" : "Since then physicists have found that it is not reflection, but refraction by the raindrops which causes the rainbows. Many complicated ideas about the rainbow have been formed.", + "rainbow_06_loud" : "Since then physicists have found that it is not reflection, but refraction by the raindrops which causes the rainbows. Many complicated ideas about the rainbow have been formed.", + "rainbow_06_lowpitch" : "Since then physicists have found that it is not reflection, but refraction by the raindrops which causes the rainbows. Many complicated ideas about the rainbow have been formed.", + "rainbow_06_regular" : "Since then physicists have found that it is not reflection, but refraction by the raindrops which causes the rainbows. Many complicated ideas about the rainbow have been formed.", + "rainbow_06_slow" : "Since then physicists have found that it is not reflection, but refraction by the raindrops which causes the rainbows. Many complicated ideas about the rainbow have been formed.", + "rainbow_06_whisper" : "Since then physicists have found that it is not reflection, but refraction by the raindrops which causes the rainbows. Many complicated ideas about the rainbow have been formed.", + "rainbow_07_fast" : "The difference in the rainbow depends considerably upon the size of the drops, and the width of the colored band increases as the size of the drops increases. The actual primary rainbow observed is said to be the effect of super-imposition of a number of bows.", + "rainbow_07_highpitch" : "The difference in the rainbow depends considerably upon the size of the drops, and the width of the colored band increases as the size of the drops increases. The actual primary rainbow observed is said to be the effect of super-imposition of a number of bows.", + "rainbow_07_loud" : "The difference in the rainbow depends considerably upon the size of the drops, and the width of the colored band increases as the size of the drops increases. The actual primary rainbow observed is said to be the effect of super-imposition of a number of bows.", + "rainbow_07_lowpitch" : "The difference in the rainbow depends considerably upon the size of the drops, and the width of the colored band increases as the size of the drops increases. The actual primary rainbow observed is said to be the effect of super-imposition of a number of bows.", + "rainbow_07_regular" : "The difference in the rainbow depends considerably upon the size of the drops, and the width of the colored band increases as the size of the drops increases. The actual primary rainbow observed is said to be the effect of super-imposition of a number of bows.", + "rainbow_07_slow" : "The difference in the rainbow depends considerably upon the size of the drops, and the width of the colored band increases as the size of the drops increases. The actual primary rainbow observed is said to be the effect of super-imposition of a number of bows.", + "rainbow_07_whisper" : "The difference in the rainbow depends considerably upon the size of the drops, and the width of the colored band increases as the size of the drops increases. The actual primary rainbow observed is said to be the effect of super-imposition of a number of bows.", + "rainbow_08_fast" : "If the red of the second bow falls upon the green of the first, the result is to give a bow with an abnormally wide yellow band, since red and green light when mixed form yellow. This is a very common type of bow, one showing mainly red and yellow, with little or no green or blue.", + "rainbow_08_highpitch" : "If the red of the second bow falls upon the green of the first, the result is to give a bow with an abnormally wide yellow band, since red and green light when mixed form yellow. This is a very common type of bow, one showing mainly red and yellow, with little or no green or blue.", + "rainbow_08_loud" : "If the red of the second bow falls upon the green of the first, the result is to give a bow with an abnormally wide yellow band, since red and green light when mixed form yellow. This is a very common type of bow, one showing mainly red and yellow, with little or no green or blue.", + "rainbow_08_lowpitch" : "If the red of the second bow falls upon the green of the first, the result is to give a bow with an abnormally wide yellow band, since red and green light when mixed form yellow. This is a very common type of bow, one showing mainly red and yellow, with little or no green or blue.", + "rainbow_08_regular" : "If the red of the second bow falls upon the green of the first, the result is to give a bow with an abnormally wide yellow band, since red and green light when mixed form yellow. This is a very common type of bow, one showing mainly red and yellow, with little or no green or blue.", + "rainbow_08_slow" : "If the red of the second bow falls upon the green of the first, the result is to give a bow with an abnormally wide yellow band, since red and green light when mixed form yellow. This is a very common type of bow, one showing mainly red and yellow, with little or no green or blue.", + "rainbow_08_whisper" : "If the red of the second bow falls upon the green of the first, the result is to give a bow with an abnormally wide yellow band, since red and green light when mixed form yellow. This is a very common type of bow, one showing mainly red and yellow, with little or no green or blue.", + "sentences_01_fast" : "I will not stay here. God, we simply must dress the character. Stay, stay, I will go myself. May one ask what it is for. He rushed to the window and opened the movable pane.", + "sentences_01_highpitch" : "I will not stay here. God, we simply must dress the character. Stay, stay, I will go myself. May one ask what it is for. He rushed to the window and opened the movable pane.", + "sentences_01_loud" : "I will not stay here. God, we simply must dress the character. Stay, stay, I will go myself. May one ask what it is for. He rushed to the window and opened the movable pane.", + "sentences_01_lowpitch" : "I will not stay here. God, we simply must dress the character. Stay, stay, I will go myself. May one ask what it is for. He rushed to the window and opened the movable pane.", + "sentences_01_regular" : "I will not stay here. God, we simply must dress the character. Stay, stay, I will go myself. May one ask what it is for. He rushed to the window and opened the movable pane.", + "sentences_01_slow" : "I will not stay here. God, we simply must dress the character. Stay, stay, I will go myself. May one ask what it is for. He rushed to the window and opened the movable pane.", + "sentences_01_whisper" : "I will not stay here. God, we simply must dress the character. Stay, stay, I will go myself. May one ask what it is for. He rushed to the window and opened the movable pane.", + "sentences_02_fast" : "It might happen, he added with an involuntary smile. It is sold, sir, was again his laconic reply. And you must have some water, my dear fellow. What is that flying about? Who wants a dead cert for the Gold cup?", + "sentences_02_highpitch" : "It might happen, he added with an involuntary smile. It is sold, sir, was again his laconic reply. And you must have some water, my dear fellow. What is that flying about? Who wants a dead cert for the Gold cup?", + "sentences_02_loud" : "It might happen, he added with an involuntary smile. It is sold, sir, was again his laconic reply. And you must have some water, my dear fellow. What is that flying about? Who wants a dead cert for the Gold cup?", + "sentences_02_lowpitch" : "It might happen, he added with an involuntary smile. It is sold, sir, was again his laconic reply. And you must have some water, my dear fellow. What is that flying about? Who wants a dead cert for the Gold cup?", + "sentences_02_regular" : "It might happen, he added with an involuntary smile. It is sold, sir, was again his laconic reply. And you must have some water, my dear fellow. What is that flying about? Who wants a dead cert for the Gold cup?", + "sentences_02_slow" : "It might happen, he added with an involuntary smile. It is sold, sir, was again his laconic reply. And you must have some water, my dear fellow. What is that flying about? Who wants a dead cert for the Gold cup?", + "sentences_02_whisper" : "It might happen, he added with an involuntary smile. It is sold, sir, was again his laconic reply. And you must have some water, my dear fellow. What is that flying about? Who wants a dead cert for the Gold cup?", + "sentences_03_whisper" : "Had it been but one, it had been easy. We have boxed the compass among us. I shall rush out and prevent it. All that is mean slander. The doctor seemed tired and in a hurry.", + "sentences_04_whisper" : "I only heard it last night. We had now got into the month of March. But go thy ways; I had forgot. Conceited fellow with his waxed up moustache! Anne's unhappiness continued for a week.", + "sentences_05_loud" : "In fact, the count's face brightened. For God's sake, talk to her. In what an amiable light does this place him! Take me out of my way. I heard many things in hell.", + "sentences_06_loud" : "Yes; but we do not invite people of fashion. You see what he writes. Silent with awe and pity I went to her bedside. Happy to say, I never knew him. Birthdays are of no importance to a rational being.", + "sentences_07_slow" : "But it may all be put in two words. Clear up the room, the sick man said with effort. He was still in sight. He delayed; he seemed almost afraid of something. Then they carried me in.", + "sentences_08_slow" : "But I have never been presented. But we were only in fun! Now, look at that third name. And serve them both right, too. Good glass of burgundy take away that.", + "sentences_09_fast" : "And it seemed to her that God heard her prayer. My word, I admire you. I also have a pious visit to pay. She has promised to come on the twentieth. I want to tell you something.", + "sentences_10_fast" : "Oh, sir, it will break bones. I am very glad to see you. This question absorbed all his mental powers. Before going away forever, I'll tell him all. I told you it was mother.", + "sentences_11_highpitch" : "You're all in good spirits. They might retreat and leave the pickets. But I like sentimental people. Our potato crop is very good this year. Why is the chestnut on the right?", + "sentences_12_highpitch" : "His room was on the first floor. I have had a pattern in my hand. The knocking still continued and grew louder. May my sorrows ever shun the light. How must I arrange it, then?", + "sentences_13_lowpitch" : "Just read it out to me. I shall take your advice in every particular. What mortal imagination could conceive it? The gate was again hidden by smoke. After a while I left him.", + "sentences_14_lowpitch" : "There was a catch in her breath. They told me, but I didn't understand. What a cabin it is. A cry of joy broke from his lips. He had obviously prepared the sentence beforehand.", + "sentences_15_regular" : "They were all sitting in her room. So that's how it stands. He did not know why he embraced it. Why don't you speak, cousin? I didn't tell a tale.", + "sentences_16_regular" : "My head aches dreadfully now. Not to say every word. I have only found out. He is trying to discover something. I have done my duty.", + "sentences_17_regular" : "I always had a value for him. He is a deceiver and a villain. But those tears were pleasant to them both. She conquered her fears, and spoke. Oh, he couldn't overhear me at the door.", + "sentences_18_regular" : "How could I have said it more directly? She remembered her oath. My kingdom for a drink! Have they caught the little girl and the boy? Then she gave him the dry bread.", + "sentences_19_regular" : "Your sister is given to government. Water was being sprinkled on his face. The clumsy things are dear. He jumped up and sat on the sofa. How do you know her?", + "sentences_20_regular" : "I never could guess a riddle in my life. The expression of her face was cold. Besides, what on earth could happen to you? Allow me to give you a piece of advice. This must be stopped at once.", + "sentences_21_regular" : "The lawyer was right about that. You are fond of fighting. Every word is so deep. So you were never in London before? Death is now, perhaps, striking a fourth blow.", + "sentences_22_regular" : "It seemed that sleep and night had resumed their empire. The snowstorm was still raging. But we'll talk later on. Take the baby, Mum, and give me your book. The doctor gave him his hand.", + "sentences_23_regular" : "It is, nevertheless, conclusive to my mind. Give this to the countess. It is only a question of a few hours. No, we don't keep a cat. The cool evening air refreshed him.", + "sentences_24_regular" : "You can well enjoy the evening now. We'll make up for it now. The weakness of a murderer. But they wouldn't leave me alone. The telegram was from his wife." + } + root = "/mount/resources/speech/corpora/EARS/" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = dict() + for speaker in os.listdir(root): + if os.path.isdir(os.path.join(root, speaker)): + for sentence_type in transcript_for_ears: + path = os.path.join(root, speaker, sentence_type + ".wav") + path_to_transcript[path] = transcript_for_ears[sentence_type] + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + def build_path_to_transcript_dict_mls_english(re_cache=False): lang = "english" root = f"/mount/resources/speech/corpora/MultiLingLibriSpeech/mls_{lang}/train" @@ -884,8 +1036,288 @@ def build_path_to_transcript_dict_css10hu(re_cache=False): return torch.load(cache_path) +# JAPANESE + +def build_path_to_transcript_dict_captain_japanese(re_cache=False): + root = "/mount/resources/speech/corpora/HiFiCaptainJapanese" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = dict() + with open(root + "/male/text/train_parallel.txt", encoding="utf8") as f: + transcriptions = f.read() + for line in transcriptions.split("\n"): + if line.strip() != "": + parsed_line = line.split() + audio_path = parsed_line[0] + transcript = parsed_line[1] + audio_path = os.path.join(root, "male", "wav", "train_parallel", audio_path + ".wav") + if os.path.exists(audio_path): + path_to_transcript[audio_path] = transcript.strip() + else: + print(f"{audio_path} does not seem to exist!") + with open(root + "/female/text/train_parallel.txt", encoding="utf8") as f: + transcriptions = f.read() + for line in transcriptions.split("\n"): + if line.strip() != "": + parsed_line = line.split() + audio_path = parsed_line[0] + transcript = parsed_line[1] + audio_path = os.path.join(root, "female", "wav", "train_parallel", audio_path + ".wav") + if os.path.exists(audio_path): + path_to_transcript[audio_path] = transcript.strip() + else: + print(f"{audio_path} does not seem to exist!") + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_jvs(re_cache=False): + root = "/mount/resources/speech/corpora/JVS/jvs_ver1" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = dict() + for data_dir in os.listdir(root): + if os.path.isdir(os.path.join(root, data_dir)): + for data_type in ["parallel100", "nonpara30"]: + with open(os.path.join(root, data_dir, data_type, "transcripts_utf8.txt"), encoding="utf8") as f: + transcriptions = f.read() + for line in transcriptions.split("\n"): + if line.strip() != "": + parsed_line = line.split(":") + audio_path = parsed_line[0] + transcript = parsed_line[1] + audio_path = os.path.join(root, data_dir, data_type, "wav24kHz16bit", audio_path + ".wav") + if os.path.exists(audio_path): + path_to_transcript[audio_path] = transcript.strip() + else: + print(f"{audio_path} does not seem to exist!") + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + # OTHER + +def build_path_to_transcript_dict_indicvoices_Assamese(re_cache=False): + language = "Assamese" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Bengali(re_cache=False): + language = "Bengali" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Bodo(re_cache=False): + language = "Bodo" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Dogri(re_cache=False): + language = "Dogri" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Gujarati(re_cache=False): + language = "Gujarati" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Hindi(re_cache=False): + language = "Hindi" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Kannada(re_cache=False): + language = "Kannada" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Kashmiri(re_cache=False): + language = "Kashmiri" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Konkani(re_cache=False): + language = "Konkani" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Maithili(re_cache=False): + language = "Maithili" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Malayalam(re_cache=False): + language = "Malayalam" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Manipuri(re_cache=False): + language = "Manipuri" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Marathi(re_cache=False): + language = "Marathi" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Nepali(re_cache=False): + language = "Nepali" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Odia(re_cache=False): + language = "Odia" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Punjabi(re_cache=False): + language = "Punjabi" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Sanskrit(re_cache=False): + language = "Sanskrit" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Santali(re_cache=False): + language = "Santali" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Sindhi(re_cache=False): + language = "Sindhi" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Tamil(re_cache=False): + language = "Tamil" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Telugu(re_cache=False): + language = "Telugu" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + +def build_path_to_transcript_dict_indicvoices_Urdu(re_cache=False): + language = "Urdu" + root = f"/mount/resources/speech/corpora/IndicVoicesR" + cache_path = os.path.join(root, "pttd_cache.pt") + if not os.path.exists(cache_path) or re_cache: + path_to_transcript = indic_voices_template(root=root, lang=language) + torch.save(path_to_transcript, cache_path) + return torch.load(cache_path) + + def build_file_list_singing_voice_audio_database(re_cache=False): root = "/mount/resources/speech/corpora/singing_voice_audio_dataset/monophonic" cache_path = os.path.join(root, "pttd_cache.pt") @@ -899,12 +1331,6 @@ def build_file_list_singing_voice_audio_database(re_cache=False): return torch.load(cache_path) -from pathlib import Path -import xml.etree.ElementTree as ET -from csv import DictReader -import json - - def build_path_to_transcript_dict_nst_norwegian(): root = '/resources/speech/corpora/NST_norwegian/pcm/cs' path_to_transcript = dict() diff --git a/Utility/utils.py b/Utility/utils.py index 106bd45bea6b8b4b4da7e2142fa1d9c1e8bca9eb..88698fc12a91feeee41f28ef0a0cc5190cc664f5 100644 --- a/Utility/utils.py +++ b/Utility/utils.py @@ -11,7 +11,7 @@ import torch import torch.multiprocessing from matplotlib.lines import Line2D -import Architectures.GeneralLayers.ConditionalLayerNorm +import Modules.GeneralLayers.ConditionalLayerNorm from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend from Preprocessing.TextFrontend import get_language_id @@ -67,7 +67,7 @@ def plot_progress_spec_toucantts(net, step, lang, default_emb, - run_glow): + run_stochastic): tf = ArticulatoryCombinedTextFrontend(language=lang) sentence = tf.get_example_sentence(lang=lang) if sentence is None: @@ -77,7 +77,7 @@ def plot_progress_spec_toucantts(net, return_duration_pitch_energy=True, utterance_embedding=default_emb, lang_id=get_language_id(lang).to(device), - run_glow=run_glow) + run_stochastic=run_stochastic) plot_code_spec(pitch, energy, sentence, durations, mel, os.path.join(save_dir, "visualization"), tf, step) return os.path.join(os.path.join(save_dir, "visualization"), f"{step}.png") @@ -146,12 +146,14 @@ def plot_code_spec(pitch, energy, sentence, durations, mel, save_path, tf, step) plt.close() -def plot_spec_tensor(spec, save_path, name): +def plot_spec_tensor(spec, save_path, name, title=None): fig, spec_plot_axis = plt.subplots(nrows=1, ncols=1, figsize=(9, 4)) spec_plot_axis.imshow(spec.detach().cpu().numpy(), origin="lower", cmap='GnBu') spec_plot_axis.yaxis.set_visible(False) spec_plot_axis.set_aspect("auto") - plt.subplots_adjust(left=0.05, bottom=0.1, right=0.95, top=.95, wspace=0.0, hspace=0.0) + if title is not None: + spec_plot_axis.set_title(title) + plt.subplots_adjust(left=0.05, bottom=0.1, right=0.95, top=.95 if title is None else .85, wspace=0.0, hspace=0.0) os.makedirs(save_path, exist_ok=True) plt.savefig(os.path.join(save_path, f"{name}.png"), dpi=100) plt.clf() @@ -336,8 +338,8 @@ def initialize(model, init): for m in model.modules(): if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm, - Architectures.GeneralLayers.ConditionalLayerNorm.ConditionalLayerNorm, - Architectures.GeneralLayers.ConditionalLayerNorm.SequentialWrappableConditionalLayerNorm + Modules.GeneralLayers.ConditionalLayerNorm.ConditionalLayerNorm, + Modules.GeneralLayers.ConditionalLayerNorm.SequentialWrappableConditionalLayerNorm )): m.reset_parameters() diff --git a/app.py b/app.py index 2b88ea21b81e13c8ec5c0741c8c76f435260111d..7e0d078c93dbe6922a22376c3acbea20d6a91017 100644 --- a/app.py +++ b/app.py @@ -1,144 +1,79 @@ -import os - -from run_model_downloader import download_models - -#if not os.path.exists("Models/ToucanTTS_Meta/best.pt"): -# download_models() import gradio as gr -from Preprocessing.multilinguality.SimilaritySolver import load_json_from_path -from Utility.utils import float2pcm - -import os - -import torch +import torch.cuda -from Architectures.ControllabilityGAN.GAN import GanWrapper -from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface -from Utility.storage_config import MODELS_DIR - - -class ControllableInterface(torch.nn.Module): - - def __init__(self, available_artificial_voices=1000): - super().__init__() - self.model = ToucanTTSInterface(device="cuda", tts_model_path="Meta", language="eng") - self.wgan = GanWrapper(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt"), device="cuda") - self.generated_speaker_embeds = list() - self.available_artificial_voices = available_artificial_voices - self.current_language = "" - self.current_accent = "" +from InferenceInterfaces.ControllableInterface import ControllableInterface +from Utility.utils import float2pcm +from Utility.utils import load_json_from_path + + +class TTSWebUI: + + def __init__(self, gpu_id="cpu", title="ToucanTTS in 7000 Languages", article="Check out the toolkit at https://github.com/DigitalPhonetics/IMS-Toucan", available_artificial_voices=1000, path_to_iso_list="Preprocessing/multilinguality/iso_to_fullname.json"): + iso_to_name = load_json_from_path(path_to_iso_list) + text_selection = [f"{iso_to_name[iso_code]} ({iso_code})" for iso_code in iso_to_name] + # accent_selection = [f"{iso_to_name[iso_code]} Accent ({iso_code})" for iso_code in iso_to_name] + + self.controllable_ui = ControllableInterface(gpu_id=gpu_id, + available_artificial_voices=available_artificial_voices) + self.iface = gr.Interface(fn=self.read, + inputs=[gr.Textbox(lines=2, + placeholder="write what you want the synthesis to read here...", + value="What I cannot create, I do not understand.", + label="Text input"), + gr.Dropdown(text_selection, + type="value", + value='English (eng)', + label="Select the Language of the Text (type on your keyboard to find it quickly)"), + gr.Audio(type="filepath", show_label=True, container=True, label="Voice to Clone (if left empty, will use an artificial voice instead)"), + gr.Slider(minimum=0, maximum=available_artificial_voices, step=1, + value=279, + label="Random Seed for the artificial Voice"), + gr.Slider(minimum=0.0, maximum=0.8, step=0.1, value=0.1, label="Prosody Creativity"), + gr.Slider(minimum=0.7, maximum=1.3, step=0.1, value=1.0, label="Duration Scale"), + # gr.Slider(minimum=0.5, maximum=1.5, step=0.1, value=1.0, label="Pitch Variance Scale"), + # gr.Slider(minimum=0.5, maximum=1.5, step=0.1, value=1.0, label="Energy Variance Scale"), + gr.Slider(minimum=-10.0, maximum=10.0, step=0.1, value=0.0, label="Femininity / Masculinity"), + gr.Slider(minimum=-10.0, maximum=10.0, step=0.1, value=0.0, label="Voice Depth") + ], + outputs=[gr.Audio(type="numpy", label="Speech"), + gr.Image(label="Visualization")], + title=title, + theme="default", + allow_flagging="never", + article=article) + self.iface.launch() def read(self, prompt, language, - accent, + reference_audio, voice_seed, + prosody_creativity, duration_scaling_factor, - pause_duration_scaling_factor, - pitch_variance_scale, - energy_variance_scale, - emb_slider_1, - emb_slider_2, - emb_slider_3, - emb_slider_4, - emb_slider_5, - emb_slider_6, - loudness_in_db + # pitch_variance_scale, + # energy_variance_scale, + emb1, + emb2 ): - if self.current_language != language: - self.model.set_language(language) - self.current_language = language - - self.wgan.set_latent(voice_seed) - controllability_vector = torch.tensor([emb_slider_1, - emb_slider_2, - emb_slider_3, - emb_slider_4, - emb_slider_5, - emb_slider_6], dtype=torch.float32) - embedding = self.wgan.modify_embed(controllability_vector) - self.model.set_utterance_embedding(embedding=embedding) - - if len(prompt) > 1800: - raise AssertionError("The input is too long!") - phones = self.model.text2phone.get_phone_string(prompt) - if len(phones) > 1800: - raise AssertionError("The input is too long!") - - print("\n\n") - print(prompt) - print(language) - print("\n\n") - - wav, sr, fig = self.model(prompt, - input_is_phones=False, - duration_scaling_factor=duration_scaling_factor, - pitch_variance_scale=pitch_variance_scale, - energy_variance_scale=energy_variance_scale, - pause_duration_scaling_factor=pause_duration_scaling_factor, - return_plot_as_filepath=True, - loudness_in_db=loudness_in_db) - return sr, wav, fig - - -title = "Controllable Text-to-Speech for over 7000 Languages" -article = "Check out the IMS Toucan TTS Toolkit at https://github.com/DigitalPhonetics/IMS-Toucan" -available_artificial_voices = 1000 -path_to_iso_list = "Preprocessing/multilinguality/iso_to_fullname.json" -iso_to_name = load_json_from_path(path_to_iso_list) -text_selection = [f"{iso_to_name[iso_code]} Text ({iso_code})" for iso_code in iso_to_name] -controllable_ui = ControllableInterface(available_artificial_voices=available_artificial_voices) - - -def read(prompt, - language, - voice_seed, - duration_scaling_factor, - pitch_variance_scale, - energy_variance_scale, - emb1, - emb2 - ): - with torch.no_grad(): - sr, wav, fig = controllable_ui.read(prompt, - language.split(" ")[-1].split("(")[1].split(")")[0], - language.split(" ")[-1].split("(")[1].split(")")[0], - voice_seed, - duration_scaling_factor, - 1., - pitch_variance_scale, - energy_variance_scale, - emb1, - emb2, - 0., - 0., - 0., - 0., - -24.) - return (sr, float2pcm(wav)), fig - -iface = gr.Interface(fn=read, - inputs=[gr.Textbox(lines=2, - placeholder="write what you want the synthesis to read here...", - value="What I cannot create, I do not understand.", - label="Text input"), - gr.Dropdown(text_selection, - type="value", - value='English Text (eng)', - label="Select the Language of the Text (type on your keyboard to find it quickly)"), - gr.Slider(minimum=0, maximum=available_artificial_voices, step=1, - value=279, - label="Random Seed for the artificial Voice"), - gr.Slider(minimum=0.7, maximum=1.3, step=0.1, value=1.0, label="Duration Scale"), - gr.Slider(minimum=0.5, maximum=1.5, step=0.1, value=1.0, label="Pitch Variance Scale"), - gr.Slider(minimum=0.5, maximum=1.5, step=0.1, value=1.0, label="Energy Variance Scale"), - gr.Slider(minimum=-10.0, maximum=10.0, step=0.1, value=0.0, label="Femininity / Masculinity"), - gr.Slider(minimum=-10.0, maximum=10.0, step=0.1, value=0.0, label="Voice Depth") - ], - outputs=[gr.Audio(type="numpy", label="Speech"), - gr.Image(label="Visualization")], - title=title, - theme="default", - allow_flagging="never", - article=article) -iface.launch() + sr, wav, fig = self.controllable_ui.read(prompt, + reference_audio, + language.split(" ")[-1].split("(")[1].split(")")[0], + language.split(" ")[-1].split("(")[1].split(")")[0], + voice_seed, + prosody_creativity, + duration_scaling_factor, + 1., + 1.0, + 1.0, + emb1, + emb2, + 0., + 0., + 0., + 0., + -18.) + return (sr, float2pcm(wav)), fig + + +if __name__ == '__main__': + TTSWebUI(gpu_id="cuda" if torch.cuda.is_available() else "cpu") diff --git a/app_future.py b/app_future.py deleted file mode 100644 index 3ebde739d28980901e7dfe52cdccdcfc6fd45c97..0000000000000000000000000000000000000000 --- a/app_future.py +++ /dev/null @@ -1,193 +0,0 @@ -import os - -import spaces - -os.system("git clone --branch v3.1 https://github.com/DigitalPhonetics/IMS-Toucan.git toucan_codebase") -os.system("mv toucan_codebase/* .") - -from run_model_downloader import download_models - -download_models() - -import gradio as gr -import torch.cuda -from Preprocessing.multilinguality.SimilaritySolver import load_json_from_path -from Utility.utils import float2pcm - -import os - -import torch - -from Architectures.ControllabilityGAN.GAN import GanWrapper -from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface -from Utility.storage_config import MODELS_DIR - - -class ControllableInterface(torch.nn.Module): - - def __init__(self, available_artificial_voices=1000): - super().__init__() - self.model = ToucanTTSInterface(device="cpu", tts_model_path="Meta") - self.wgan = GanWrapper(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt"), device="cpu") - self.generated_speaker_embeds = list() - self.available_artificial_voices = available_artificial_voices - self.current_language = "" - self.current_accent = "" - - def read(self, - prompt, - language, - accent, - voice_seed, - prosody_creativity, - duration_scaling_factor, - pause_duration_scaling_factor, - pitch_variance_scale, - energy_variance_scale, - emb_slider_1, - emb_slider_2, - emb_slider_3, - emb_slider_4, - emb_slider_5, - emb_slider_6, - loudness_in_db - ): - if self.current_language != language: - self.model.set_phonemizer_language(language) - self.current_language = language - if self.current_accent != accent: - self.model.set_accent_language(accent) - self.current_accent = accent - - self.wgan.set_latent(voice_seed) - controllability_vector = torch.tensor([emb_slider_1, - emb_slider_2, - emb_slider_3, - emb_slider_4, - emb_slider_5, - emb_slider_6], dtype=torch.float32) - embedding = self.wgan.modify_embed(controllability_vector) - self.model.set_utterance_embedding(embedding=embedding) - - phones = self.model.text2phone.get_phone_string(prompt) - if len(phones) > 1800: - if language == "deu": - prompt = "Deine Eingabe war zu lang. Bitte versuche es entweder mit einem kürzeren Text oder teile ihn in mehrere Teile auf." - elif language == "ell": - prompt = "Η εισήγησή σας ήταν πολύ μεγάλη. Παρακαλώ δοκιμάστε είτε ένα μικρότερο κείμενο είτε χωρίστε το σε διάφορα μέρη." - elif language == "spa": - prompt = "Su entrada es demasiado larga. Por favor, intente un texto más corto o divídalo en varias partes." - elif language == "fin": - prompt = "Vastauksesi oli liian pitkä. Kokeile joko lyhyempää tekstiä tai jaa se useampaan osaan." - elif language == "rus": - prompt = "Ваш текст слишком длинный. Пожалуйста, попробуйте либо сократить текст, либо разделить его на несколько частей." - elif language == "hun": - prompt = "Túl hosszú volt a bevitele. Kérjük, próbáljon meg rövidebb szöveget írni, vagy ossza több részre." - elif language == "nld": - prompt = "Uw input was te lang. Probeer een kortere tekst of splits het in verschillende delen." - elif language == "fra": - prompt = "Votre saisie était trop longue. Veuillez essayer un texte plus court ou le diviser en plusieurs parties." - elif language == 'pol': - prompt = "Twój wpis był zbyt długi. Spróbuj skrócić tekst lub podzielić go na kilka części." - elif language == 'por': - prompt = "O seu contributo foi demasiado longo. Por favor, tente um texto mais curto ou divida-o em várias partes." - elif language == 'ita': - prompt = "Il tuo input era troppo lungo. Per favore, prova un testo più corto o dividilo in più parti." - elif language == 'cmn': - prompt = "你的输入太长了。请尝试使用较短的文本或将其拆分为多个部分。" - elif language == 'vie': - prompt = "Đầu vào của bạn quá dài. Vui lòng thử một văn bản ngắn hơn hoặc chia nó thành nhiều phần." - else: - prompt = "Your input was too long. Please try either a shorter text or split it into several parts." - if self.current_language != "eng": - self.model.set_phonemizer_language("eng") - self.current_language = "eng" - if self.current_accent != "eng": - self.model.set_accent_language("eng") - self.current_accent = "eng" - - print(prompt) - wav, sr, fig = self.model(prompt, - input_is_phones=False, - duration_scaling_factor=duration_scaling_factor, - pitch_variance_scale=pitch_variance_scale, - energy_variance_scale=energy_variance_scale, - pause_duration_scaling_factor=pause_duration_scaling_factor, - return_plot_as_filepath=True, - prosody_creativity=prosody_creativity, - loudness_in_db=loudness_in_db) - return sr, wav, fig - - -title = "Controllable Text-to-Speech for over 7000 Languages" -article = "Check out the IMS Toucan TTS Toolkit at https://github.com/DigitalPhonetics/IMS-Toucan" -available_artificial_voices = 1000 -path_to_iso_list = "Preprocessing/multilinguality/iso_to_fullname.json" -iso_to_name = load_json_from_path(path_to_iso_list) -text_selection = [f"{iso_to_name[iso_code]} Text ({iso_code})" for iso_code in iso_to_name] -controllable_ui = ControllableInterface(available_artificial_voices=available_artificial_voices) - - -@spaces.GPU -def read(prompt, - language, - voice_seed, - prosody_creativity, - duration_scaling_factor, - pitch_variance_scale, - energy_variance_scale, - emb1, - emb2 - ): - if torch.cuda.is_available(): - controllable_ui.to("cuda") - controllable_ui.device = "cuda" - try: - sr, wav, fig = controllable_ui.read(prompt, - language.split(" ")[-1].split("(")[1].split(")")[0], - language.split(" ")[-1].split("(")[1].split(")")[0], - voice_seed, - prosody_creativity, - duration_scaling_factor, - 1., - pitch_variance_scale, - energy_variance_scale, - emb1, - emb2, - 0., - 0., - 0., - 0., - -24.) - finally: - controllable_ui.to("cpu") - controllable_ui.device = "cpu" - return (sr, float2pcm(wav)), fig - - -iface = gr.Interface(fn=read, - inputs=[gr.Textbox(lines=2, - placeholder="write what you want the synthesis to read here...", - value="The woods are lovely, dark and deep, but I have promises to keep, and miles to go, before I sleep.", - label="Text input"), - gr.Dropdown(text_selection, - type="value", - value='English Text (eng)', - label="Select the Language of the Text (type on your keyboard to find it quickly)"), - gr.Slider(minimum=0, maximum=available_artificial_voices, step=1, - value=279, - label="Random Seed for the artificial Voice"), - gr.Slider(minimum=0.0, maximum=0.8, step=0.1, value=0.7, label="Prosody Creativity"), - gr.Slider(minimum=0.7, maximum=1.3, step=0.1, value=1.0, label="Duration Scale"), - gr.Slider(minimum=0.5, maximum=1.5, step=0.1, value=1.0, label="Pitch Variance Scale"), - gr.Slider(minimum=0.5, maximum=1.5, step=0.1, value=1.0, label="Energy Variance Scale"), - gr.Slider(minimum=-10.0, maximum=10.0, step=0.1, value=0.0, label="Femininity / Masculinity"), - gr.Slider(minimum=-10.0, maximum=10.0, step=0.1, value=0.0, label="Voice Depth") - ], - outputs=[gr.Audio(type="numpy", label="Speech"), - gr.Image(label="Visualization")], - title=title, - theme="default", - allow_flagging="never", - article=article) -iface.launch() diff --git a/pre-requirements.txt b/pre-requirements.txt index 2ef7d0aa44ee2d547d26bef5bec9270ff3cdbc8e..be99cfeec295e56f5393c803ea8f4944dccd3f35 100644 --- a/pre-requirements.txt +++ b/pre-requirements.txt @@ -1,3 +1,3 @@ -torch==2.2.0 -torchaudio==2.2.0 -torchvision==0.17.0 \ No newline at end of file +torch==2.1.0 +torchaudio==2.1.0 +torchvision==0.16.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 2314d835e6d1c1e800da49fa1723ed6b2945b1af..5ba3623344a2a65153eaac234e40dba2737fa69b 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/run_CLI_demo.py b/run_CLI_demo.py deleted file mode 100644 index ea7523d2323f0d15706f0ca4988b6653a4b1f6a9..0000000000000000000000000000000000000000 --- a/run_CLI_demo.py +++ /dev/null @@ -1,42 +0,0 @@ -import os -import sys -import warnings - -import torch - -from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface -from Utility.storage_config import MODELS_DIR - -if __name__ == '__main__': - warnings.filterwarnings("ignore", category=UserWarning) - - PATH_TO_TTS_MODEL = os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt") - PATH_TO_REFERENCE_SPEAKER = "" # audios/speaker_references_for_testing/female_high_voice.wav audios/speaker_references_for_testing/male_low_voice.wav - LANGUAGE = "eng" - device = "cuda" if torch.cuda.is_available() else "cpu" - - tts = ToucanTTSInterface(device=device, tts_model_path=PATH_TO_TTS_MODEL) - tts.set_language(lang_id=LANGUAGE) - if PATH_TO_REFERENCE_SPEAKER != "": - if os.path.exists(PATH_TO_REFERENCE_SPEAKER): - tts.set_utterance_embedding(PATH_TO_REFERENCE_SPEAKER) - else: - print(f"\n\nFile {PATH_TO_REFERENCE_SPEAKER} could not be found, please check for typos and re-run. Using default for now.\n\n") - - print("Loading the following configuration:") - print(f"\tTTS Model: {PATH_TO_TTS_MODEL}") - print(f"\tReference Audio: {PATH_TO_REFERENCE_SPEAKER}") - print(f"\tLanguage Used: {LANGUAGE}") - print(f"\tDevice Used: {device}") - - while True: - text = input("\n\nWhat should I say? (or 'exit')\n") - if text == "exit": - sys.exit() - tts.read_aloud(text, - view=True, - blocking=False, - duration_scaling_factor=1.0, - energy_variance_scale=1.0, - pitch_variance_scale=1.0, - glow_sampling_temperature=0.2) diff --git a/run_GUI_demo.py b/run_GUI_demo.py deleted file mode 100644 index 742a0169cb1bd4b0ffd964bfef53b9ffdd234595..0000000000000000000000000000000000000000 --- a/run_GUI_demo.py +++ /dev/null @@ -1,74 +0,0 @@ -import gradio as gr -import torch.cuda - -from InferenceInterfaces.ControllableInterface import ControllableInterface -from Utility.utils import float2pcm -from Utility.utils import load_json_from_path - - -class TTSWebUI: - - def __init__(self, gpu_id="cpu", title="Controllable Text-to-Speech for over 7000 Languages", article="", available_artificial_voices=1000, path_to_iso_list="Preprocessing/multilinguality/iso_to_fullname.json"): - iso_to_name = load_json_from_path(path_to_iso_list) - text_selection = [f"{iso_to_name[iso_code]} Text ({iso_code})" for iso_code in iso_to_name] - # accent_selection = [f"{iso_to_name[iso_code]} Accent ({iso_code})" for iso_code in iso_to_name] - - self.controllable_ui = ControllableInterface(gpu_id=gpu_id, - available_artificial_voices=available_artificial_voices) - self.iface = gr.Interface(fn=self.read, - inputs=[gr.Textbox(lines=2, - placeholder="write what you want the synthesis to read here...", - value="The woods are lovely, dark and deep, but I have promises to keep, and miles to go, before I sleep.", - label="Text input"), - gr.Dropdown(text_selection, - type="value", - value='English Text (eng)', - label="Select the Language of the Text (type on your keyboard to find it quickly)"), - gr.Slider(minimum=0, maximum=available_artificial_voices, step=1, - value=279, - label="Random Seed for the artificial Voice"), - gr.Slider(minimum=0.7, maximum=1.3, step=0.1, value=1.0, label="Duration Scale"), - gr.Slider(minimum=0.5, maximum=1.5, step=0.1, value=1.0, label="Pitch Variance Scale"), - gr.Slider(minimum=0.5, maximum=1.5, step=0.1, value=1.0, label="Energy Variance Scale"), - gr.Slider(minimum=-10.0, maximum=10.0, step=0.1, value=0.0, label="Femininity / Masculinity"), - gr.Slider(minimum=-10.0, maximum=10.0, step=0.1, value=0.0, label="Voice Depth") - ], - outputs=[gr.Audio(type="numpy", label="Speech"), - gr.Image(label="Visualization")], - title=title, - theme="default", - allow_flagging="never", - article=article) - self.iface.launch() - - def read(self, - prompt, - language, - voice_seed, - duration_scaling_factor, - pitch_variance_scale, - energy_variance_scale, - emb1, - emb2 - ): - sr, wav, fig = self.controllable_ui.read( - prompt=prompt, - language=language.split(" ")[-1].split("(")[1].split(")")[0], - accent=language.split(" ")[-1].split("(")[1].split(")")[0], - voice_seed=voice_seed, - duration_scaling_factor=duration_scaling_factor, - pause_duration_scaling_factor=1.0, - pitch_variance_scale=pitch_variance_scale, - energy_variance_scale=energy_variance_scale, - emb_slider_1=emb1, - emb_slider_2=emb2, - emb_slider_3=0.0, - emb_slider_4=0.0, - emb_slider_5=0.0, - emb_slider_6=0.0 - ) - return (sr, float2pcm(wav)), fig - - -if __name__ == '__main__': - TTSWebUI(gpu_id="cuda" if torch.cuda.is_available() else "cpu") diff --git a/run_model_downloader.py b/run_model_downloader.py deleted file mode 100644 index 9e655c906d60966a0ecd5cfb063f1d4a264c1c65..0000000000000000000000000000000000000000 --- a/run_model_downloader.py +++ /dev/null @@ -1,83 +0,0 @@ -import os -import urllib.request - -from Utility.storage_config import MODELS_DIR - - -def report(block_number, read_size, total_size): - if block_number % 1000 == 0: - return_to_front = '\b' * 52 - percent = round(((block_number * read_size) / total_size) * 100) - print(f"{return_to_front}[{'█' * (percent // 2)}{'.' * (50 - (percent // 2))}]", end='') - if block_number * read_size >= total_size: - return_to_front = '\b' * 52 - print(f"{return_to_front}Download complete!\n") - - -def download_models(): - ############# - print("Downloading Aligner Model") - os.makedirs(os.path.join(MODELS_DIR, "Aligner"), exist_ok=True) - filename, headers = urllib.request.urlretrieve( - url="https://github.com/DigitalPhonetics/IMS-Toucan/releases/download/v3.0/aligner.pt", - filename=os.path.abspath(os.path.join(MODELS_DIR, "Aligner", "aligner.pt")), - reporthook=report) - - ############# - print("Downloading Multilingual ToucanTTS Model") - os.makedirs(os.path.join(MODELS_DIR, "ToucanTTS_Meta"), exist_ok=True) - filename, headers = urllib.request.urlretrieve( - url="https://github.com/DigitalPhonetics/IMS-Toucan/releases/download/v3.0/ToucanTTS_Meta.pt", - filename=os.path.abspath(os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt")), - reporthook=report) - - ############# - print("Downloading Vocoder") - os.makedirs(os.path.join(MODELS_DIR, "Vocoder"), exist_ok=True) - filename, headers = urllib.request.urlretrieve( - url="https://github.com/DigitalPhonetics/IMS-Toucan/releases/download/v3.0/Vocoder.pt", - filename=os.path.abspath(os.path.join(MODELS_DIR, "Vocoder", "best.pt")), - reporthook=report) - - ############# - print("Downloading Embedding Model") - os.makedirs(os.path.join(MODELS_DIR, "Embedding"), exist_ok=True) - filename, headers = urllib.request.urlretrieve( - url="https://github.com/DigitalPhonetics/IMS-Toucan/releases/download/v3.0/embedding_function.pt", - filename=os.path.abspath(os.path.join(MODELS_DIR, "Embedding", "embedding_function.pt")), - reporthook=report) - - ############# - print("Downloading Embedding GAN") - os.makedirs(os.path.join(MODELS_DIR, "Embedding"), exist_ok=True) - filename, headers = urllib.request.urlretrieve( - url="https://github.com/DigitalPhonetics/IMS-Toucan/releases/download/v3.0/embedding_gan.pt", - filename=os.path.abspath(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt")), - reporthook=report) - - ############# - print("Downloading Codec Model") - os.makedirs("Preprocessing/Codec", exist_ok=True) - filename, headers = urllib.request.urlretrieve( - url="https://huggingface.co/Dongchao/AcademiCodec/resolve/main/encodec_16k_320d.pth", - filename=os.path.abspath(os.path.join("Preprocessing/Codec", "encodec_16k_320d.pt")), - reporthook=report) - - ############# - print("Downloading ASP lookup") - os.makedirs("Preprocessing/multilinguality", exist_ok=True) - filename, headers = urllib.request.urlretrieve( - url="https://github.com/DigitalPhonetics/IMS-Toucan/releases/download/v3.0/asp_dict.pkl", - filename=os.path.abspath(os.path.join("Preprocessing/multilinguality", "asp_dict.pkl")), - reporthook=report) - - ############# - print("Downloading Audioseal Model") - os.makedirs("Models/audioseal", exist_ok=True) - filename, headers = urllib.request.urlretrieve( - url="https://github.com/DigitalPhonetics/IMS-Toucan/releases/download/v3.0/audioseal.pth", - filename=os.path.abspath(os.path.join("Models/audioseal", "generator.pth")), - reporthook=report) - -if __name__ == '__main__': - download_models() diff --git a/run_prosody_override.py b/run_prosody_override.py deleted file mode 100644 index 7b865ebf7e345e4b8b933ecd8162634c5db5e597..0000000000000000000000000000000000000000 --- a/run_prosody_override.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch - -from InferenceInterfaces.UtteranceCloner import UtteranceCloner - -if __name__ == '__main__': - uc = UtteranceCloner(model_id="Nancy", device="cuda" if torch.cuda.is_available() else "cpu") - - # What is said in path_to_reference_audio_for_intonation has to match the text in the reference_transcription exactly! - uc.clone_utterance(path_to_reference_audio_for_intonation="audios/speaker_references_for_testing/sad.wav", - path_to_reference_audio_for_voice="audios/speaker_references_for_testing/female_mid_voice.wav", # the two reference audios can be the same, but don't have to be - transcription_of_intonation_reference="This report is due tomorrow.", - filename_of_result="audios/test_cloned.wav", - lang="eng") diff --git a/run_scorer.py b/run_scorer.py deleted file mode 100644 index 0dc40bb3d0b4cd10bf222eb443f63e2eef7cd668..0000000000000000000000000000000000000000 --- a/run_scorer.py +++ /dev/null @@ -1,100 +0,0 @@ -""" -Example use of the scorer utility to inspect data. -(pre-trained models and already cache files with extracted features are required.)""" - -from Utility.Scorer import TTSScorer -from Utility.path_to_transcript_dicts import * -from Utility.storage_config import MODELS_DIR -from Utility.storage_config import PREPROCESSING_DIR - -exec_device = "cuda:8" - -lang_id = "fon" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "african_voices_fon_alf"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "hau" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "african_voices_hausa_cmv"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "lbb" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "african_voices_ibibio_lst"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "kik" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "african_voices_kikuyu_opb"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "lin" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "african_voices_lingala_opb"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "lug" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "african_voices_ganda_cmv"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "luo" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "african_voices_luo_afv"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "luo" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "african_voices_luo_opb"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "swh" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "african_voices_swahili_llsti"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "sxb" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "african_voices_suba_afv"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "wol" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "african_voices_wolof_alf"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "yor" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "african_voices_yoruba_opb"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "nya" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "zambezi_voice_nyanja"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "loz" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "zambezi_voice_lozi"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) - -lang_id = "toi" -tts_scorer = TTSScorer(path_to_model=os.path.join(MODELS_DIR, "ToucanTTS_Massive", "best.pt"), device=exec_device) -tts_scorer.score(path_to_toucantts_dataset=os.path.join(PREPROCESSING_DIR, "zambezi_voice_tonga"), lang_id=lang_id) -tts_scorer.show_samples_with_highest_loss(20) -tts_scorer.remove_samples_with_highest_loss(5) diff --git a/run_text_to_file_reader.py b/run_text_to_file_reader.py deleted file mode 100644 index 20c14f67c408a667d960e4b7741b5ca46b54f14e..0000000000000000000000000000000000000000 --- a/run_text_to_file_reader.py +++ /dev/null @@ -1,113 +0,0 @@ -import os - -import torch - -from InferenceInterfaces.ToucanTTSInterface import ToucanTTSInterface - - -def read_texts(model_id, sentence, filename, device="cpu", language="eng", speaker_reference=None, duration_scaling_factor=1.0): - tts = ToucanTTSInterface(device=device, tts_model_path=model_id) - tts.set_language(language) - if speaker_reference is not None: - tts.set_utterance_embedding(speaker_reference) - if type(sentence) == str: - sentence = [sentence] - tts.read_to_file(text_list=sentence, file_location=filename, duration_scaling_factor=duration_scaling_factor) - del tts - - -def the_raven(version, model_id="Meta", exec_device="cpu", speaker_reference=None): - os.makedirs("audios", exist_ok=True) - - read_texts(model_id=model_id, - sentence=['Once upon a midnight dreary, while I pondered, weak, and weary,', - 'Over many a quaint, and curious volume, of forgotten lore,', - 'While I nodded, nearly napping, suddenly, there came a tapping,', - 'As of someone gently rapping, rapping at my chamber door.', - 'Ah, distinctly, I remember, it was in the bleak December,', - 'And each separate dying ember, wrought its ghost upon the floor.', - 'Eagerly, I wished the morrow, vainly, I had sought to borrow', - 'From my books surcease of sorrow, sorrow, for the lost Lenore,', - 'And the silken, sad, uncertain, rustling of each purple curtain', - 'Thrilled me, filled me, with fantastic terrors, never felt before.'], - filename=f"audios/{version}_the_raven.wav", - device=exec_device, - language="eng", - speaker_reference=speaker_reference) - - -def sound_of_silence_single_utt(version, model_id="Meta", exec_device="cpu", speaker_reference=None): - os.makedirs("audios", exist_ok=True) - - read_texts(model_id=model_id, - sentence=["""In restless dreams I walked alone, -Narrow streets of cobblestone. -Beneath the halo of a streetlamp, -I turned my collar to the cold and damp, -When my eyes were stabbed, by the flash of a neon light, -That split the night. -And touched the sound, of silence."""], - filename=f"audios/{version}_sound_of_silence_as_single_utterance.wav", - device=exec_device, - language="eng", - speaker_reference=speaker_reference) - - -def die_glocke(version, model_id="Meta", exec_device="cpu", speaker_reference=None): - os.makedirs("audios", exist_ok=True) - - read_texts(model_id=model_id, - sentence=["""Fest gemauert in der Erden, - Steht die Form, aus Lehm gebrannt. - Heute muss die Glocke werden! - Frisch, Gesellen, seid zur Hand!"""], - filename=f"audios/{version}_die_glocke.wav", - device=exec_device, - language="deu", - speaker_reference=speaker_reference) - - -def viet_poem(version, model_id="Meta", exec_device="cpu", speaker_reference=None): - os.makedirs("audios", exist_ok=True) - - read_texts(model_id=model_id, - sentence=["""Thân phận, - ở một nơi luôn phải nhắc mình, - im miệng, - thân phận, - là khi nói về quá khứ, - ngó trước nhìn sau, - là phải biết nhắm mắt bịt tai làm lơ, - thờ ơ, - với tất cả những điều gai chướng, - thân phận chúng tôi ở đó, - những quyển sách chuyền tay nhau như ăn cắp, - ngôn luận ư? - không có đất cho nghĩa tự do."""], - filename=f"audios/{version}_viet_poem.wav", - device=exec_device, - language="vie", - speaker_reference=speaker_reference, - duration_scaling_factor=1.2) - - -if __name__ == '__main__': - exec_device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"running on {exec_device}") - - merged_speaker_references = ["audios/speaker_references/" + ref for ref in os.listdir("audios/speaker_references/")] - - sound_of_silence_single_utt(version="new_voc", - model_id="Meta", - exec_device=exec_device, - speaker_reference=merged_speaker_references) - - die_glocke(version="new_voc", - model_id="Meta", - exec_device=exec_device, - speaker_reference=merged_speaker_references) - - viet_poem(version="new_voc", - model_id="Meta", - exec_device=exec_device, - speaker_reference=merged_speaker_references) diff --git a/run_training_pipeline.py b/run_training_pipeline.py deleted file mode 100644 index 2d0585d65b0aa1d0f07673b90b56cde637b01736..0000000000000000000000000000000000000000 --- a/run_training_pipeline.py +++ /dev/null @@ -1,117 +0,0 @@ -import argparse -import os -import random -import sys - -import torch - -from TrainingPipelines.AlignerPipeline import run as aligner -from TrainingPipelines.HiFiGAN_combined import run as HiFiGAN -from TrainingPipelines.StochasticToucanTTS_Nancy import run as nancystoch -from TrainingPipelines.ToucanTTS_IntegrationTest import run as tt_integration_test -from TrainingPipelines.ToucanTTS_MLS_English import run as mls -from TrainingPipelines.ToucanTTS_Massive_stage1 import run as stage1 -from TrainingPipelines.ToucanTTS_Massive_stage2 import run as stage2 -from TrainingPipelines.ToucanTTS_Massive_stage3 import run as stage3 -from TrainingPipelines.ToucanTTS_MetaCheckpoint import run as meta -from TrainingPipelines.ToucanTTS_Nancy import run as nancy -from TrainingPipelines.finetuning_example_multilingual import run as fine_tuning_example_multilingual -from TrainingPipelines.finetuning_example_simple import run as fine_tuning_example_simple - -pipeline_dict = { - # the finetuning example - "finetuning_example_simple" : fine_tuning_example_simple, - "finetuning_example_multilingual": fine_tuning_example_multilingual, - # integration tests - "tt_it" : tt_integration_test, - # regular ToucanTTS pipelines - "nancy" : nancy, - "mls" : mls, - "nancystoch" : nancystoch, - "meta" : meta, - "stage1" : stage1, - "stage2" : stage2, - "stage3" : stage3, - # training the aligner from scratch (not recommended, best to use provided checkpoint) - "aligner" : aligner, - # vocoder training (not recommended, best to use provided checkpoint) - "hifigan" : HiFiGAN -} - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description='Training with the IMS Toucan Speech Synthesis Toolkit') - - parser.add_argument('pipeline', - choices=list(pipeline_dict.keys()), - help="Select pipeline to train.") - - parser.add_argument('--gpu_id', - type=str, - help="Which GPU(s) to run on. If not specified runs on CPU, but other than for integration tests that doesn't make much sense.", - default="cpu") - - parser.add_argument('--resume_checkpoint', - type=str, - help="Path to checkpoint to resume from.", - default=None) - - parser.add_argument('--resume', - action="store_true", - help="Automatically load the highest checkpoint and continue from there.", - default=False) - - parser.add_argument('--finetune', - action="store_true", - help="Whether to fine-tune from the specified checkpoint.", - default=False) - - parser.add_argument('--model_save_dir', - type=str, - help="Directory where the checkpoints should be saved to.", - default=None) - - parser.add_argument('--wandb', - action="store_true", - help="Whether to use weights and biases to track training runs. Requires you to run wandb login and place your auth key before.", - default=False) - - parser.add_argument('--wandb_resume_id', - type=str, - help="ID of a stopped wandb run to continue tracking", - default=None) - - args = parser.parse_args() - - if args.finetune and args.resume_checkpoint is None and not args.resume: - print("Need to provide path to checkpoint to fine-tune from!") - sys.exit() - - if args.gpu_id == "cpu": - os.environ["CUDA_VISIBLE_DEVICES"] = "" - device = torch.device("cpu") - print(f"No GPU specified, using CPU. Training will likely not work without GPU.") - gpu_count = 1 # for technical reasons this is set to one, indicating it's not gpu_count training, even though there is no GPU in this case - else: - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu_id}" - device = torch.device("cuda") - print(f"Making GPU {os.environ['CUDA_VISIBLE_DEVICES']} the only visible device(s).") - gpu_count = len(args.gpu_id.replace(",", " ").split()) - # example call for gpu_count training: - # torchrun --standalone --nproc_per_node=4 --nnodes=1 run_training_pipeline.py nancy --gpu_id "1,2,3" - - torch.manual_seed(9665) - random.seed(9665) - torch.random.manual_seed(9665) - - torch.multiprocessing.set_sharing_strategy('file_system') - - pipeline_dict[args.pipeline](gpu_id=args.gpu_id, - resume_checkpoint=args.resume_checkpoint, - resume=args.resume, - finetune=args.finetune, - model_dir=args.model_save_dir, - use_wandb=args.wandb, - wandb_resume_id=args.wandb_resume_id, - gpu_count=gpu_count) \ No newline at end of file diff --git a/run_weight_averaging.py b/run_weight_averaging.py deleted file mode 100644 index a1a9330ffb689e7d51787efe21361bfed4825d5d..0000000000000000000000000000000000000000 --- a/run_weight_averaging.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -https://alexander-stasiuk.medium.com/pytorch-weights-averaging-e2c0fa611a0c -""" - -import os - -import torch - -from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS -from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN -from Utility.storage_config import MODELS_DIR - - -def load_net_toucan(path): - check_dict = torch.load(path, map_location=torch.device("cpu")) - net = ToucanTTS(weights=check_dict["model"], config=check_dict["config"]) - return net, check_dict["default_emb"] - - -def load_net_bigvgan(path): - check_dict = torch.load(path, map_location=torch.device("cpu")) - net = HiFiGAN(weights=check_dict["generator"]) - return net, None - - -def get_n_recent_checkpoints_paths(checkpoint_dir, n=5): - print("selecting checkpoints...") - checkpoint_list = list() - for el in os.listdir(checkpoint_dir): - if el.endswith(".pt") and el.startswith("checkpoint_"): - try: - checkpoint_list.append(int(el.split(".")[0].split("_")[1])) - except RuntimeError: - pass - if len(checkpoint_list) == 0: - return None - elif len(checkpoint_list) < n: - n = len(checkpoint_list) - checkpoint_list.sort(reverse=True) - return [os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(step)) for step in checkpoint_list[:n]] - - -def average_checkpoints(list_of_checkpoint_paths, load_func): - # COLLECT CHECKPOINTS - if list_of_checkpoint_paths is None or len(list_of_checkpoint_paths) == 0: - return None - checkpoints_weights = {} - model = None - default_embed = None - - # LOAD CHECKPOINTS - for path_to_checkpoint in list_of_checkpoint_paths: - print("loading model {}".format(path_to_checkpoint)) - model, default_embed = load_func(path=path_to_checkpoint) - checkpoints_weights[path_to_checkpoint] = dict(model.named_parameters()) - - # AVERAGE CHECKPOINTS - params = model.named_parameters() - dict_params = dict(params) - checkpoint_amount = len(checkpoints_weights) - print("averaging...") - for name in dict_params.keys(): - custom_params = None - for _, checkpoint_parameters in checkpoints_weights.items(): - if custom_params is None: - custom_params = checkpoint_parameters[name].data - else: - custom_params += checkpoint_parameters[name].data - dict_params[name].data.copy_(custom_params / checkpoint_amount) - model_dict = model.state_dict() - model_dict.update(dict_params) - model.load_state_dict(model_dict) - model.eval() - return model, default_embed - - -def save_model_for_use(model, name="", default_embed=None, dict_name="model"): - print("saving model...") - torch.save({dict_name: model.state_dict(), "default_emb": default_embed, "config": model.config}, name) - print("...done!") - - -def make_best_in_all(): - for model_dir in os.listdir(MODELS_DIR): - if os.path.isdir(os.path.join(MODELS_DIR, model_dir)): - if "ToucanTTS" in model_dir: - checkpoint_paths = get_n_recent_checkpoints_paths(checkpoint_dir=os.path.join(MODELS_DIR, model_dir), n=3) - if checkpoint_paths is None: - continue - averaged_model, default_embed = average_checkpoints(checkpoint_paths, load_func=load_net_toucan) - save_model_for_use(model=averaged_model, default_embed=default_embed, name=os.path.join(MODELS_DIR, model_dir, "best.pt")) - - -def count_parameters(net): - return sum(p.numel() for p in net.parameters() if p.requires_grad) - - -if __name__ == '__main__': - make_best_in_all() diff --git a/run_zero_shot_lang_emb_injection.py b/run_zero_shot_lang_emb_injection.py deleted file mode 100644 index 12246f093aefbc0b1e05392cb7400666faabdbe4..0000000000000000000000000000000000000000 --- a/run_zero_shot_lang_emb_injection.py +++ /dev/null @@ -1,30 +0,0 @@ -import argparse -import os - -from Preprocessing.multilinguality.create_distance_lookups import CacheCreator -from Preprocessing.multilinguality.create_lang_dist_dataset import LangDistDatasetCreator -from Preprocessing.multilinguality.generate_zero_shot_lang_embs import approximate_and_inject_language_embeddings -from Utility.storage_config import MODELS_DIR - -if __name__ == "__main__": - default_model_path = os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt") - parser = argparse.ArgumentParser() - parser.add_argument("--model_path", "-m", type=str, default=default_model_path, help="model path from which to obtain pretrained language embeddings") - parser.add_argument("--distance_type", "-d", type=str, choices=["map", "tree", "asp", "learned", "combined"], default="learned", - help="which type of distance to use for finding nearest languages") - parser.add_argument("--n_closest", "-k", type=int, default=50, help="how many nearest languages to select for language embedding approximation") - - args = parser.parse_args() - - # make sure that cache files exist - cc = CacheCreator(cache_root="Preprocessing/multilinguality") - cc.create_required_files(model_path=os.path.join(MODELS_DIR, "ToucanTTS_Meta", "best.pt")) - - # create distance dataset - dc = LangDistDatasetCreator(args.model_path, cache_root="Preprocessing/multilinguality") - distance_dataset = dc.create_dataset(args.distance_type, n_closest=args.n_closest, zero_shot=True) - - # generate zero-shot lang embs and inject into pretrained model, then save to modified model path - approximate_and_inject_language_embeddings(model_path=args.model_path, - df=distance_dataset, - iso_lookup=dc.iso_lookup)