Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from Modules.GeneralLayers.Conformer import Conformer | |
class CodecRefinementTransformer(torch.nn.Module): | |
def __init__(self, | |
attention_dimension=128, | |
num_codebooks=4, | |
codebook_size=1024, | |
backtranslation_dim=8, | |
attention_heads=4, | |
positionwise_conv_kernel_size=1, | |
use_macaron_style_in_conformer=True, | |
use_cnn_in_conformer=False, # for now, we try using just a regular transformer | |
decoder_layers=6, | |
decoder_units=1280, | |
decoder_concat_after=False, | |
conformer_decoder_kernel_size=31, | |
decoder_normalize_before=True, | |
transformer_dec_dropout_rate=0.2, | |
transformer_dec_positional_dropout_rate=0.1, | |
transformer_dec_attn_dropout_rate=0.1, | |
utt_embed_dim=512, | |
use_conditional_layernorm_embedding_integration=False, | |
): | |
super().__init__() | |
self.reconstruction_transformer = Conformer( | |
conformer_type="decoder", | |
attention_dim=num_codebooks * backtranslation_dim, | |
attention_heads=attention_heads, | |
linear_units=decoder_units, | |
num_blocks=decoder_layers, | |
input_layer=None, | |
dropout_rate=transformer_dec_dropout_rate, | |
positional_dropout_rate=transformer_dec_positional_dropout_rate, | |
attention_dropout_rate=transformer_dec_attn_dropout_rate, | |
normalize_before=decoder_normalize_before, | |
concat_after=decoder_concat_after, | |
positionwise_conv_kernel_size=positionwise_conv_kernel_size, | |
macaron_style=use_macaron_style_in_conformer, | |
use_cnn_module=use_cnn_in_conformer, | |
cnn_module_kernel=conformer_decoder_kernel_size, | |
use_output_norm=False, | |
utt_embed=utt_embed_dim, | |
use_conditional_layernorm_embedding_integration=use_conditional_layernorm_embedding_integration | |
) | |
self.num_codebooks = num_codebooks | |
self.codebook_size = codebook_size | |
self.input_embeddings = torch.nn.ModuleList() | |
self.backtranslation_heads = torch.nn.ModuleList() | |
self.hierarchical_classifier = torch.nn.ModuleList() | |
self.padding_id = codebook_size + 5 | |
for head in range(num_codebooks): | |
self.input_embeddings.append(torch.nn.Embedding(num_embeddings=self.padding_id + 1, embedding_dim=backtranslation_dim, padding_idx=self.padding_id)) | |
self.backtranslation_heads.append(torch.nn.Embedding(num_embeddings=self.padding_id + 1, embedding_dim=backtranslation_dim, padding_idx=self.padding_id)) | |
self.hierarchical_classifier.append(torch.nn.Linear(num_codebooks * backtranslation_dim + head * backtranslation_dim, codebook_size)) | |
self.criterion = MaskedRefinementObjective() | |
for backtranslation_head in self.backtranslation_heads: | |
torch.nn.init.normal_(backtranslation_head.weight, mean=0, std=attention_dimension ** -0.5) | |
for input_embedding in self.input_embeddings: | |
torch.nn.init.normal_(input_embedding.weight, mean=0, std=attention_dimension ** -0.5) | |
def forward(self, index_sequence, is_inference, speaker_embedding, padding_mask=None, gold_index_sequence=None): | |
""" | |
index_sequence: [batch, codebook_index, time_steps] a sequence of indexes that come from an argmax of the previous prediction layer. | |
is_inference: boolean flag that indicates whether to return the masked language modelling loss or the refined sequence | |
speaker_embedding: [batch, speaker_embed_dim] | |
padding_mask: [batch, time_steps] a mask that is True for all time steps that are padding and should not be considered and False everywhere else. | |
return: loss if is_inference is false, otherwise [batch, codebook_index, time_steps] a sequence of indexes with the same shape and same interpretation, refined through iterative masked language modelling. | |
""" | |
if not is_inference: | |
index_sequence_padding_accounted = index_sequence.masked_fill(mask=padding_mask.unsqueeze(1), value=self.padding_id) | |
else: | |
index_sequence_padding_accounted = index_sequence # in the case of inference, there is no padding | |
sequence_of_continuous_tokens = self.indexes_per_codebook_to_stacked_embedding_vector(index_sequence_padding_accounted) # return [batch, time_steps, num_codebooks x backtranslation_dim] | |
contextualized_sequence = self.contextualize_sequence(sequence_of_continuous_tokens, speaker_embedding, non_padding_mask=~padding_mask if padding_mask is not None else None) | |
predicted_indexes_one_hot = list() | |
backtranslated_indexes = list() | |
for head_index, classifier_head in enumerate(self.hierarchical_classifier): | |
# each codebook considers all previous codebooks. | |
predicted_indexes_one_hot.append(classifier_head(torch.cat([contextualized_sequence] + backtranslated_indexes, dim=2))) | |
predicted_lookup_index = torch.argmax(predicted_indexes_one_hot[-1], dim=-1) | |
backtranslation = self.backtranslation_heads[head_index](predicted_lookup_index) | |
if len(backtranslation.size()) == 1: | |
backtranslation = backtranslation.unsqueeze(0) | |
backtranslated_indexes.append(backtranslation) | |
indexes = torch.cat(predicted_indexes_one_hot, dim=2) | |
# [Batch, Sequence, Hidden] | |
indexes = indexes.view(contextualized_sequence.size(0), contextualized_sequence.size(1), self.num_codebooks, self.codebook_size) | |
# [Batch, Sequence, Codebook, Classes] | |
indexes = indexes.transpose(1, 2) | |
# [Batch, Codebook, Sequence, Classes] | |
indexes = indexes.transpose(2, 3) | |
# [Batch, Codebook, Classes, Sequence] | |
indexes = indexes.transpose(0, 1) | |
# [Codebook, Batch, Classes, Sequence] | |
if is_inference: | |
return indexes | |
else: | |
return self.criterion(predicted_one_hot=indexes, gold_one_hot=gold_index_sequence, non_pad_mask=~padding_mask) | |
def contextualize_sequence(self, masked_sequence, utterance_embedding, non_padding_mask): | |
decoded_speech, _ = self.reconstruction_transformer(masked_sequence, non_padding_mask.unsqueeze(2) if non_padding_mask is not None else None, utterance_embedding=utterance_embedding) | |
return decoded_speech | |
def indexes_per_codebook_to_stacked_embedding_vector(self, index_sequence_per_codebook): | |
continuous_frame_sequences = list() | |
for codebook_id, backtranslation_head in enumerate(self.backtranslation_heads): | |
continuous_frame_sequences.append(backtranslation_head(index_sequence_per_codebook.transpose(0, 1)[codebook_id])) | |
stacked_embedding_vector = torch.cat(continuous_frame_sequences, dim=-1) | |
return stacked_embedding_vector | |
class MaskedRefinementObjective(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.classification_loss = torch.nn.CrossEntropyLoss(reduction="none") | |
self.l1_loss = torch.nn.L1Loss(reduction="none") | |
def forward(self, predicted_one_hot, gold_one_hot, non_pad_mask): | |
ce = list() | |
for one_hot_pred, one_hot_target in zip(predicted_one_hot, gold_one_hot.transpose(0, 1).transpose(2, 3)): | |
# we iterate over codebooks | |
ce.append(self.classification_loss(one_hot_pred, one_hot_target)) | |
classification_loss = torch.stack(ce).sum(0) | |
# make weighted mask and apply it | |
out_masks = non_pad_mask.unsqueeze(-1).to(gold_one_hot.device) | |
out_masks = torch.nn.functional.pad(out_masks.transpose(1, 2), [0, gold_one_hot.size(2) - out_masks.size(1), 0, 0, 0, 0], value=False).transpose(1, 2) | |
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() | |
out_weights /= gold_one_hot.size(0) * gold_one_hot.size(-1) | |
# apply weight | |
classification_loss = classification_loss.mul(out_weights.squeeze()).masked_select(out_masks.squeeze()).sum() | |
return classification_loss, classification_loss | |
def one_hot_sequence_to_token_sequence(batch_of_indexes_one_hot_per_codebook): | |
return torch.argmax(batch_of_indexes_one_hot_per_codebook, dim=-2).transpose(0, 1) | |
if __name__ == '__main__': | |
from Modules.ToucanTTS.ToucanTTS import ToucanTTS | |
from Utility.utils import make_pad_mask | |
# prepare dummy inputs | |
num_codebooks = 4 | |
dummy_text_batch = torch.randint(low=0, high=2, size=[3, 3, 62]).float() # [Batch, Sequence Length, Features per Phone] | |
dummy_text_lens = torch.LongTensor([2, 3, 3]) | |
gold_speech_batch = torch.randn([3, num_codebooks, 30, 1024]) # [Batch, Sequence Length, Spectrogram Buckets] | |
gold_speech_lens = torch.LongTensor([10, 30, 20]) | |
gold_durations = torch.LongTensor([[10, 0, 0], [10, 15, 5], [5, 5, 10]]) | |
gold_pitch = torch.Tensor([[[1.0], [0.], [0.]], [[1.1], [1.2], [0.8]], [[1.1], [1.2], [0.8]]]) | |
gold_energy = torch.Tensor([[[1.0], [1.3], [0.]], [[1.1], [1.4], [0.8]], [[1.1], [1.2], [0.8]]]) | |
dummy_utterance_embed = torch.randn([3, 512]) # [Batch, Dimensions of Speaker Embedding] | |
dummy_language_id = torch.LongTensor([5, 3, 2]).unsqueeze(1) | |
# run TTS on pseudo inputs | |
batch_of_indexes_one_hot_per_codebook, _, _, _, _, _ = ToucanTTS(num_codebooks=num_codebooks, use_language_model=False)._forward(dummy_text_batch, | |
dummy_text_lens, | |
gold_speech_batch, | |
gold_speech_lens, | |
gold_durations, | |
gold_pitch, | |
gold_energy, | |
utterance_embedding=dummy_utterance_embed, | |
lang_ids=dummy_language_id) | |
# reformat outputs to be a token sequence | |
batch_of_indexes = one_hot_sequence_to_token_sequence(batch_of_indexes_one_hot_per_codebook) | |
# refine the output of the TTS with the Language Model | |
refiner = CodecRefinementTransformer() | |
loss = refiner(index_sequence=one_hot_sequence_to_token_sequence(gold_speech_batch.transpose(3, 2)).transpose(0, 1), padding_mask=make_pad_mask(gold_speech_lens), is_inference=False, speaker_embedding=dummy_utterance_embed, gold_index_sequence=gold_speech_batch) | |
print(loss) | |
refined_indexes = refiner(index_sequence=batch_of_indexes[1].unsqueeze(0), is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None) | |
print(refined_indexes.shape) | |
refined_indexes = one_hot_sequence_to_token_sequence(refined_indexes) | |
refined_indexes = refiner(index_sequence=refined_indexes, is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None) | |
print(refined_indexes.shape) | |
refined_indexes = one_hot_sequence_to_token_sequence(refined_indexes) | |
refined_indexes = refiner(index_sequence=refined_indexes, is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None) | |
print(refined_indexes.shape) | |
refined_indexes = one_hot_sequence_to_token_sequence(refined_indexes) | |
refined_indexes = refiner(index_sequence=refined_indexes, is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None) | |
print(refined_indexes.shape) | |