Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,388 Bytes
6faeba1 6a79837 6faeba1 6a79837 6faeba1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import torch
from Modules.GeneralLayers.Conformer import Conformer
class CodecRefinementTransformer(torch.nn.Module):
def __init__(self,
attention_dimension=128,
num_codebooks=4,
codebook_size=1024,
backtranslation_dim=8,
attention_heads=4,
positionwise_conv_kernel_size=1,
use_macaron_style_in_conformer=True,
use_cnn_in_conformer=False, # for now, we try using just a regular transformer
decoder_layers=6,
decoder_units=1280,
decoder_concat_after=False,
conformer_decoder_kernel_size=31,
decoder_normalize_before=True,
transformer_dec_dropout_rate=0.2,
transformer_dec_positional_dropout_rate=0.1,
transformer_dec_attn_dropout_rate=0.1,
utt_embed_dim=512,
use_conditional_layernorm_embedding_integration=False,
):
super().__init__()
self.reconstruction_transformer = Conformer(
conformer_type="decoder",
attention_dim=num_codebooks * backtranslation_dim,
attention_heads=attention_heads,
linear_units=decoder_units,
num_blocks=decoder_layers,
input_layer=None,
dropout_rate=transformer_dec_dropout_rate,
positional_dropout_rate=transformer_dec_positional_dropout_rate,
attention_dropout_rate=transformer_dec_attn_dropout_rate,
normalize_before=decoder_normalize_before,
concat_after=decoder_concat_after,
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
macaron_style=use_macaron_style_in_conformer,
use_cnn_module=use_cnn_in_conformer,
cnn_module_kernel=conformer_decoder_kernel_size,
use_output_norm=False,
utt_embed=utt_embed_dim,
use_conditional_layernorm_embedding_integration=use_conditional_layernorm_embedding_integration
)
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
self.input_embeddings = torch.nn.ModuleList()
self.backtranslation_heads = torch.nn.ModuleList()
self.hierarchical_classifier = torch.nn.ModuleList()
self.padding_id = codebook_size + 5
for head in range(num_codebooks):
self.input_embeddings.append(torch.nn.Embedding(num_embeddings=self.padding_id + 1, embedding_dim=backtranslation_dim, padding_idx=self.padding_id))
self.backtranslation_heads.append(torch.nn.Embedding(num_embeddings=self.padding_id + 1, embedding_dim=backtranslation_dim, padding_idx=self.padding_id))
self.hierarchical_classifier.append(torch.nn.Linear(num_codebooks * backtranslation_dim + head * backtranslation_dim, codebook_size))
self.criterion = MaskedRefinementObjective()
for backtranslation_head in self.backtranslation_heads:
torch.nn.init.normal_(backtranslation_head.weight, mean=0, std=attention_dimension ** -0.5)
for input_embedding in self.input_embeddings:
torch.nn.init.normal_(input_embedding.weight, mean=0, std=attention_dimension ** -0.5)
def forward(self, index_sequence, is_inference, speaker_embedding, padding_mask=None, gold_index_sequence=None):
"""
index_sequence: [batch, codebook_index, time_steps] a sequence of indexes that come from an argmax of the previous prediction layer.
is_inference: boolean flag that indicates whether to return the masked language modelling loss or the refined sequence
speaker_embedding: [batch, speaker_embed_dim]
padding_mask: [batch, time_steps] a mask that is True for all time steps that are padding and should not be considered and False everywhere else.
return: loss if is_inference is false, otherwise [batch, codebook_index, time_steps] a sequence of indexes with the same shape and same interpretation, refined through iterative masked language modelling.
"""
if not is_inference:
index_sequence_padding_accounted = index_sequence.masked_fill(mask=padding_mask.unsqueeze(1), value=self.padding_id)
else:
index_sequence_padding_accounted = index_sequence # in the case of inference, there is no padding
sequence_of_continuous_tokens = self.indexes_per_codebook_to_stacked_embedding_vector(index_sequence_padding_accounted) # return [batch, time_steps, num_codebooks x backtranslation_dim]
contextualized_sequence = self.contextualize_sequence(sequence_of_continuous_tokens, speaker_embedding, non_padding_mask=~padding_mask if padding_mask is not None else None)
predicted_indexes_one_hot = list()
backtranslated_indexes = list()
for head_index, classifier_head in enumerate(self.hierarchical_classifier):
# each codebook considers all previous codebooks.
predicted_indexes_one_hot.append(classifier_head(torch.cat([contextualized_sequence] + backtranslated_indexes, dim=2)))
predicted_lookup_index = torch.argmax(predicted_indexes_one_hot[-1], dim=-1)
backtranslation = self.backtranslation_heads[head_index](predicted_lookup_index)
if len(backtranslation.size()) == 1:
backtranslation = backtranslation.unsqueeze(0)
backtranslated_indexes.append(backtranslation)
indexes = torch.cat(predicted_indexes_one_hot, dim=2)
# [Batch, Sequence, Hidden]
indexes = indexes.view(contextualized_sequence.size(0), contextualized_sequence.size(1), self.num_codebooks, self.codebook_size)
# [Batch, Sequence, Codebook, Classes]
indexes = indexes.transpose(1, 2)
# [Batch, Codebook, Sequence, Classes]
indexes = indexes.transpose(2, 3)
# [Batch, Codebook, Classes, Sequence]
indexes = indexes.transpose(0, 1)
# [Codebook, Batch, Classes, Sequence]
if is_inference:
return indexes
else:
return self.criterion(predicted_one_hot=indexes, gold_one_hot=gold_index_sequence, non_pad_mask=~padding_mask)
def contextualize_sequence(self, masked_sequence, utterance_embedding, non_padding_mask):
decoded_speech, _ = self.reconstruction_transformer(masked_sequence, non_padding_mask.unsqueeze(2) if non_padding_mask is not None else None, utterance_embedding=utterance_embedding)
return decoded_speech
def indexes_per_codebook_to_stacked_embedding_vector(self, index_sequence_per_codebook):
continuous_frame_sequences = list()
for codebook_id, backtranslation_head in enumerate(self.backtranslation_heads):
continuous_frame_sequences.append(backtranslation_head(index_sequence_per_codebook.transpose(0, 1)[codebook_id]))
stacked_embedding_vector = torch.cat(continuous_frame_sequences, dim=-1)
return stacked_embedding_vector
class MaskedRefinementObjective(torch.nn.Module):
def __init__(self):
super().__init__()
self.classification_loss = torch.nn.CrossEntropyLoss(reduction="none")
self.l1_loss = torch.nn.L1Loss(reduction="none")
def forward(self, predicted_one_hot, gold_one_hot, non_pad_mask):
ce = list()
for one_hot_pred, one_hot_target in zip(predicted_one_hot, gold_one_hot.transpose(0, 1).transpose(2, 3)):
# we iterate over codebooks
ce.append(self.classification_loss(one_hot_pred, one_hot_target))
classification_loss = torch.stack(ce).sum(0)
# make weighted mask and apply it
out_masks = non_pad_mask.unsqueeze(-1).to(gold_one_hot.device)
out_masks = torch.nn.functional.pad(out_masks.transpose(1, 2), [0, gold_one_hot.size(2) - out_masks.size(1), 0, 0, 0, 0], value=False).transpose(1, 2)
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
out_weights /= gold_one_hot.size(0) * gold_one_hot.size(-1)
# apply weight
classification_loss = classification_loss.mul(out_weights.squeeze()).masked_select(out_masks.squeeze()).sum()
return classification_loss, classification_loss
def one_hot_sequence_to_token_sequence(batch_of_indexes_one_hot_per_codebook):
return torch.argmax(batch_of_indexes_one_hot_per_codebook, dim=-2).transpose(0, 1)
if __name__ == '__main__':
from Modules.ToucanTTS.ToucanTTS import ToucanTTS
from Utility.utils import make_pad_mask
# prepare dummy inputs
num_codebooks = 4
dummy_text_batch = torch.randint(low=0, high=2, size=[3, 3, 62]).float() # [Batch, Sequence Length, Features per Phone]
dummy_text_lens = torch.LongTensor([2, 3, 3])
gold_speech_batch = torch.randn([3, num_codebooks, 30, 1024]) # [Batch, Sequence Length, Spectrogram Buckets]
gold_speech_lens = torch.LongTensor([10, 30, 20])
gold_durations = torch.LongTensor([[10, 0, 0], [10, 15, 5], [5, 5, 10]])
gold_pitch = torch.Tensor([[[1.0], [0.], [0.]], [[1.1], [1.2], [0.8]], [[1.1], [1.2], [0.8]]])
gold_energy = torch.Tensor([[[1.0], [1.3], [0.]], [[1.1], [1.4], [0.8]], [[1.1], [1.2], [0.8]]])
dummy_utterance_embed = torch.randn([3, 512]) # [Batch, Dimensions of Speaker Embedding]
dummy_language_id = torch.LongTensor([5, 3, 2]).unsqueeze(1)
# run TTS on pseudo inputs
batch_of_indexes_one_hot_per_codebook, _, _, _, _, _ = ToucanTTS(num_codebooks=num_codebooks, use_language_model=False)._forward(dummy_text_batch,
dummy_text_lens,
gold_speech_batch,
gold_speech_lens,
gold_durations,
gold_pitch,
gold_energy,
utterance_embedding=dummy_utterance_embed,
lang_ids=dummy_language_id)
# reformat outputs to be a token sequence
batch_of_indexes = one_hot_sequence_to_token_sequence(batch_of_indexes_one_hot_per_codebook)
# refine the output of the TTS with the Language Model
refiner = CodecRefinementTransformer()
loss = refiner(index_sequence=one_hot_sequence_to_token_sequence(gold_speech_batch.transpose(3, 2)).transpose(0, 1), padding_mask=make_pad_mask(gold_speech_lens), is_inference=False, speaker_embedding=dummy_utterance_embed, gold_index_sequence=gold_speech_batch)
print(loss)
refined_indexes = refiner(index_sequence=batch_of_indexes[1].unsqueeze(0), is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None)
print(refined_indexes.shape)
refined_indexes = one_hot_sequence_to_token_sequence(refined_indexes)
refined_indexes = refiner(index_sequence=refined_indexes, is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None)
print(refined_indexes.shape)
refined_indexes = one_hot_sequence_to_token_sequence(refined_indexes)
refined_indexes = refiner(index_sequence=refined_indexes, is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None)
print(refined_indexes.shape)
refined_indexes = one_hot_sequence_to_token_sequence(refined_indexes)
refined_indexes = refiner(index_sequence=refined_indexes, is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None)
print(refined_indexes.shape)
|