myhanhhyugen commited on
Commit
dc9eaa3
·
verified ·
1 Parent(s): 0629725

initial commits

Browse files
Files changed (4) hide show
  1. TTSInferencing.py +267 -0
  2. hyperparams.yaml +187 -0
  3. model.ckpt +3 -0
  4. module_classes.py +214 -0
TTSInferencing.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import re
3
+ import logging
4
+ import torch
5
+ import torchaudio
6
+ import random
7
+ import speechbrain
8
+ from speechbrain.inference.interfaces import Pretrained
9
+ from speechbrain.inference.text import GraphemeToPhoneme
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class TTSInferencing(Pretrained):
14
+ """
15
+ A ready-to-use wrapper for TTS (text -> mel_spec).
16
+ Arguments
17
+ ---------
18
+ hparams
19
+ Hyperparameters (from HyperPyYAML)
20
+ """
21
+
22
+ HPARAMS_NEEDED = ["modules", "input_encoder"]
23
+
24
+ MODULES_NEEDED = ["encoder_prenet", "pos_emb_enc",
25
+ "decoder_prenet", "pos_emb_dec",
26
+ "Seq2SeqTransformer", "mel_lin",
27
+ "stop_lin", "decoder_postnet"]
28
+
29
+
30
+ def __init__(self, *args, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+ lexicon = self.hparams.lexicon
33
+ lexicon = ["@@"] + lexicon
34
+ self.input_encoder = self.hparams.input_encoder
35
+ self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
36
+ self.input_encoder.add_unk()
37
+
38
+ self.modules = self.hparams.modules
39
+
40
+ self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
41
+
42
+
43
+
44
+
45
+ def generate_padded_phonemes(self, texts):
46
+ """Computes mel-spectrogram for a list of texts
47
+
48
+ Arguments
49
+ ---------
50
+ texts: List[str]
51
+ texts to be converted to spectrogram
52
+
53
+ Returns
54
+ -------
55
+ tensors of output spectrograms
56
+ """
57
+
58
+ # Preprocessing required at the inference time for the input text
59
+ # "label" below contains input text
60
+ # "phoneme_labels" contain the phoneme sequences corresponding to input text labels
61
+
62
+ phoneme_labels = list()
63
+
64
+ for label in texts:
65
+
66
+ phoneme_label = list()
67
+
68
+ label = self.custom_clean(label).upper()
69
+
70
+ words = label.split()
71
+ words = [word.strip() for word in words]
72
+ words_phonemes = self.g2p(words)
73
+
74
+ for i in range(len(words_phonemes)):
75
+ words_phonemes_seq = words_phonemes[i]
76
+ for phoneme in words_phonemes_seq:
77
+ if not phoneme.isspace():
78
+ phoneme_label.append(phoneme)
79
+ phoneme_labels.append(phoneme_label)
80
+
81
+
82
+ # encode the phonemes with input text encoder
83
+ encoded_phonemes = list()
84
+ for i in range(len(phoneme_labels)):
85
+ phoneme_label = phoneme_labels[i]
86
+ encoded_phoneme = torch.LongTensor(self.input_encoder.encode_sequence(phoneme_label)).to(self.device)
87
+ encoded_phonemes.append(encoded_phoneme)
88
+
89
+
90
+ # Right zero-pad all one-hot text sequences to max input length
91
+ input_lengths, ids_sorted_decreasing = torch.sort(
92
+ torch.LongTensor([len(x) for x in encoded_phonemes]), dim=0, descending=True
93
+ )
94
+
95
+ max_input_len = input_lengths[0]
96
+
97
+ phoneme_padded = torch.LongTensor(len(encoded_phonemes), max_input_len).to(self.device)
98
+ phoneme_padded.zero_()
99
+
100
+ for seq_idx, seq in enumerate(encoded_phonemes):
101
+ phoneme_padded[seq_idx, : len(seq)] = seq
102
+
103
+
104
+ return phoneme_padded.to(self.device, non_blocking=True).float()
105
+
106
+
107
+ def encode_batch(self, texts):
108
+ """Computes mel-spectrogram for a list of texts
109
+
110
+ Texts must be sorted in decreasing order on their lengths
111
+
112
+ Arguments
113
+ ---------
114
+ texts: List[str]
115
+ texts to be encoded into spectrogram
116
+
117
+ Returns
118
+ -------
119
+ tensors of output spectrograms
120
+ """
121
+
122
+ # generate phonemes and padd the input texts
123
+ encoded_phoneme_padded = self.generate_padded_phonemes(texts)
124
+ phoneme_prenet_emb = self.modules['encoder_prenet'](encoded_phoneme_padded)
125
+ # Positional Embeddings
126
+ phoneme_pos_emb = self.modules['pos_emb_enc'](encoded_phoneme_padded)
127
+ # Summing up embeddings
128
+ enc_phoneme_emb = phoneme_prenet_emb.permute(0,2,1) + phoneme_pos_emb
129
+ enc_phoneme_emb = enc_phoneme_emb.to(self.device)
130
+
131
+
132
+ with torch.no_grad():
133
+
134
+ # generate sequential predictions via transformer decoder
135
+ start_token = torch.full((80, 1), fill_value= 0)
136
+ start_token[1] = 2
137
+ decoder_input = start_token.repeat(enc_phoneme_emb.size(0), 1, 1)
138
+ decoder_input = decoder_input.to(self.device, non_blocking=True).float()
139
+
140
+ num_itr = 0
141
+ stop_condition = [False] * decoder_input.size(0)
142
+ max_iter = 100
143
+
144
+ # while not all(stop_condition) and num_itr < max_iter:
145
+ while num_itr < max_iter:
146
+
147
+ # Decoder Prenet
148
+ mel_prenet_emb = self.modules['decoder_prenet'](decoder_input).to(self.device).permute(0,2,1)
149
+
150
+ # Positional Embeddings
151
+ mel_pos_emb = self.modules['pos_emb_dec'](mel_prenet_emb).to(self.device)
152
+ # Summing up Embeddings
153
+ dec_mel_spec = mel_prenet_emb + mel_pos_emb
154
+
155
+ # Getting the target mask to avoid looking ahead
156
+ tgt_mask = self.hparams.lookahead_mask(dec_mel_spec).to(self.device)
157
+
158
+ # Getting the source mask
159
+ src_mask = torch.zeros(enc_phoneme_emb.shape[1], enc_phoneme_emb.shape[1]).to(self.device)
160
+
161
+ # Padding masks for source and targets
162
+ src_key_padding_mask = self.hparams.padding_mask(enc_phoneme_emb, pad_idx = self.hparams.blank_index).to(self.device)
163
+ tgt_key_padding_mask = self.hparams.padding_mask(dec_mel_spec, pad_idx = self.hparams.blank_index).to(self.device)
164
+
165
+
166
+ # Running the Seq2Seq Transformer
167
+ decoder_outputs = self.modules['Seq2SeqTransformer'](src = enc_phoneme_emb, tgt = dec_mel_spec, src_mask = src_mask, tgt_mask = tgt_mask,
168
+ src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = tgt_key_padding_mask)
169
+
170
+ # Mel Linears
171
+ mel_linears = self.modules['mel_lin'](decoder_outputs).permute(0,2,1)
172
+ mel_postnet = self.modules['decoder_postnet'](mel_linears) # mel tensor output
173
+ mel_pred = mel_linears + mel_postnet # mel tensor output
174
+
175
+ stop_token_pred = self.modules['stop_lin'](decoder_outputs).squeeze(-1)
176
+
177
+ stop_condition_list = self.check_stop_condition(stop_token_pred)
178
+
179
+
180
+ # update the values of main stop conditions
181
+ stop_condition_update = [True if stop_condition_list[i] else stop_condition[i] for i in range(len(stop_condition))]
182
+ stop_condition = stop_condition_update
183
+
184
+
185
+ # Prepare input for the transformer input for next iteration
186
+ current_output = mel_pred[:, :, -1:]
187
+
188
+ decoder_input=torch.cat([decoder_input,current_output],dim=2)
189
+ num_itr = num_itr+1
190
+
191
+ mel_outputs = decoder_input[:, :, 1:]
192
+
193
+ return mel_outputs
194
+
195
+
196
+
197
+ def encode_text(self, text):
198
+ """Runs inference for a single text str"""
199
+ return self.encode_batch([text])
200
+
201
+
202
+ def forward(self, text_list):
203
+ "Encodes the input texts."
204
+ return self.encode_batch(text_list)
205
+
206
+
207
+ def check_stop_condition(self, stop_token_pred):
208
+ """
209
+ check if stop token / EOS reached or not for mel_specs in the batch
210
+ """
211
+
212
+ # Applying sigmoid to perform binary classification
213
+ sigmoid_output = torch.sigmoid(stop_token_pred)
214
+ # Checking if the probability is greater than 0.5
215
+ stop_results = sigmoid_output > 0.8
216
+ stop_output = [all(result) for result in stop_results]
217
+
218
+ return stop_output
219
+
220
+
221
+
222
+ def custom_clean(self, text):
223
+ """
224
+ Uses custom criteria to clean text.
225
+
226
+ Arguments
227
+ ---------
228
+ text : str
229
+ Input text to be cleaned
230
+ model_name : str
231
+ whether to treat punctuations
232
+
233
+ Returns
234
+ -------
235
+ text : str
236
+ Cleaned text
237
+ """
238
+
239
+ _abbreviations = [
240
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
241
+ for x in [
242
+ ("mrs", "missus"),
243
+ ("mr", "mister"),
244
+ ("dr", "doctor"),
245
+ ("st", "saint"),
246
+ ("co", "company"),
247
+ ("jr", "junior"),
248
+ ("maj", "major"),
249
+ ("gen", "general"),
250
+ ("drs", "doctors"),
251
+ ("rev", "reverend"),
252
+ ("lt", "lieutenant"),
253
+ ("hon", "honorable"),
254
+ ("sgt", "sergeant"),
255
+ ("capt", "captain"),
256
+ ("esq", "esquire"),
257
+ ("ltd", "limited"),
258
+ ("col", "colonel"),
259
+ ("ft", "fort"),
260
+ ]
261
+ ]
262
+
263
+ text = re.sub(" +", " ", text)
264
+
265
+ for regex, replacement in _abbreviations:
266
+ text = re.sub(regex, replacement, text)
267
+ return text
hyperparams.yaml ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ############################################################################
3
+ # Model: TTS with attention-based mechanism
4
+ # Tokens: g2p + possitional embeddings
5
+ # losses: MSE & BCE
6
+ # Training: LJSpeech
7
+ # ############################################################################
8
+
9
+ ###################################
10
+ # Experiment Parameters and setup #
11
+ ###################################
12
+ seed: 1234
13
+ __set_seed: !apply:torch.manual_seed [!ref <seed>]
14
+
15
+ # Folder set up
16
+ # output_folder: !ref .\\results\\tts\\<seed>
17
+ # save_folder: !ref <output_folder>\\save
18
+
19
+ output_folder: !ref ./results/<seed>
20
+ save_folder: !ref <output_folder>/save
21
+
22
+
23
+ ################################
24
+ # Model Parameters and model #
25
+ ################################
26
+ # Input parameters
27
+ lexicon:
28
+ - AA
29
+ - AE
30
+ - AH
31
+ - AO
32
+ - AW
33
+ - AY
34
+ - B
35
+ - CH
36
+ - D
37
+ - DH
38
+ - EH
39
+ - ER
40
+ - EY
41
+ - F
42
+ - G
43
+ - HH
44
+ - IH
45
+ - IY
46
+ - JH
47
+ - K
48
+ - L
49
+ - M
50
+ - N
51
+ - NG
52
+ - OW
53
+ - OY
54
+ - P
55
+ - R
56
+ - S
57
+ - SH
58
+ - T
59
+ - TH
60
+ - UH
61
+ - UW
62
+ - V
63
+ - W
64
+ - Y
65
+ - Z
66
+ - ZH
67
+
68
+ input_encoder: !new:speechbrain.dataio.encoder.TextEncoder
69
+
70
+
71
+
72
+ ################################
73
+ # Model Parameters and model #
74
+ # Transformer Parameters
75
+ ################################
76
+ d_model: 512
77
+ nhead: 8
78
+ num_encoder_layers: 3
79
+ num_decoder_layers: 3
80
+ dim_feedforward: 512
81
+ dropout: 0.1
82
+
83
+
84
+ # Decoder parameters
85
+ # The number of frames in the target per encoder step
86
+ n_frames_per_step: 1
87
+ decoder_rnn_dim: 1024
88
+ prenet_dim: 256
89
+ max_decoder_steps: 1000
90
+ gate_threshold: 0.5
91
+ p_decoder_dropout: 0.1
92
+ decoder_no_early_stopping: False
93
+
94
+ blank_index: 0 # This special tokes is for padding
95
+
96
+
97
+ # Masks
98
+ lookahead_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_lookahead_mask
99
+ padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask
100
+
101
+
102
+ ################################
103
+ # CNN 3-layers Prenet #
104
+ ################################
105
+ # Encoder Prenet
106
+ encoder_prenet: !new:module_classes.CNNPrenet
107
+
108
+ # Decoder Prenet
109
+ decoder_prenet: !new:module_classes.CNNDecoderPrenet
110
+
111
+ ################################
112
+ # Positional Encodings #
113
+ ################################
114
+
115
+ #encoder
116
+ pos_emb_enc: !new:module_classes.ScaledPositionalEncoding
117
+ input_size: !ref <d_model>
118
+ max_len: 5000
119
+
120
+ #decoder
121
+ pos_emb_dec: !new:module_classes.ScaledPositionalEncoding
122
+ input_size: !ref <d_model>
123
+ max_len: 5000
124
+
125
+
126
+ ################################
127
+ # S2S Transfomer #
128
+ ################################
129
+
130
+ Seq2SeqTransformer: !new:torch.nn.Transformer
131
+ d_model: !ref <d_model>
132
+ nhead: !ref <nhead>
133
+ num_encoder_layers: !ref <num_encoder_layers>
134
+ num_decoder_layers: !ref <num_decoder_layers>
135
+ dim_feedforward: !ref <dim_feedforward>
136
+ dropout: !ref <dropout>
137
+ batch_first: True
138
+
139
+
140
+ ################################
141
+ # CNN 5-layers PostNet #
142
+ ################################
143
+
144
+ decoder_postnet: !new:speechbrain.lobes.models.Tacotron2.Postnet
145
+
146
+
147
+ # Linear transformation on the top of the decoder.
148
+ stop_lin: !new:speechbrain.nnet.linear.Linear
149
+ input_size: !ref <d_model>
150
+ n_neurons: 1
151
+
152
+
153
+ # Linear transformation on the top of the decoder.
154
+ mel_lin: !new:speechbrain.nnet.linear.Linear
155
+ input_size: !ref <d_model>
156
+ n_neurons: 80
157
+
158
+ modules:
159
+ encoder_prenet: !ref <encoder_prenet>
160
+ pos_emb_enc: !ref <pos_emb_enc>
161
+ decoder_prenet: !ref <decoder_prenet>
162
+ pos_emb_dec: !ref <pos_emb_dec>
163
+ Seq2SeqTransformer: !ref <Seq2SeqTransformer>
164
+ mel_lin: !ref <mel_lin>
165
+ stop_lin: !ref <stop_lin>
166
+ decoder_postnet: !ref <decoder_postnet>
167
+
168
+
169
+ model: !new:torch.nn.ModuleList
170
+ - [!ref <encoder_prenet>,!ref <pos_emb_enc>,
171
+ !ref <decoder_prenet>, !ref <pos_emb_dec>, !ref <Seq2SeqTransformer>,
172
+ !ref <mel_lin>, !ref <stop_lin>, !ref <decoder_postnet>]
173
+
174
+
175
+ pretrained_model_path: ./model.ckpt
176
+
177
+ # The pretrainer allows a mapping between pretrained files and instances that
178
+ # are declared in the yaml. E.g here, we will download the file model.ckpt
179
+ # and it will be loaded into "model" which is pointing to the <model> defined
180
+ # before.
181
+
182
+ pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
183
+ collect_in: !ref <save_folder>
184
+ loadables:
185
+ model: !ref <model>
186
+ paths:
187
+ model: !ref <pretrained_model_path>
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e5421fe987116817841652862ce070a421d7f5d7c8bbef68c83bec876b1eafb
3
+ size 95804314
module_classes.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+
7
+ class CNNPrenet(torch.nn.Module):
8
+ def __init__(self):
9
+ super(CNNPrenet, self).__init__()
10
+
11
+ # Define the layers using Sequential container
12
+ self.conv_layers = nn.Sequential(
13
+ nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3, padding=1),
14
+ nn.BatchNorm1d(512),
15
+ nn.ReLU(),
16
+ nn.Dropout(0.1),
17
+
18
+ nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
19
+ nn.BatchNorm1d(512),
20
+ nn.ReLU(),
21
+ nn.Dropout(0.1),
22
+
23
+ nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
24
+ nn.BatchNorm1d(512),
25
+ nn.ReLU(),
26
+ nn.Dropout(0.1)
27
+ )
28
+
29
+ def forward(self, x):
30
+
31
+ # Add a new dimension for the channel
32
+ x = x.unsqueeze(1)
33
+
34
+ # Pass input through convolutional layers
35
+ x = self.conv_layers(x)
36
+
37
+ # Remove the channel dimension
38
+ x = x.squeeze(1)
39
+
40
+ # Scale the output to the range [-1, 1]
41
+ x = torch.tanh(x)
42
+
43
+ return x
44
+
45
+
46
+
47
+ class CNNDecoderPrenet(nn.Module):
48
+ def __init__(self, input_dim=80, hidden_dim=256, output_dim=256, final_dim=512, dropout_rate=0.5):
49
+ super(CNNDecoderPrenet, self).__init__()
50
+ self.layer1 = nn.Linear(input_dim, hidden_dim)
51
+ self.layer2 = nn.Linear(hidden_dim, output_dim)
52
+ self.linear_projection = nn.Linear(output_dim, final_dim) # Added linear projection
53
+ self.dropout = nn.Dropout(dropout_rate)
54
+
55
+ def forward(self, x):
56
+
57
+ # Transpose the input tensor to have the feature dimension as the last dimension
58
+ x = x.transpose(1, 2)
59
+ # Apply the linear layers
60
+ x = F.relu(self.layer1(x))
61
+ x = self.dropout(x)
62
+ x = F.relu(self.layer2(x))
63
+ x = self.dropout(x)
64
+ # Apply the linear projection
65
+ x = self.linear_projection(x)
66
+ x = x.transpose(1, 2)
67
+
68
+ return x
69
+
70
+
71
+
72
+
73
+ class CNNPostNet(torch.nn.Module):
74
+ """
75
+ Conv Postnet
76
+ Arguments
77
+ ---------
78
+ n_mel_channels: int
79
+ input feature dimension for convolution layers
80
+ postnet_embedding_dim: int
81
+ output feature dimension for convolution layers
82
+ postnet_kernel_size: int
83
+ postnet convolution kernal size
84
+ postnet_n_convolutions: int
85
+ number of convolution layers
86
+ postnet_dropout: float
87
+ dropout probability fot postnet
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ n_mel_channels=80,
93
+ postnet_embedding_dim=512,
94
+ postnet_kernel_size=5,
95
+ postnet_n_convolutions=5,
96
+ postnet_dropout=0.1,
97
+ ):
98
+ super(CNNPostNet, self).__init__()
99
+
100
+ self.conv_pre = nn.Conv1d(
101
+ in_channels=n_mel_channels,
102
+ out_channels=postnet_embedding_dim,
103
+ kernel_size=postnet_kernel_size,
104
+ padding="same",
105
+ )
106
+
107
+ self.convs_intermedite = nn.ModuleList()
108
+ for i in range(1, postnet_n_convolutions - 1):
109
+ self.convs_intermedite.append(
110
+ nn.Conv1d(
111
+ in_channels=postnet_embedding_dim,
112
+ out_channels=postnet_embedding_dim,
113
+ kernel_size=postnet_kernel_size,
114
+ padding="same",
115
+ ),
116
+ )
117
+
118
+ self.conv_post = nn.Conv1d(
119
+ in_channels=postnet_embedding_dim,
120
+ out_channels=n_mel_channels,
121
+ kernel_size=postnet_kernel_size,
122
+ padding="same",
123
+ )
124
+
125
+ self.tanh = nn.Tanh()
126
+ self.ln1 = nn.LayerNorm(postnet_embedding_dim)
127
+ self.ln2 = nn.LayerNorm(postnet_embedding_dim)
128
+ self.ln3 = nn.LayerNorm(n_mel_channels)
129
+ self.dropout1 = nn.Dropout(postnet_dropout)
130
+ self.dropout2 = nn.Dropout(postnet_dropout)
131
+ self.dropout3 = nn.Dropout(postnet_dropout)
132
+
133
+
134
+ def forward(self, x):
135
+ """Computes the forward pass
136
+ Arguments
137
+ ---------
138
+ x: torch.Tensor
139
+ a (batch, time_steps, features) input tensor
140
+ Returns
141
+ -------
142
+ output: torch.Tensor (the spectrogram predicted)
143
+ """
144
+ x = self.conv_pre(x)
145
+ x = self.ln1(x.permute(0, 2, 1)).permute(0, 2, 1) # Transpose to [batch_size, feature_dim, sequence_length]
146
+ x = self.tanh(x)
147
+ x = self.dropout1(x)
148
+
149
+ for i in range(len(self.convs_intermedite)):
150
+ x = self.convs_intermedite[i](x)
151
+ x = self.ln2(x.permute(0, 2, 1)).permute(0, 2, 1) # Transpose to [batch_size, feature_dim, sequence_length]
152
+ x = self.tanh(x)
153
+ x = self.dropout2(x)
154
+
155
+ x = self.conv_post(x)
156
+ x = self.ln3(x.permute(0, 2, 1)).permute(0, 2, 1) # Transpose to [batch_size, feature_dim, sequence_length]
157
+ x = self.dropout3(x)
158
+
159
+ return x
160
+
161
+
162
+ class ScaledPositionalEncoding(nn.Module):
163
+ """
164
+ This class implements the absolute sinusoidal positional encoding function
165
+ with an adaptive weight parameter alpha.
166
+
167
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
168
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
169
+
170
+ Arguments
171
+ ---------
172
+ input_size: int
173
+ Embedding dimension.
174
+ max_len : int, optional
175
+ Max length of the input sequences (default 2500).
176
+ Example
177
+ -------
178
+ >>> a = torch.rand((8, 120, 512))
179
+ >>> enc = PositionalEncoding(input_size=a.shape[-1])
180
+ >>> b = enc(a)
181
+ >>> b.shape
182
+ torch.Size([1, 120, 512])
183
+ """
184
+
185
+ def __init__(self, input_size, max_len=2500):
186
+ super().__init__()
187
+ if input_size % 2 != 0:
188
+ raise ValueError(
189
+ f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})"
190
+ )
191
+ self.max_len = max_len
192
+ self.alpha = nn.Parameter(torch.ones(1)) # Define alpha as a trainable parameter
193
+ pe = torch.zeros(self.max_len, input_size, requires_grad=False)
194
+ positions = torch.arange(0, self.max_len).unsqueeze(1).float()
195
+ denominator = torch.exp(
196
+ torch.arange(0, input_size, 2).float()
197
+ * -(math.log(10000.0) / input_size)
198
+ )
199
+
200
+ pe[:, 0::2] = torch.sin(positions * denominator)
201
+ pe[:, 1::2] = torch.cos(positions * denominator)
202
+ pe = pe.unsqueeze(0)
203
+ self.register_buffer("pe", pe)
204
+
205
+ def forward(self, x):
206
+ """
207
+ Arguments
208
+ ---------
209
+ x : tensor
210
+ Input feature shape (batch, time, fea)
211
+ """
212
+ pe_scaled = self.pe[:, :x.size(1)].clone().detach() * self.alpha # Scale positional encoding by alpha
213
+ return pe_scaled
214
+