Flux9665 commited on
Commit
9e275b8
1 Parent(s): 035eb7e

use explicit code instead of relying on release download

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +36 -0
  2. Architectures/Aligner/Aligner.py +164 -0
  3. Architectures/Aligner/CodecAlignerDataset.py +271 -0
  4. Architectures/Aligner/README.md +1 -0
  5. Architectures/Aligner/Reconstructor.py +40 -0
  6. Architectures/Aligner/__init__.py +0 -0
  7. Architectures/Aligner/autoaligner_train_loop.py +188 -0
  8. Architectures/ControllabilityGAN/GAN.py +82 -0
  9. Architectures/ControllabilityGAN/__init__.py +0 -0
  10. Architectures/ControllabilityGAN/dataset/__init__.py +0 -0
  11. Architectures/ControllabilityGAN/dataset/speaker_embeddings_dataset.py +94 -0
  12. Architectures/ControllabilityGAN/wgan/__init__.py +0 -0
  13. Architectures/ControllabilityGAN/wgan/init_weights.py +21 -0
  14. Architectures/ControllabilityGAN/wgan/init_wgan.py +34 -0
  15. Architectures/ControllabilityGAN/wgan/resnet_1.py +181 -0
  16. Architectures/ControllabilityGAN/wgan/resnet_init.py +15 -0
  17. Architectures/ControllabilityGAN/wgan/wgan_qc.py +272 -0
  18. Architectures/EmbeddingModel/GST.py +235 -0
  19. Architectures/EmbeddingModel/README.md +1 -0
  20. Architectures/EmbeddingModel/StyleEmbedding.py +73 -0
  21. Architectures/EmbeddingModel/StyleTTSEncoder.py +156 -0
  22. Architectures/EmbeddingModel/__init__.py +0 -0
  23. Architectures/GeneralLayers/Attention.py +324 -0
  24. Architectures/GeneralLayers/ConditionalLayerNorm.py +118 -0
  25. Architectures/GeneralLayers/Conformer.py +158 -0
  26. Architectures/GeneralLayers/Convolution.py +55 -0
  27. Architectures/GeneralLayers/DurationPredictor.py +171 -0
  28. Architectures/GeneralLayers/EncoderLayer.py +144 -0
  29. Architectures/GeneralLayers/LayerNorm.py +36 -0
  30. Architectures/GeneralLayers/LengthRegulator.py +61 -0
  31. Architectures/GeneralLayers/MultiLayeredConv1d.py +87 -0
  32. Architectures/GeneralLayers/MultiSequential.py +33 -0
  33. Architectures/GeneralLayers/PositionalEncoding.py +166 -0
  34. Architectures/GeneralLayers/PositionwiseFeedForward.py +26 -0
  35. Architectures/GeneralLayers/README.md +2 -0
  36. Architectures/GeneralLayers/ResidualBlock.py +98 -0
  37. Architectures/GeneralLayers/ResidualStack.py +51 -0
  38. Architectures/GeneralLayers/STFT.py +123 -0
  39. Architectures/GeneralLayers/Swish.py +18 -0
  40. Architectures/GeneralLayers/VariancePredictor.py +98 -0
  41. Architectures/GeneralLayers/__init__.py +0 -0
  42. Architectures/README.md +2 -0
  43. Architectures/ToucanTTS/CodecDiscriminator.py +94 -0
  44. Architectures/ToucanTTS/CodecRefinementTransformer.py +199 -0
  45. Architectures/ToucanTTS/DurationCalculator.py +30 -0
  46. Architectures/ToucanTTS/EnergyCalculator.py +94 -0
  47. Architectures/ToucanTTS/Glow.py +402 -0
  48. Architectures/ToucanTTS/InferenceToucanTTS.py +375 -0
  49. Architectures/ToucanTTS/LanguageEmbeddingSpaceStructureLoss.py +74 -0
  50. Architectures/ToucanTTS/PitchCalculator.py +117 -0
.gitignore ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .idea/
2
+ tensorboard_logs/
3
+ Corpora/
4
+ Models/
5
+ audios/
6
+ Preprocessing/glottolog/
7
+ Preprocessing/multilinguality/datasets/
8
+ apex/
9
+ pretrained_models/
10
+ .tmp/
11
+ .vscode/
12
+ split/
13
+ singing/
14
+ toucan_conda_venv/
15
+ venv/
16
+ vis/
17
+ Utility/storage_config.py
18
+ Preprocessing/multilinguality/distance_datasets
19
+
20
+
21
+ *_graph
22
+ app.py
23
+ gradio*
24
+ *playground*
25
+ run_phonemizer.py
26
+
27
+ *.pt
28
+ *.out
29
+ *.wav
30
+ *.flac
31
+ *.json
32
+ *.pyc
33
+ *.png
34
+ *.pdf
35
+ *.pkl
36
+ *.gif
Architectures/Aligner/Aligner.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ taken and adapted from https://github.com/as-ideas/DeepForcedAligner
3
+ """
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ import torch.multiprocessing
8
+ import torch.nn as nn
9
+ from torch.nn import CTCLoss
10
+ from torch.nn.utils.rnn import pack_padded_sequence
11
+ from torch.nn.utils.rnn import pad_packed_sequence
12
+
13
+ from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
14
+
15
+
16
+ class BatchNormConv(nn.Module):
17
+
18
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
19
+ super().__init__()
20
+ self.conv = nn.Conv1d(
21
+ in_channels, out_channels, kernel_size,
22
+ stride=1, padding=kernel_size // 2, bias=False)
23
+ self.bnorm = nn.BatchNorm1d(out_channels)
24
+ self.relu = nn.ReLU()
25
+
26
+ def forward(self, x):
27
+ x = x.transpose(1, 2)
28
+ x = self.conv(x)
29
+ x = self.relu(x)
30
+ x = self.bnorm(x)
31
+ x = x.transpose(1, 2)
32
+ return x
33
+
34
+
35
+ class Aligner(torch.nn.Module):
36
+
37
+ def __init__(self,
38
+ n_features=128,
39
+ num_symbols=145,
40
+ lstm_dim=512,
41
+ conv_dim=512):
42
+ super().__init__()
43
+ self.convs = nn.ModuleList([
44
+ BatchNormConv(n_features, conv_dim, 3),
45
+ nn.Dropout(p=0.5),
46
+ BatchNormConv(conv_dim, conv_dim, 3),
47
+ nn.Dropout(p=0.5),
48
+ BatchNormConv(conv_dim, conv_dim, 3),
49
+ nn.Dropout(p=0.5),
50
+ BatchNormConv(conv_dim, conv_dim, 3),
51
+ nn.Dropout(p=0.5),
52
+ BatchNormConv(conv_dim, conv_dim, 3),
53
+ nn.Dropout(p=0.5),
54
+ ])
55
+ self.rnn = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True)
56
+ self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols)
57
+ self.tf = ArticulatoryCombinedTextFrontend(language="eng")
58
+ self.ctc_loss = CTCLoss(blank=144, zero_infinity=True)
59
+ self.vector_to_id = dict()
60
+
61
+ def forward(self, x, lens=None):
62
+ for conv in self.convs:
63
+ x = conv(x)
64
+
65
+ if lens is not None:
66
+ x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
67
+ x, _ = self.rnn(x)
68
+ if lens is not None:
69
+ x, _ = pad_packed_sequence(x, batch_first=True)
70
+
71
+ x = self.proj(x)
72
+
73
+ return x
74
+
75
+ @torch.inference_mode()
76
+ def inference(self, features, tokens, save_img_for_debug=None, train=False, pathfinding="MAS", return_ctc=False):
77
+ if not train:
78
+ tokens_indexed = self.tf.text_vectors_to_id_sequence(text_vector=tokens) # first we need to convert the articulatory vectors to IDs, so we can apply dijkstra or viterbi
79
+ tokens = np.asarray(tokens_indexed)
80
+ else:
81
+ tokens = tokens.cpu().detach().numpy()
82
+
83
+ pred = self(features.unsqueeze(0))
84
+ if return_ctc:
85
+ ctc_loss = self.ctc_loss(pred.transpose(0, 1).log_softmax(2), torch.LongTensor(tokens), torch.LongTensor([len(pred[0])]),
86
+ torch.LongTensor([len(tokens)])).item()
87
+ pred = pred.squeeze().cpu().detach().numpy()
88
+ pred_max = pred[:, tokens]
89
+
90
+ # run monotonic alignment search
91
+
92
+ alignment_matrix = binarize_alignment(pred_max)
93
+
94
+ if save_img_for_debug is not None:
95
+ phones = list()
96
+ for index in tokens:
97
+ for phone in self.tf.phone_to_id:
98
+ if self.tf.phone_to_id[phone] == index:
99
+ phones.append(phone)
100
+ fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
101
+
102
+ ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis')
103
+ ax.set_ylabel("Mel-Frames")
104
+ ax.set_xticks(range(len(pred_max[0])))
105
+ ax.set_xticklabels(labels=phones)
106
+ ax.set_title("MAS Path")
107
+
108
+ plt.tight_layout()
109
+ fig.savefig(save_img_for_debug)
110
+ fig.clf()
111
+ plt.close()
112
+
113
+ if return_ctc:
114
+ return alignment_matrix, ctc_loss
115
+ return alignment_matrix
116
+
117
+
118
+
119
+ def binarize_alignment(alignment_prob):
120
+ """
121
+ # Implementation by:
122
+ # https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/alignment.py
123
+ # https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attn_loss_function.py
124
+
125
+ Binarizes alignment with MAS.
126
+ """
127
+ # assumes features x text
128
+ opt = np.zeros_like(alignment_prob)
129
+ alignment_prob = alignment_prob + (np.abs(alignment_prob).max() + 1.0) # make all numbers positive and add an offset to avoid log of 0 later
130
+ alignment_prob * alignment_prob * (1.0 / alignment_prob.max()) # normalize to (0, 1]
131
+ attn_map = np.log(alignment_prob)
132
+ attn_map[0, 1:] = -np.inf
133
+ log_p = np.zeros_like(attn_map)
134
+ log_p[0, :] = attn_map[0, :]
135
+ prev_ind = np.zeros_like(attn_map, dtype=np.int64)
136
+ for i in range(1, attn_map.shape[0]):
137
+ for j in range(attn_map.shape[1]): # for each text dim
138
+ prev_log = log_p[i - 1, j]
139
+ prev_j = j
140
+ if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
141
+ prev_log = log_p[i - 1, j - 1]
142
+ prev_j = j - 1
143
+ log_p[i, j] = attn_map[i, j] + prev_log
144
+ prev_ind[i, j] = prev_j
145
+ # now backtrack
146
+ curr_text_idx = attn_map.shape[1] - 1
147
+ for i in range(attn_map.shape[0] - 1, -1, -1):
148
+ opt[i, curr_text_idx] = 1
149
+ curr_text_idx = prev_ind[i, curr_text_idx]
150
+ opt[0, curr_text_idx] = 1
151
+ return opt
152
+
153
+
154
+ if __name__ == '__main__':
155
+ tf = ArticulatoryCombinedTextFrontend(language="eng")
156
+ from Preprocessing.HiFiCodecAudioPreprocessor import CodecAudioPreprocessor
157
+
158
+ cap = CodecAudioPreprocessor(input_sr=-1)
159
+ dummy_codebook_indexes = torch.randint(low=0, high=1023, size=[9, 20])
160
+ codebook_frames = cap.indexes_to_codec_frames(dummy_codebook_indexes)
161
+ alignment = Aligner().inference(codebook_frames.transpose(0, 1), tokens=tf.string_to_tensor("Hello world"))
162
+ print(alignment.shape)
163
+ plt.imshow(alignment, origin="lower", cmap="GnBu")
164
+ plt.show()
Architectures/Aligner/CodecAlignerDataset.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import librosa
5
+ import soundfile as sf
6
+ import torch
7
+ from speechbrain.pretrained import EncoderClassifier
8
+ from torch.multiprocessing import Manager
9
+ from torch.multiprocessing import Process
10
+ from torch.utils.data import Dataset
11
+ from torchaudio.transforms import Resample
12
+ from tqdm import tqdm
13
+
14
+ from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor
15
+ from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend
16
+ from Utility.storage_config import MODELS_DIR
17
+
18
+
19
+ class CodecAlignerDataset(Dataset):
20
+
21
+ def __init__(self,
22
+ path_to_transcript_dict,
23
+ cache_dir,
24
+ lang,
25
+ loading_processes,
26
+ device,
27
+ min_len_in_seconds=1,
28
+ max_len_in_seconds=15,
29
+ rebuild_cache=False,
30
+ verbose=False,
31
+ phone_input=False,
32
+ allow_unknown_symbols=False,
33
+ gpu_count=1,
34
+ rank=0):
35
+ self.gpu_count = gpu_count
36
+ self.rank = rank
37
+ if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache:
38
+ self._build_dataset_cache(path_to_transcript_dict=path_to_transcript_dict,
39
+ cache_dir=cache_dir,
40
+ lang=lang,
41
+ loading_processes=loading_processes,
42
+ device=device,
43
+ min_len_in_seconds=min_len_in_seconds,
44
+ max_len_in_seconds=max_len_in_seconds,
45
+ verbose=verbose,
46
+ phone_input=phone_input,
47
+ allow_unknown_symbols=allow_unknown_symbols,
48
+ gpu_count=gpu_count,
49
+ rank=rank)
50
+ self.lang = lang
51
+ self.device = device
52
+ self.cache_dir = cache_dir
53
+ self.tf = ArticulatoryCombinedTextFrontend(language=self.lang)
54
+ cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu')
55
+ self.speaker_embeddings = cache[2]
56
+ self.datapoints = cache[0]
57
+ if self.gpu_count > 1:
58
+ # we only keep a chunk of the dataset in memory to avoid redundancy. Which chunk, we figure out using the rank.
59
+ while len(self.datapoints) % self.gpu_count != 0:
60
+ self.datapoints.pop(-1) # a bit unfortunate, but if you're using multiple GPUs, you probably have a ton of datapoints anyway.
61
+ chunksize = int(len(self.datapoints) / self.gpu_count)
62
+ self.datapoints = self.datapoints[chunksize * self.rank:chunksize * (self.rank + 1)]
63
+ self.speaker_embeddings = self.speaker_embeddings[chunksize * self.rank:chunksize * (self.rank + 1)]
64
+ print(f"Loaded an Aligner dataset with {len(self.datapoints)} datapoints from {cache_dir}.")
65
+
66
+ def _build_dataset_cache(self,
67
+ path_to_transcript_dict,
68
+ cache_dir,
69
+ lang,
70
+ loading_processes,
71
+ device,
72
+ min_len_in_seconds=1,
73
+ max_len_in_seconds=15,
74
+ verbose=False,
75
+ phone_input=False,
76
+ allow_unknown_symbols=False,
77
+ gpu_count=1,
78
+ rank=0
79
+ ):
80
+ if gpu_count != 1:
81
+ import sys
82
+ print("Please run the feature extraction using only a single GPU. Multi-GPU is only supported for training.")
83
+ sys.exit()
84
+ os.makedirs(cache_dir, exist_ok=True)
85
+ if type(path_to_transcript_dict) != dict:
86
+ path_to_transcript_dict = path_to_transcript_dict() # in this case we passed a function instead of the dict, so that the function isn't executed if not necessary.
87
+ torch.multiprocessing.set_start_method('spawn', force=True)
88
+ resource_manager = Manager()
89
+ self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict)
90
+ key_list = list(self.path_to_transcript_dict.keys())
91
+ with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note:
92
+ files_used_note.write(str(key_list))
93
+ fisher_yates_shuffle(key_list)
94
+ # build cache
95
+ print("... building dataset cache ...")
96
+ self.result_pool = resource_manager.list()
97
+ # make processes
98
+ key_splits = list()
99
+ process_list = list()
100
+ for i in range(loading_processes):
101
+ key_splits.append(
102
+ key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes])
103
+ for key_split in key_splits:
104
+ process_list.append(
105
+ Process(target=self._cache_builder_process,
106
+ args=(key_split,
107
+ lang,
108
+ min_len_in_seconds,
109
+ max_len_in_seconds,
110
+ verbose,
111
+ device,
112
+ phone_input,
113
+ allow_unknown_symbols),
114
+ daemon=True))
115
+ process_list[-1].start()
116
+ for process in process_list:
117
+ process.join()
118
+ print("pooling results...")
119
+ pooled_datapoints = list()
120
+ for chunk in self.result_pool:
121
+ for datapoint in chunk:
122
+ pooled_datapoints.append(datapoint) # unpack into a joint list
123
+ self.result_pool = pooled_datapoints
124
+ del pooled_datapoints
125
+ print("converting text to tensors...")
126
+ text_tensors = [torch.ShortTensor(x[0]) for x in self.result_pool] # turn everything back to tensors (had to turn it to np arrays to avoid multiprocessing issues)
127
+ print("converting speech to tensors...")
128
+ speech_tensors = [torch.ShortTensor(x[1]) for x in self.result_pool]
129
+ print("converting waves to tensors...")
130
+ norm_waves = [torch.Tensor(x[2]) for x in self.result_pool]
131
+ print("unpacking file list...")
132
+ filepaths = [x[3] for x in self.result_pool]
133
+ del self.result_pool
134
+ self.datapoints = list(zip(text_tensors, speech_tensors))
135
+ del text_tensors
136
+ del speech_tensors
137
+ print("done!")
138
+
139
+ # add speaker embeddings
140
+ self.speaker_embeddings = list()
141
+ speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
142
+ run_opts={"device": str(device)},
143
+ savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa"))
144
+ with torch.inference_mode():
145
+ for wave in tqdm(norm_waves):
146
+ self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu())
147
+
148
+ # save to cache
149
+ if len(self.datapoints) == 0:
150
+ raise RuntimeError # something went wrong and there are no datapoints
151
+ torch.save((self.datapoints, None, self.speaker_embeddings, filepaths),
152
+ os.path.join(cache_dir, "aligner_train_cache.pt"))
153
+
154
+ def _cache_builder_process(self,
155
+ path_list,
156
+ lang,
157
+ min_len,
158
+ max_len,
159
+ verbose,
160
+ device,
161
+ phone_input,
162
+ allow_unknown_symbols):
163
+ process_internal_dataset_chunk = list()
164
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
165
+ # careful: assumes 16kHz or 8kHz audio
166
+ silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
167
+ model='silero_vad',
168
+ force_reload=False,
169
+ onnx=False,
170
+ verbose=False)
171
+ (get_speech_timestamps,
172
+ save_audio,
173
+ read_audio,
174
+ VADIterator,
175
+ collect_chunks) = utils
176
+ torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
177
+ # this to false globally during model loading rather than using inference mode or no_grad
178
+ silero_model = silero_model.to(device)
179
+ silence = torch.zeros([16000 // 4], device=device)
180
+ tf = ArticulatoryCombinedTextFrontend(language=lang)
181
+ _, sr = sf.read(path_list[0])
182
+ assumed_sr = sr
183
+ ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device)
184
+ resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device)
185
+
186
+ for path in tqdm(path_list):
187
+ if self.path_to_transcript_dict[path].strip() == "":
188
+ continue
189
+
190
+ try:
191
+ wave, sr = sf.read(path)
192
+ except:
193
+ print(f"Problem with an audio file: {path}")
194
+ continue
195
+
196
+ wave = librosa.to_mono(wave)
197
+
198
+ if sr != assumed_sr:
199
+ assumed_sr = sr
200
+ ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device)
201
+ resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device)
202
+ print(f"{path} has a different sampling rate --> adapting the codec processor")
203
+
204
+ try:
205
+ norm_wave = resample(torch.tensor(wave).float().to(device))
206
+ except ValueError:
207
+ continue
208
+ dur_in_seconds = len(norm_wave) / 16000
209
+ if not (min_len <= dur_in_seconds <= max_len):
210
+ if verbose:
211
+ print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.")
212
+ continue
213
+
214
+ # remove silences from front and back, then add constant 1/4th second silences back to front and back
215
+ with torch.no_grad():
216
+ speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000)
217
+ try:
218
+ result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
219
+ except IndexError:
220
+ print("Audio might be too short to cut silences from front and back.")
221
+ continue
222
+ wave = torch.cat([silence, result, silence])
223
+
224
+ # raw audio preprocessing is done
225
+ transcript = self.path_to_transcript_dict[path]
226
+
227
+ try:
228
+ try:
229
+ cached_text = tf.string_to_tensor(transcript, handle_missing=False, input_phonemes=phone_input).squeeze(0).cpu().numpy()
230
+ except KeyError:
231
+ cached_text = tf.string_to_tensor(transcript, handle_missing=True, input_phonemes=phone_input).squeeze(0).cpu().numpy()
232
+ if not allow_unknown_symbols:
233
+ continue # we skip sentences with unknown symbols
234
+ except ValueError:
235
+ # this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample.
236
+ continue
237
+ except KeyError:
238
+ # this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample.
239
+ continue
240
+
241
+ cached_speech = ap.audio_to_codebook_indexes(audio=wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy()
242
+ process_internal_dataset_chunk.append([cached_text,
243
+ cached_speech,
244
+ result.cpu().detach().numpy(),
245
+ path])
246
+ self.result_pool.append(process_internal_dataset_chunk)
247
+
248
+ def __getitem__(self, index):
249
+ text_vector = self.datapoints[index][0]
250
+ tokens = self.tf.text_vectors_to_id_sequence(text_vector=text_vector)
251
+ tokens = torch.LongTensor(tokens)
252
+ token_len = torch.LongTensor([len(tokens)])
253
+
254
+ codes = self.datapoints[index][1]
255
+ if codes.size()[0] != 24: # no clue why this is sometimes the case
256
+ codes = codes.transpose(0, 1)
257
+
258
+ return tokens, \
259
+ token_len, \
260
+ codes, \
261
+ None, \
262
+ self.speaker_embeddings[index]
263
+
264
+ def __len__(self):
265
+ return len(self.datapoints)
266
+
267
+
268
+ def fisher_yates_shuffle(lst):
269
+ for i in range(len(lst) - 1, 0, -1):
270
+ j = random.randint(0, i)
271
+ lst[i], lst[j] = lst[j], lst[i]
Architectures/Aligner/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Everything that is concerned with training and using the aligner model is contained in this directory. It is recommended to use the universal aligner model that we supply in the GitHub releases.
Architectures/Aligner/Reconstructor.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.multiprocessing
3
+ from torch.nn.utils.rnn import pack_padded_sequence
4
+ from torch.nn.utils.rnn import pad_packed_sequence
5
+
6
+ from Utility.utils import make_non_pad_mask
7
+
8
+
9
+ class Reconstructor(torch.nn.Module):
10
+
11
+ def __init__(self,
12
+ n_features=128,
13
+ num_symbols=145,
14
+ speaker_embedding_dim=192,
15
+ lstm_dim=256):
16
+ super().__init__()
17
+ self.in_proj = torch.nn.Linear(num_symbols + speaker_embedding_dim, lstm_dim)
18
+ self.rnn1 = torch.nn.LSTM(lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
19
+ self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True)
20
+ self.out_proj = torch.nn.Linear(2 * lstm_dim, n_features)
21
+ self.l1_criterion = torch.nn.L1Loss(reduction="none")
22
+ self.l2_criterion = torch.nn.MSELoss(reduction="none")
23
+
24
+ def forward(self, x, lens, ys):
25
+ x = self.in_proj(x)
26
+ x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False)
27
+ x, _ = self.rnn1(x)
28
+ x, _ = self.rnn2(x)
29
+ x, _ = pad_packed_sequence(x, batch_first=True)
30
+ x = self.out_proj(x)
31
+ out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(ys.device)
32
+ out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
33
+ out_weights /= ys.size(0) * ys.size(2)
34
+ l1_loss = self.l1_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
35
+ l2_loss = self.l2_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum()
36
+ return l1_loss + l2_loss
37
+
38
+
39
+ if __name__ == '__main__':
40
+ print(sum(p.numel() for p in Reconstructor().parameters() if p.requires_grad))
Architectures/Aligner/__init__.py ADDED
File without changes
Architectures/Aligner/autoaligner_train_loop.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import torch
5
+ import torch.multiprocessing
6
+ from torch.nn.utils.rnn import pad_sequence
7
+ from torch.optim import RAdam
8
+ from torch.utils.data.dataloader import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from Architectures.Aligner.Aligner import Aligner
12
+ from Architectures.Aligner.Reconstructor import Reconstructor
13
+ from Preprocessing.AudioPreprocessor import AudioPreprocessor
14
+ from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor
15
+
16
+
17
+ def collate_and_pad(batch):
18
+ # text, text_len, speech, speech_len, embed
19
+ return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True),
20
+ torch.stack([datapoint[1] for datapoint in batch]).squeeze(1),
21
+ [datapoint[2] for datapoint in batch],
22
+ None,
23
+ torch.stack([datapoint[4] for datapoint in batch]).squeeze())
24
+
25
+
26
+ def train_loop(train_dataset,
27
+ device,
28
+ save_directory,
29
+ batch_size,
30
+ steps,
31
+ path_to_checkpoint=None,
32
+ fine_tune=False,
33
+ resume=False,
34
+ debug_img_path=None,
35
+ use_reconstruction=True,
36
+ gpu_count=1,
37
+ rank=0,
38
+ steps_per_checkpoint=None):
39
+ """
40
+ Args:
41
+ resume: whether to resume from the most recent checkpoint
42
+ steps: How many steps to train
43
+ path_to_checkpoint: reloads a checkpoint to continue training from there
44
+ fine_tune: whether to load everything from a checkpoint, or only the model parameters
45
+ train_dataset: Pytorch Dataset Object for train data
46
+ device: Device to put the loaded tensors on
47
+ save_directory: Where to save the checkpoints
48
+ batch_size: How many elements should be loaded at once
49
+ debug_img_path: where to put images of the training progress if desired
50
+ use_reconstruction: whether to use the auxiliary reconstruction procedure/loss, which can make the alignment sharper
51
+ """
52
+ os.makedirs(save_directory, exist_ok=True)
53
+ torch.multiprocessing.set_sharing_strategy('file_system')
54
+ torch.multiprocessing.set_start_method('spawn', force=True)
55
+
56
+ if steps_per_checkpoint is None:
57
+ steps_per_checkpoint = len(train_dataset) // batch_size
58
+ ap = CodecAudioPreprocessor(input_sr=-1, device=device) # only used to transform features into continuous matrices
59
+ spectrogram_extractor = AudioPreprocessor(input_sr=16000, output_sr=16000, device=device)
60
+
61
+ asr_model = Aligner().to(device)
62
+ optim_asr = RAdam(asr_model.parameters(), lr=0.0001)
63
+
64
+ tiny_tts = Reconstructor().to(device)
65
+ optim_tts = RAdam(tiny_tts.parameters(), lr=0.0001)
66
+
67
+ if gpu_count > 1:
68
+ asr_model.to(rank)
69
+ tiny_tts.to(rank)
70
+ asr_model = torch.nn.parallel.DistributedDataParallel(
71
+ asr_model,
72
+ device_ids=[rank],
73
+ output_device=rank,
74
+ find_unused_parameters=True,
75
+ ).module
76
+ tiny_tts = torch.nn.parallel.DistributedDataParallel(
77
+ tiny_tts,
78
+ device_ids=[rank],
79
+ output_device=rank,
80
+ find_unused_parameters=True,
81
+ ).module
82
+ torch.distributed.barrier()
83
+ train_sampler = torch.utils.data.RandomSampler(train_dataset)
84
+ batch_sampler_train = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True)
85
+
86
+ train_loader = DataLoader(dataset=train_dataset,
87
+ num_workers=0, # unfortunately necessary for big data due to mmap errors
88
+ batch_sampler=batch_sampler_train,
89
+ prefetch_factor=None,
90
+ collate_fn=collate_and_pad)
91
+
92
+ step_counter = 0
93
+ loss_sum = list()
94
+
95
+ if resume:
96
+ previous_checkpoint = os.path.join(save_directory, "aligner.pt")
97
+ path_to_checkpoint = previous_checkpoint
98
+ fine_tune = False
99
+
100
+ if path_to_checkpoint is not None:
101
+ check_dict = torch.load(os.path.join(path_to_checkpoint), map_location=device)
102
+ asr_model.load_state_dict(check_dict["asr_model"])
103
+ tiny_tts.load_state_dict(check_dict["tts_model"])
104
+ if not fine_tune:
105
+ optim_asr.load_state_dict(check_dict["optimizer"])
106
+ optim_tts.load_state_dict(check_dict["tts_optimizer"])
107
+ step_counter = check_dict["step_counter"]
108
+ if step_counter > steps:
109
+ print("Desired steps already reached in loaded checkpoint.")
110
+ return
111
+ start_time = time.time()
112
+
113
+ while True:
114
+ asr_model.train()
115
+ tiny_tts.train()
116
+ for batch in tqdm(train_loader):
117
+ tokens = batch[0].to(device)
118
+ tokens_len = batch[1].to(device)
119
+ speaker_embeddings = batch[4].to(device)
120
+
121
+ mels = list()
122
+ mel_lengths = list()
123
+ for datapoint in batch[2]:
124
+ with torch.inference_mode():
125
+ # extremely unfortunate that we have to do this over here, but multiprocessing and this don't go together well
126
+ speech = ap.indexes_to_audio(datapoint.int().to(device))
127
+ mel = spectrogram_extractor.audio_to_mel_spec_tensor(speech, explicit_sampling_rate=16000).transpose(0, 1).cpu()
128
+ speech_len = torch.LongTensor([len(mel)])
129
+ mels.append(mel.clone())
130
+ mel_lengths.append(speech_len)
131
+ mel = pad_sequence(mels, batch_first=True).to(device)
132
+ mel_len = torch.stack(mel_lengths).squeeze(1).to(device)
133
+
134
+ pred = asr_model(mel, mel_len)
135
+
136
+ ctc_loss = asr_model.ctc_loss(pred.transpose(0, 1).log_softmax(2),
137
+ tokens,
138
+ mel_len,
139
+ tokens_len)
140
+
141
+ if use_reconstruction:
142
+ speaker_embeddings_expanded = torch.nn.functional.normalize(speaker_embeddings).unsqueeze(1).expand(-1, pred.size(1), -1)
143
+ tts_lambda = min([0.1, step_counter / 10000]) # super simple schedule
144
+ reconstruction_loss = tiny_tts(x=torch.cat([pred, speaker_embeddings_expanded], dim=-1),
145
+ # combine ASR prediction with speaker embeddings to allow for reconstruction loss on multiple speakers
146
+ lens=mel_len,
147
+ ys=mel) * tts_lambda # reconstruction loss to make the states more distinct
148
+ loss = ctc_loss + reconstruction_loss
149
+ else:
150
+ loss = ctc_loss
151
+
152
+ optim_asr.zero_grad()
153
+ if use_reconstruction:
154
+ optim_tts.zero_grad()
155
+ loss.backward()
156
+ torch.nn.utils.clip_grad_norm_(asr_model.parameters(), 1.0)
157
+ if use_reconstruction:
158
+ torch.nn.utils.clip_grad_norm_(tiny_tts.parameters(), 1.0)
159
+ optim_asr.step()
160
+ if use_reconstruction:
161
+ optim_tts.step()
162
+
163
+ loss_sum.append(loss.item())
164
+ step_counter += 1
165
+
166
+ if step_counter % steps_per_checkpoint == 0 and rank == 0:
167
+ asr_model.eval()
168
+ torch.save({
169
+ "asr_model" : asr_model.state_dict(),
170
+ "optimizer" : optim_asr.state_dict(),
171
+ "tts_model" : tiny_tts.state_dict(),
172
+ "tts_optimizer": optim_tts.state_dict(),
173
+ "step_counter" : step_counter,
174
+ },
175
+ os.path.join(save_directory, "aligner.pt"))
176
+ print("Total Loss: {}".format(round(sum(loss_sum) / len(loss_sum), 3)))
177
+ print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60)))
178
+ print("Steps: {}".format(step_counter))
179
+ if debug_img_path is not None:
180
+ asr_model.inference(features=mel[0][:mel_len[0]],
181
+ tokens=tokens[0][:tokens_len[0]],
182
+ save_img_for_debug=debug_img_path + f"/{step_counter}.png",
183
+ train=True) # for testing
184
+ asr_model.train()
185
+ loss_sum = list()
186
+
187
+ if step_counter > steps and step_counter % steps_per_checkpoint == 0:
188
+ return
Architectures/ControllabilityGAN/GAN.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from Architectures.ControllabilityGAN.wgan.init_wgan import create_wgan
4
+
5
+
6
+ class GanWrapper:
7
+
8
+ def __init__(self, path_wgan, device):
9
+ self.device = device
10
+ self.path_wgan = path_wgan
11
+
12
+ self.mean = None
13
+ self.std = None
14
+ self.wgan = None
15
+ self.normalize = True
16
+
17
+ self.load_model(path_wgan)
18
+
19
+ self.U = self.compute_controllability()
20
+
21
+ self.z_list = list()
22
+ for _ in range(1100):
23
+ self.z_list.append(self.wgan.G.module.sample_latent(1, 32))
24
+ self.z = self.z_list[0]
25
+
26
+ def set_latent(self, seed):
27
+ self.z = self.z = self.z_list[seed]
28
+
29
+ def reset_default_latent(self):
30
+ self.z = self.wgan.G.module.sample_latent(1, 32)
31
+
32
+ def load_model(self, path):
33
+ gan_checkpoint = torch.load(path, map_location="cpu")
34
+
35
+ self.wgan = create_wgan(parameters=gan_checkpoint['model_parameters'], device=self.device)
36
+ self.wgan.G.load_state_dict(gan_checkpoint['generator_state_dict'])
37
+ self.wgan.D.load_state_dict(gan_checkpoint['critic_state_dict'])
38
+
39
+ self.mean = gan_checkpoint["dataset_mean"]
40
+ self.std = gan_checkpoint["dataset_std"]
41
+
42
+ def compute_controllability(self, n_samples=50000):
43
+ _, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
44
+ intermediate = intermediate.cpu()
45
+ z = z.cpu()
46
+ U = self.controllable_speakers(intermediate, z)
47
+ return U
48
+
49
+ def controllable_speakers(self, intermediate, z):
50
+ pca = torch.pca_lowrank(intermediate)
51
+ mu = intermediate.mean()
52
+ X = torch.matmul((intermediate - mu), pca[2])
53
+ U = torch.linalg.lstsq(X, z)
54
+ return U
55
+
56
+ def get_original_embed(self):
57
+ self.wgan.G.eval()
58
+ embed_original = self.wgan.G.module.forward(self.z.to(self.device))
59
+
60
+ if self.normalize:
61
+ embed_original = inverse_normalize(
62
+ embed_original.cpu(),
63
+ self.mean.cpu().unsqueeze(0),
64
+ self.std.cpu().unsqueeze(0)
65
+ )
66
+ return embed_original
67
+
68
+ def modify_embed(self, x):
69
+ self.wgan.G.eval()
70
+ z_new = self.z.squeeze() + torch.matmul(self.U.solution.t(), x)
71
+ embed_modified = self.wgan.G.module.forward(z_new.unsqueeze(0).to(self.device))
72
+ if self.normalize:
73
+ embed_modified = inverse_normalize(
74
+ embed_modified.cpu(),
75
+ self.mean.cpu().unsqueeze(0),
76
+ self.std.cpu().unsqueeze(0)
77
+ )
78
+ return embed_modified
79
+
80
+
81
+ def inverse_normalize(tensor, mean, std):
82
+ return tensor * std + mean
Architectures/ControllabilityGAN/__init__.py ADDED
File without changes
Architectures/ControllabilityGAN/dataset/__init__.py ADDED
File without changes
Architectures/ControllabilityGAN/dataset/speaker_embeddings_dataset.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ class SpeakerEmbeddingsDataset(torch.utils.data.Dataset):
8
+
9
+ def __init__(self, feature_path, device, mode='utterance'):
10
+ super(SpeakerEmbeddingsDataset, self).__init__()
11
+
12
+ modes = ['utterance', 'speaker']
13
+ assert mode in modes, f'mode: {mode} is not supported'
14
+ if mode == 'utterance':
15
+ self.mode = 'utt'
16
+ elif mode == 'speaker':
17
+ self.mode = 'spk'
18
+
19
+ self.device = device
20
+
21
+ self.x, self.speakers = self._load_features(feature_path)
22
+ # unique_speakers = set(self.speakers)
23
+ # spk2class = dict(zip(unique_speakers, range(len(unique_speakers))))
24
+ # #self.x = self._reformat_features(self.x)
25
+ # self.y = torch.tensor([spk2class[spk] for spk in self.speakers]).to(self.device)
26
+ # self.class2spk = dict(zip(spk2class.values(), spk2class.keys()))
27
+
28
+ def __len__(self):
29
+ return len(self.speakers)
30
+
31
+ def __getitem__(self, index):
32
+ embedding = self.normalize_embedding(self.x[index])
33
+ # speaker_id = self.y[index]
34
+ return embedding, torch.zeros([0])
35
+
36
+ def normalize_embedding(self, vector):
37
+ return torch.sub(vector, self.mean) / self.std
38
+
39
+ def get_speaker(self, label):
40
+ return self.class2spk[label]
41
+
42
+ def get_embedding_dim(self):
43
+ return self.x.shape[-1]
44
+
45
+ def get_num_speaker(self):
46
+ return len(torch.unique((self.y)))
47
+
48
+ def set_labels(self, labels):
49
+ self.y_old = self.y
50
+ self.y = torch.full(size=(len(self),), fill_value=labels).to(self.device)
51
+ # if isinstance(labels, int) or isinstance(labels, float):
52
+ # self.y = torch.full(size=len(self), fill_value=labels)
53
+ # elif len(labels) == len(self):
54
+ # self.y = torch.tensor(labels)
55
+
56
+ def _load_features(self, feature_path):
57
+ if os.path.isfile(feature_path):
58
+ vectors = torch.load(feature_path, map_location=self.device)
59
+ if isinstance(vectors, list):
60
+ vectors = torch.stack(vectors)
61
+
62
+ self.mean = torch.mean(vectors)
63
+ self.std = torch.std(vectors)
64
+ return vectors, torch.zeros(vectors.size(0))
65
+ else:
66
+ vectors = torch.load(feature_path, map_location=self.device)
67
+
68
+ self.mean = torch.mean(vectors)
69
+ self.std = torch.std(vectors)
70
+
71
+ spk2idx = {}
72
+ with open(feature_path / f'{self.mode}2idx', 'r') as f:
73
+ for line in f:
74
+ split_line = line.strip().split()
75
+ if len(split_line) == 2:
76
+ spk2idx[split_line[0].strip()] = int(split_line[1])
77
+
78
+ speakers, indices = zip(*spk2idx.items())
79
+
80
+ if (feature_path / 'utt2spk').exists(): # spk2idx contains utt_ids not speaker_ids
81
+ utt2spk = {}
82
+ with open(feature_path / 'utt2spk', 'r') as f:
83
+ for line in f:
84
+ split_line = line.strip().split()
85
+ if len(split_line) == 2:
86
+ utt2spk[split_line[0].strip()] = split_line[1].strip()
87
+
88
+ speakers = [utt2spk[utt] for utt in speakers]
89
+
90
+ return vectors[np.array(indices)], speakers
91
+
92
+ def _reformat_features(self, features):
93
+ if len(features.shape) == 2:
94
+ return features.reshape(features.shape[0], 1, 1, features.shape[1])
Architectures/ControllabilityGAN/wgan/__init__.py ADDED
File without changes
Architectures/ControllabilityGAN/wgan/init_weights.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def weights_init_D(m):
5
+ classname = m.__class__.__name__
6
+ if classname.find('Conv') != -1:
7
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
8
+ # nn.init.constant_(m.bias, 0)
9
+ elif classname.find('BatchNorm') != -1:
10
+ nn.init.constant_(m.weight, 1)
11
+ nn.init.constant_(m.bias, 0)
12
+
13
+
14
+ def weights_init_G(m):
15
+ classname = m.__class__.__name__
16
+ if classname.find('Conv') != -1:
17
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
18
+ # nn.init.constant_(m.bias, 0)
19
+ elif classname.find('BatchNorm') != -1:
20
+ nn.init.constant_(m.weight, 1)
21
+ nn.init.constant_(m.bias, 0)
Architectures/ControllabilityGAN/wgan/init_wgan.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from Architectures.ControllabilityGAN.wgan.resnet_init import init_resnet
4
+ from Architectures.ControllabilityGAN.wgan.wgan_qc import WassersteinGanQuadraticCost
5
+
6
+
7
+ def create_wgan(parameters, device, optimizer='adam'):
8
+ if parameters['model'] == "resnet":
9
+ generator, discriminator = init_resnet(parameters)
10
+ else:
11
+ raise NotImplementedError
12
+
13
+ if optimizer == 'adam':
14
+ optimizer_g = torch.optim.Adam(generator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas'])
15
+ optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas'])
16
+ elif optimizer == 'rmsprop':
17
+ optimizer_g = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate'])
18
+ optimizer_d = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate'])
19
+
20
+ criterion = torch.nn.MSELoss()
21
+
22
+ gan = WassersteinGanQuadraticCost(generator,
23
+ discriminator,
24
+ optimizer_g,
25
+ optimizer_d,
26
+ criterion=criterion,
27
+ data_dimensions=parameters['data_dim'],
28
+ epochs=parameters['epochs'],
29
+ batch_size=parameters['batch_size'],
30
+ device=device,
31
+ n_max_iterations=parameters['n_max_iterations'],
32
+ gamma=parameters['gamma'])
33
+
34
+ return gan
Architectures/ControllabilityGAN/wgan/resnet_1.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ import torch.utils.data.distributed
5
+ from torch import nn
6
+
7
+
8
+ class ResNet_G(nn.Module):
9
+
10
+ def __init__(self, data_dim, z_dim, size, nfilter=64, nfilter_max=512, bn=True, res_ratio=0.1, **kwargs):
11
+ super().__init__()
12
+ self.input_dim = z_dim
13
+ self.output_dim = z_dim
14
+ self.dropout_rate = 0
15
+
16
+ s0 = self.s0 = 4
17
+ nf = self.nf = nfilter
18
+ nf_max = self.nf_max = nfilter_max
19
+ self.bn = bn
20
+ self.z_dim = z_dim
21
+
22
+ # Submodules
23
+ nlayers = int(np.log2(size / s0))
24
+ self.nf0 = min(nf_max, nf * 2 ** (nlayers + 1))
25
+
26
+ self.fc = nn.Linear(z_dim, self.nf0 * s0 * s0)
27
+ if self.bn:
28
+ self.bn1d = nn.BatchNorm1d(self.nf0 * s0 * s0)
29
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
30
+
31
+ blocks = []
32
+ for i in range(nlayers, 0, -1):
33
+ nf0 = min(nf * 2 ** (i + 1), nf_max)
34
+ nf1 = min(nf * 2 ** i, nf_max)
35
+ blocks += [
36
+ ResNetBlock(nf0, nf1, bn=self.bn, res_ratio=res_ratio),
37
+ nn.Upsample(scale_factor=2)
38
+ ]
39
+
40
+ nf0 = min(nf * 2, nf_max)
41
+ nf1 = min(nf, nf_max)
42
+ blocks += [
43
+ ResNetBlock(nf0, nf1, bn=self.bn, res_ratio=res_ratio),
44
+ ResNetBlock(nf1, nf1, bn=self.bn, res_ratio=res_ratio)
45
+ ]
46
+
47
+ self.resnet = nn.Sequential(*blocks)
48
+ self.conv_img = nn.Conv2d(nf, 3, 3, padding=1)
49
+
50
+ self.fc_out = nn.Linear(3 * size * size, data_dim)
51
+
52
+ def forward(self, z, return_intermediate=False):
53
+ # print(z.shape)
54
+ batch_size = z.size(0)
55
+ # z = z.view(batch_size, -1)
56
+ out = self.fc(z)
57
+ if self.bn:
58
+ out = self.bn1d(out)
59
+ out = self.relu(out)
60
+ if return_intermediate:
61
+ l_1 = out.detach().clone()
62
+ out = out.view(batch_size, self.nf0, self.s0, self.s0)
63
+ # print(out.shape)
64
+
65
+ out = self.resnet(out)
66
+
67
+ # print(out.shape)
68
+ # out = out.view(batch_size, self.nf0*self.s0*self.s0*2)
69
+
70
+ out = self.conv_img(out)
71
+ out = self.relu(out)
72
+ out.flatten(1)
73
+ out = self.fc_out(out.flatten(1))
74
+
75
+ if return_intermediate:
76
+ return out, l_1
77
+ return out
78
+
79
+ def sample_latent(self, n_samples, z_size):
80
+ return torch.randn((n_samples, z_size))
81
+
82
+
83
+ class ResNet_D(nn.Module):
84
+
85
+ def __init__(self, data_dim, size, nfilter=64, nfilter_max=512, res_ratio=0.1):
86
+ super().__init__()
87
+ s0 = self.s0 = 4
88
+ nf = self.nf = nfilter
89
+ nf_max = self.nf_max = nfilter_max
90
+ self.size = size
91
+
92
+ # Submodules
93
+ nlayers = int(np.log2(size / s0))
94
+ self.nf0 = min(nf_max, nf * 2 ** nlayers)
95
+
96
+ nf0 = min(nf, nf_max)
97
+ nf1 = min(nf * 2, nf_max)
98
+ blocks = [
99
+ ResNetBlock(nf0, nf0, bn=False, res_ratio=res_ratio),
100
+ ResNetBlock(nf0, nf1, bn=False, res_ratio=res_ratio)
101
+ ]
102
+
103
+ self.fc_input = nn.Linear(data_dim, 3 * size * size)
104
+
105
+ for i in range(1, nlayers + 1):
106
+ nf0 = min(nf * 2 ** i, nf_max)
107
+ nf1 = min(nf * 2 ** (i + 1), nf_max)
108
+ blocks += [
109
+ nn.AvgPool2d(3, stride=2, padding=1),
110
+ ResNetBlock(nf0, nf1, bn=False, res_ratio=res_ratio),
111
+ ]
112
+
113
+ self.conv_img = nn.Conv2d(3, 1 * nf, 3, padding=1)
114
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
115
+ self.resnet = nn.Sequential(*blocks)
116
+
117
+ self.fc = nn.Linear(self.nf0 * s0 * s0, 1)
118
+
119
+ def forward(self, x):
120
+ batch_size = x.size(0)
121
+
122
+ out = self.fc_input(x)
123
+ out = self.relu(out).view(batch_size, 3, self.size, self.size)
124
+
125
+ out = self.relu((self.conv_img(out)))
126
+ out = self.resnet(out)
127
+ out = out.view(batch_size, self.nf0 * self.s0 * self.s0)
128
+ out = self.fc(out)
129
+
130
+ return out
131
+
132
+
133
+ class ResNetBlock(nn.Module):
134
+
135
+ def __init__(self, fin, fout, fhidden=None, bn=True, res_ratio=0.1):
136
+ super().__init__()
137
+ # Attributes
138
+ self.bn = bn
139
+ self.is_bias = not bn
140
+ self.learned_shortcut = (fin != fout)
141
+ self.fin = fin
142
+ self.fout = fout
143
+ if fhidden is None:
144
+ self.fhidden = min(fin, fout)
145
+ else:
146
+ self.fhidden = fhidden
147
+ self.res_ratio = res_ratio
148
+
149
+ # Submodules
150
+ self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1, bias=self.is_bias)
151
+ if self.bn:
152
+ self.bn2d_0 = nn.BatchNorm2d(self.fhidden)
153
+ self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=self.is_bias)
154
+ if self.bn:
155
+ self.bn2d_1 = nn.BatchNorm2d(self.fout)
156
+ if self.learned_shortcut:
157
+ self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False)
158
+ if self.bn:
159
+ self.bn2d_s = nn.BatchNorm2d(self.fout)
160
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
161
+
162
+ def forward(self, x):
163
+ x_s = self._shortcut(x)
164
+ dx = self.conv_0(x)
165
+ if self.bn:
166
+ dx = self.bn2d_0(dx)
167
+ dx = self.relu(dx)
168
+ dx = self.conv_1(dx)
169
+ if self.bn:
170
+ dx = self.bn2d_1(dx)
171
+ out = self.relu(x_s + self.res_ratio * dx)
172
+ return out
173
+
174
+ def _shortcut(self, x):
175
+ if self.learned_shortcut:
176
+ x_s = self.conv_s(x)
177
+ if self.bn:
178
+ x_s = self.bn2d_s(x_s)
179
+ else:
180
+ x_s = x
181
+ return x_s
Architectures/ControllabilityGAN/wgan/resnet_init.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Architectures.ControllabilityGAN.wgan.init_weights import weights_init_D
2
+ from Architectures.ControllabilityGAN.wgan.init_weights import weights_init_G
3
+ from Architectures.ControllabilityGAN.wgan.resnet_1 import ResNet_D
4
+ from Architectures.ControllabilityGAN.wgan.resnet_1 import ResNet_G
5
+
6
+
7
+ def init_resnet(parameters):
8
+ critic = ResNet_D(parameters['data_dim'][-1], parameters['size'], nfilter=parameters['nfilter'], nfilter_max=parameters['nfilter_max'])
9
+ generator = ResNet_G(parameters['data_dim'][-1], parameters['z_dim'], parameters['size'], nfilter=parameters['nfilter'],
10
+ nfilter_max=parameters['nfilter_max'])
11
+
12
+ generator.apply(weights_init_G)
13
+ critic.apply(weights_init_D)
14
+
15
+ return generator, critic
Architectures/ControllabilityGAN/wgan/wgan_qc.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from cvxopt import matrix
9
+ from cvxopt import solvers
10
+ from cvxopt import sparse
11
+ from cvxopt import spmatrix
12
+ from torch.autograd import grad as torch_grad
13
+ from tqdm import tqdm
14
+
15
+
16
+ class WassersteinGanQuadraticCost:
17
+
18
+ def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, criterion, epochs, n_max_iterations,
19
+ data_dimensions, batch_size, device, gamma=0.1, K=-1, milestones=[150000, 250000], lr_anneal=1.0):
20
+ self.G = generator
21
+ self.G_opt = gen_optimizer
22
+ self.D = discriminator
23
+ self.D_opt = dis_optimizer
24
+ self.losses = {
25
+ 'D' : [],
26
+ 'WD': [],
27
+ 'G' : []
28
+ }
29
+ self.num_steps = 0
30
+ self.gen_steps = 0
31
+ self.epochs = epochs
32
+ self.n_max_iterations = n_max_iterations
33
+ # put in the shape of a dataset sample
34
+ self.data_dim = data_dimensions[0] * data_dimensions[1] * data_dimensions[2]
35
+ self.batch_size = batch_size
36
+ self.device = device
37
+ self.criterion = criterion
38
+ self.mone = torch.FloatTensor([-1]).to(device)
39
+ self.tensorboard_counter = 0
40
+
41
+ if K <= 0:
42
+ self.K = 1 / self.data_dim
43
+ else:
44
+ self.K = K
45
+ self.Kr = np.sqrt(self.K)
46
+ self.LAMBDA = 2 * self.Kr * gamma * 2
47
+
48
+ self.G = nn.DataParallel(self.G.to(self.device))
49
+ self.D = nn.DataParallel(self.D.to(self.device))
50
+
51
+ self.schedulerD = self._build_lr_scheduler_(self.D_opt, milestones, lr_anneal)
52
+ self.schedulerG = self._build_lr_scheduler_(self.G_opt, milestones, lr_anneal)
53
+
54
+ self.c, self.A, self.pStart = self._prepare_linear_programming_solver_(self.batch_size)
55
+
56
+ def _build_lr_scheduler_(self, optimizer, milestones, lr_anneal, last_epoch=-1):
57
+ scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=lr_anneal, last_epoch=-1)
58
+ return scheduler
59
+
60
+ def _quadratic_wasserstein_distance_(self, real, generated):
61
+ num_r = real.size(0)
62
+ num_f = generated.size(0)
63
+ real_flat = real.view(num_r, -1)
64
+ fake_flat = generated.view(num_f, -1)
65
+
66
+ real3D = real_flat.unsqueeze(1).expand(num_r, num_f, self.data_dim)
67
+ fake3D = fake_flat.unsqueeze(0).expand(num_r, num_f, self.data_dim)
68
+ # compute squared L2 distance
69
+ dif = real3D - fake3D
70
+ dist = 0.5 * dif.pow(2).sum(2).squeeze()
71
+
72
+ return self.K * dist
73
+
74
+ def _prepare_linear_programming_solver_(self, batch_size):
75
+ A = spmatrix(1.0, range(batch_size), [0] * batch_size, (batch_size, batch_size))
76
+ for i in range(1, batch_size):
77
+ Ai = spmatrix(1.0, range(batch_size), [i] * batch_size, (batch_size, batch_size))
78
+ A = sparse([A, Ai])
79
+
80
+ D = spmatrix(-1.0, range(batch_size), range(batch_size), (batch_size, batch_size))
81
+ DM = D
82
+ for i in range(1, batch_size):
83
+ DM = sparse([DM, D])
84
+
85
+ A = sparse([[A], [DM]])
86
+
87
+ cr = matrix([-1.0 / batch_size] * batch_size)
88
+ cf = matrix([1.0 / batch_size] * batch_size)
89
+ c = matrix([cr, cf])
90
+
91
+ pStart = {}
92
+ pStart['x'] = matrix([matrix([1.0] * batch_size), matrix([-1.0] * batch_size)])
93
+ pStart['s'] = matrix([1.0] * (2 * batch_size))
94
+
95
+ return c, A, pStart
96
+
97
+ def _linear_programming_(self, distance, batch_size):
98
+ b = matrix(distance.cpu().double().detach().numpy().flatten())
99
+ sol = solvers.lp(self.c, self.A, b, primalstart=self.pStart, solver='glpk',
100
+ options={'glpk': {'msg_lev': 'GLP_MSG_OFF'}})
101
+ offset = 0.5 * (sum(sol['x'])) / batch_size
102
+ sol['x'] = sol['x'] - offset
103
+ self.pStart['x'] = sol['x']
104
+ self.pStart['s'] = sol['s']
105
+
106
+ return sol
107
+
108
+ def _approx_OT_(self, sol):
109
+ # Compute the OT mapping for each fake dataset
110
+ ResMat = np.array(sol['z']).reshape((self.batch_size, self.batch_size))
111
+ mapping = torch.from_numpy(np.argmax(ResMat, axis=0)).long().to(self.device)
112
+
113
+ return mapping
114
+
115
+ def _optimal_transport_regularization_(self, output_fake, fake, real_fake_diff):
116
+ output_fake_grad = torch.ones(output_fake.size()).to(self.device)
117
+ gradients = torch_grad(outputs=output_fake, inputs=fake,
118
+ grad_outputs=output_fake_grad,
119
+ create_graph=True, retain_graph=True, only_inputs=True)[0]
120
+ n = gradients.size(0)
121
+ RegLoss = 0.5 * ((gradients.view(n, -1).norm(dim=1) / (2 * self.Kr) - self.Kr / 2 * real_fake_diff.view(n,
122
+ -1).norm(
123
+ dim=1)).pow(2)).mean()
124
+ fake.requires_grad = False
125
+
126
+ return RegLoss
127
+
128
+ def _critic_deep_regression_(self, images, opt_iterations=1):
129
+ images = images.to(self.device)
130
+
131
+ for p in self.D.parameters(): # reset requires_grad
132
+ p.requires_grad = True # they are set to False below in netG update
133
+
134
+ self.G.train()
135
+ self.D.train()
136
+
137
+ # Get generated fake dataset
138
+ generated_data = self.sample_generator(self.batch_size)
139
+
140
+ # compute wasserstein distance
141
+ distance = self._quadratic_wasserstein_distance_(images, generated_data)
142
+ # solve linear programming problem
143
+ sol = self._linear_programming_(distance, self.batch_size)
144
+ # approximate optimal transport
145
+ mapping = self._approx_OT_(sol)
146
+ real_ordered = images[mapping] # match real and fake
147
+ real_fake_diff = real_ordered - generated_data
148
+
149
+ # construct target
150
+ target = torch.from_numpy(np.array(sol['x'])).float()
151
+ target = target.squeeze().to(self.device)
152
+
153
+ for i in range(opt_iterations):
154
+ self.D.zero_grad() # ???
155
+ self.D_opt.zero_grad()
156
+ generated_data.requires_grad_()
157
+ if generated_data.grad is not None:
158
+ generated_data.grad.data.zero_()
159
+ output_real = self.D(images)
160
+ output_fake = self.D(generated_data)
161
+ output_real, output_fake = output_real.squeeze(), output_fake.squeeze()
162
+ output_R_mean = output_real.mean(0).view(1)
163
+ output_F_mean = output_fake.mean(0).view(1)
164
+
165
+ L2LossD_real = self.criterion(output_R_mean[0], target[:self.batch_size].mean())
166
+ L2LossD_fake = self.criterion(output_fake, target[self.batch_size:])
167
+ L2LossD = 0.5 * L2LossD_real + 0.5 * L2LossD_fake
168
+
169
+ reg_loss_D = self._optimal_transport_regularization_(output_fake, generated_data, real_fake_diff)
170
+
171
+ total_loss = L2LossD + self.LAMBDA * reg_loss_D
172
+
173
+ self.losses['D'].append(float(total_loss.data))
174
+
175
+ total_loss.backward()
176
+ self.D_opt.step()
177
+
178
+ # this is supposed to be the wasserstein distance
179
+ wasserstein_distance = output_R_mean - output_F_mean
180
+ self.losses['WD'].append(float(wasserstein_distance.data))
181
+
182
+ def _generator_train_iteration(self, batch_size):
183
+ for p in self.D.parameters():
184
+ p.requires_grad = False # freeze critic
185
+
186
+ self.G.zero_grad()
187
+ self.G_opt.zero_grad()
188
+
189
+ if isinstance(self.G, torch.nn.parallel.DataParallel):
190
+ z = self.G.module.sample_latent(batch_size, self.G.module.z_dim)
191
+ else:
192
+ z = self.G.sample_latent(batch_size, self.G.z_dim)
193
+ z.requires_grad = True
194
+
195
+ fake = self.G(z)
196
+ output_fake = self.D(fake)
197
+ output_F_mean_after = output_fake.mean(0).view(1)
198
+
199
+ self.losses['G'].append(float(output_F_mean_after.data))
200
+
201
+ output_F_mean_after.backward(self.mone)
202
+ self.G_opt.step()
203
+
204
+ self.schedulerD.step()
205
+ self.schedulerG.step()
206
+
207
+ def _train_epoch(self, data_loader, writer, experiment):
208
+ for i, data in enumerate(tqdm(data_loader)):
209
+ images = data[0]
210
+ speaker_ids = data[1]
211
+ self.num_steps += 1
212
+ # self.tensorboard_counter += 1
213
+ if self.gen_steps >= self.n_max_iterations:
214
+ return
215
+ self._critic_deep_regression_(images)
216
+ self._generator_train_iteration(images.size(0))
217
+
218
+ D_loss_avg = np.average(self.losses['D'])
219
+ G_loss_avg = np.average(self.losses['G'])
220
+ wd_avg = np.average(self.losses['WD'])
221
+
222
+ def train(self, data_loader, writer, experiment=None):
223
+ self.G.train()
224
+ self.D.train()
225
+
226
+ for epoch in range(self.epochs):
227
+ if self.gen_steps >= self.n_max_iterations:
228
+ return
229
+ time_start_epoch = time.time()
230
+ self._train_epoch(data_loader, writer, experiment)
231
+
232
+ D_loss_avg = np.average(self.losses['D'])
233
+
234
+ time_end_epoch = time.time()
235
+
236
+ return self
237
+
238
+ def sample_generator(self, num_samples, nograd=False, return_intermediate=False):
239
+ self.G.eval()
240
+ if isinstance(self.G, torch.nn.parallel.DataParallel):
241
+ latent_samples = self.G.module.sample_latent(num_samples, self.G.module.z_dim)
242
+ else:
243
+ latent_samples = self.G.sample_latent(num_samples, self.G.z_dim)
244
+ latent_samples = latent_samples.to(self.device)
245
+ if nograd:
246
+ with torch.no_grad():
247
+ generated_data = self.G(latent_samples, return_intermediate=return_intermediate)
248
+ else:
249
+ generated_data = self.G(latent_samples)
250
+ self.G.train()
251
+ if return_intermediate:
252
+ return generated_data[0].detach(), generated_data[1], latent_samples
253
+ return generated_data.detach()
254
+
255
+ def sample(self, num_samples):
256
+ generated_data = self.sample_generator(num_samples)
257
+ # Remove color channel
258
+ return generated_data.data.cpu().numpy()[:, 0, :, :]
259
+
260
+ def save_model_checkpoint(self, model_path, model_parameters, timestampStr):
261
+ # dateTimeObj = datetime.now()
262
+ # timestampStr = dateTimeObj.strftime("%d-%m-%Y-%H-%M-%S")
263
+ name = '%s_%s' % (timestampStr, 'wgan')
264
+ model_filename = os.path.join(model_path, name)
265
+ torch.save({
266
+ 'generator_state_dict' : self.G.state_dict(),
267
+ 'critic_state_dict' : self.D.state_dict(),
268
+ 'gen_optimizer_state_dict' : self.G_opt.state_dict(),
269
+ 'critic_optimizer_state_dict': self.D_opt.state_dict(),
270
+ 'model_parameters' : model_parameters,
271
+ 'iterations' : self.num_steps
272
+ }, model_filename)
Architectures/EmbeddingModel/GST.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Nagoya University (Tomoki Hayashi)
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import torch
5
+
6
+ from Architectures.GeneralLayers.Attention import MultiHeadedAttention as BaseMultiHeadedAttention
7
+
8
+
9
+ class GSTStyleEncoder(torch.nn.Module):
10
+ """Style encoder.
11
+ This module is style encoder introduced in `Style Tokens: Unsupervised Style
12
+ Modeling, Control and Transfer in End-to-End Speech Synthesis`.
13
+ .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End
14
+ Speech Synthesis`: https://arxiv.org/abs/1803.09017
15
+ Args:
16
+ idim (int, optional): Dimension of the input features.
17
+ gst_tokens (int, optional): The number of GST embeddings.
18
+ gst_token_dim (int, optional): Dimension of each GST embedding.
19
+ gst_heads (int, optional): The number of heads in GST multihead attention.
20
+ conv_layers (int, optional): The number of conv layers in the reference encoder.
21
+ conv_chans_list: (Sequence[int], optional):
22
+ List of the number of channels of conv layers in the reference encoder.
23
+ conv_kernel_size (int, optional):
24
+ Kernel size of conv layers in the reference encoder.
25
+ conv_stride (int, optional):
26
+ Stride size of conv layers in the reference encoder.
27
+ gst_layers (int, optional): The number of GRU layers in the reference encoder.
28
+ gst_units (int, optional): The number of GRU units in the reference encoder.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ idim: int = 128,
34
+ gst_tokens: int = 512, # adaspeech suggests to use many more "basis vectors", but I believe that this is already sufficient
35
+ gst_token_dim: int = 64,
36
+ gst_heads: int = 8,
37
+ conv_layers: int = 8,
38
+ conv_chans_list=(32, 32, 64, 64, 128, 128, 256, 256),
39
+ conv_kernel_size: int = 3,
40
+ conv_stride: int = 2,
41
+ gst_layers: int = 2,
42
+ gst_units: int = 256,
43
+ ):
44
+ """Initialize global style encoder module."""
45
+ super(GSTStyleEncoder, self).__init__()
46
+
47
+ self.num_tokens = gst_tokens
48
+ self.ref_enc = ReferenceEncoder(idim=idim,
49
+ conv_layers=conv_layers,
50
+ conv_chans_list=conv_chans_list,
51
+ conv_kernel_size=conv_kernel_size,
52
+ conv_stride=conv_stride,
53
+ gst_layers=gst_layers,
54
+ gst_units=gst_units, )
55
+ self.stl = StyleTokenLayer(ref_embed_dim=gst_units,
56
+ gst_tokens=gst_tokens,
57
+ gst_token_dim=gst_token_dim,
58
+ gst_heads=gst_heads, )
59
+
60
+ def forward(self, speech):
61
+ """Calculate forward propagation.
62
+ Args:
63
+ speech (Tensor): Batch of padded target features (B, Lmax, odim).
64
+ Returns:
65
+ Tensor: Style token embeddings (B, token_dim).
66
+ """
67
+ ref_embs = self.ref_enc(speech)
68
+ style_embs = self.stl(ref_embs)
69
+
70
+ return style_embs
71
+
72
+ def calculate_ada4_regularization_loss(self):
73
+ losses = list()
74
+ for emb1_index in range(self.num_tokens):
75
+ for emb2_index in range(emb1_index + 1, self.num_tokens):
76
+ if emb1_index != emb2_index:
77
+ losses.append(torch.nn.functional.cosine_similarity(self.stl.gst_embs[emb1_index],
78
+ self.stl.gst_embs[emb2_index], dim=0))
79
+ return sum(losses)
80
+
81
+
82
+ class ReferenceEncoder(torch.nn.Module):
83
+ """Reference encoder module.
84
+ This module is reference encoder introduced in `Style Tokens: Unsupervised Style
85
+ Modeling, Control and Transfer in End-to-End Speech Synthesis`.
86
+ .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End
87
+ Speech Synthesis`: https://arxiv.org/abs/1803.09017
88
+ Args:
89
+ idim (int, optional): Dimension of the input features.
90
+ conv_layers (int, optional): The number of conv layers in the reference encoder.
91
+ conv_chans_list: (Sequence[int], optional):
92
+ List of the number of channels of conv layers in the reference encoder.
93
+ conv_kernel_size (int, optional):
94
+ Kernel size of conv layers in the reference encoder.
95
+ conv_stride (int, optional):
96
+ Stride size of conv layers in the reference encoder.
97
+ gst_layers (int, optional): The number of GRU layers in the reference encoder.
98
+ gst_units (int, optional): The number of GRU units in the reference encoder.
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ idim=80,
104
+ conv_layers: int = 6,
105
+ conv_chans_list=(32, 32, 64, 64, 128, 128),
106
+ conv_kernel_size: int = 3,
107
+ conv_stride: int = 2,
108
+ gst_layers: int = 1,
109
+ gst_units: int = 128,
110
+ ):
111
+ """Initialize reference encoder module."""
112
+ super(ReferenceEncoder, self).__init__()
113
+
114
+ # check hyperparameters are valid
115
+ assert conv_kernel_size % 2 == 1, "kernel size must be odd."
116
+ assert (
117
+ len(conv_chans_list) == conv_layers), "the number of conv layers and length of channels list must be the same."
118
+
119
+ convs = []
120
+ padding = (conv_kernel_size - 1) // 2
121
+ for i in range(conv_layers):
122
+ conv_in_chans = 1 if i == 0 else conv_chans_list[i - 1]
123
+ conv_out_chans = conv_chans_list[i]
124
+ convs += [torch.nn.Conv2d(conv_in_chans,
125
+ conv_out_chans,
126
+ kernel_size=conv_kernel_size,
127
+ stride=conv_stride,
128
+ padding=padding,
129
+ # Do not use bias due to the following batch norm
130
+ bias=False, ),
131
+ torch.nn.BatchNorm2d(conv_out_chans),
132
+ torch.nn.ReLU(inplace=True), ]
133
+ self.convs = torch.nn.Sequential(*convs)
134
+
135
+ self.conv_layers = conv_layers
136
+ self.kernel_size = conv_kernel_size
137
+ self.stride = conv_stride
138
+ self.padding = padding
139
+
140
+ # get the number of GRU input units
141
+ gst_in_units = idim
142
+ for i in range(conv_layers):
143
+ gst_in_units = (gst_in_units - conv_kernel_size + 2 * padding) // conv_stride + 1
144
+ gst_in_units *= conv_out_chans
145
+ self.gst = torch.nn.GRU(gst_in_units, gst_units, gst_layers, batch_first=True)
146
+
147
+ def forward(self, speech):
148
+ """Calculate forward propagation.
149
+ Args:
150
+ speech (Tensor): Batch of padded target features (B, Lmax, idim).
151
+ Returns:
152
+ Tensor: Reference embedding (B, gst_units)
153
+ """
154
+ batch_size = speech.size(0)
155
+ xs = speech.unsqueeze(1) # (B, 1, Lmax, idim)
156
+ hs = self.convs(xs).transpose(1, 2) # (B, Lmax', conv_out_chans, idim')
157
+ time_length = hs.size(1)
158
+ hs = hs.contiguous().view(batch_size, time_length, -1) # (B, Lmax', gst_units)
159
+ self.gst.flatten_parameters()
160
+ # pack_padded_sequence(hs, speech_lens, enforce_sorted=False, batch_first=True)
161
+ _, ref_embs = self.gst(hs) # (gst_layers, batch_size, gst_units)
162
+ ref_embs = ref_embs[-1] # (batch_size, gst_units)
163
+
164
+ return ref_embs
165
+
166
+
167
+ class StyleTokenLayer(torch.nn.Module):
168
+ """Style token layer module.
169
+ This module is style token layer introduced in `Style Tokens: Unsupervised Style
170
+ Modeling, Control and Transfer in End-to-End Speech Synthesis`.
171
+ .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End
172
+ Speech Synthesis`: https://arxiv.org/abs/1803.09017
173
+ Args:
174
+ ref_embed_dim (int, optional): Dimension of the input reference embedding.
175
+ gst_tokens (int, optional): The number of GST embeddings.
176
+ gst_token_dim (int, optional): Dimension of each GST embedding.
177
+ gst_heads (int, optional): The number of heads in GST multihead attention.
178
+ dropout_rate (float, optional): Dropout rate in multi-head attention.
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ ref_embed_dim: int = 128,
184
+ gst_tokens: int = 10,
185
+ gst_token_dim: int = 128,
186
+ gst_heads: int = 4,
187
+ dropout_rate: float = 0.0,
188
+ ):
189
+ """Initialize style token layer module."""
190
+ super(StyleTokenLayer, self).__init__()
191
+
192
+ gst_embs = torch.randn(gst_tokens, gst_token_dim // gst_heads)
193
+ self.register_parameter("gst_embs", torch.nn.Parameter(gst_embs))
194
+ self.mha = MultiHeadedAttention(q_dim=ref_embed_dim,
195
+ k_dim=gst_token_dim // gst_heads,
196
+ v_dim=gst_token_dim // gst_heads,
197
+ n_head=gst_heads,
198
+ n_feat=gst_token_dim,
199
+ dropout_rate=dropout_rate, )
200
+
201
+ def forward(self, ref_embs):
202
+ """Calculate forward propagation.
203
+ Args:
204
+ ref_embs (Tensor): Reference embeddings (B, ref_embed_dim).
205
+ Returns:
206
+ Tensor: Style token embeddings (B, gst_token_dim).
207
+ """
208
+ batch_size = ref_embs.size(0)
209
+ # (num_tokens, token_dim) -> (batch_size, num_tokens, token_dim)
210
+ gst_embs = torch.tanh(self.gst_embs).unsqueeze(0).expand(batch_size, -1, -1)
211
+ # NOTE(kan-bayashi): Shoule we apply Tanh?
212
+ ref_embs = ref_embs.unsqueeze(1) # (batch_size, 1 ,ref_embed_dim)
213
+ style_embs = self.mha(ref_embs, gst_embs, gst_embs, None)
214
+
215
+ return style_embs.squeeze(1)
216
+
217
+
218
+ class MultiHeadedAttention(BaseMultiHeadedAttention):
219
+ """Multi head attention module with different input dimension."""
220
+
221
+ def __init__(self, q_dim, k_dim, v_dim, n_head, n_feat, dropout_rate=0.0):
222
+ """Initialize multi head attention module."""
223
+ # NOTE(kan-bayashi): Do not use super().__init__() here since we want to
224
+ # overwrite BaseMultiHeadedAttention.__init__() method.
225
+ torch.nn.Module.__init__(self)
226
+ assert n_feat % n_head == 0
227
+ # We assume d_v always equals d_k
228
+ self.d_k = n_feat // n_head
229
+ self.h = n_head
230
+ self.linear_q = torch.nn.Linear(q_dim, n_feat)
231
+ self.linear_k = torch.nn.Linear(k_dim, n_feat)
232
+ self.linear_v = torch.nn.Linear(v_dim, n_feat)
233
+ self.linear_out = torch.nn.Linear(n_feat, n_feat)
234
+ self.attn = None
235
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
Architectures/EmbeddingModel/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Everything that is concerned with the embedding model is contained in this directory. The embedding function does not have its own train loop, because it is always trained jointly with the TTS. Most of the time however, it is used in a frozen state. We recommend using the embedding function that we publish in the GitHub releases.
Architectures/EmbeddingModel/StyleEmbedding.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from Architectures.EmbeddingModel.GST import GSTStyleEncoder
4
+ from Architectures.EmbeddingModel.StyleTTSEncoder import StyleEncoder as StyleTTSEncoder
5
+
6
+
7
+ class StyleEmbedding(torch.nn.Module):
8
+ """
9
+ The style embedding should provide information of the speaker and their speaking style
10
+
11
+ The feedback signal for the module will come from the TTS objective, so it doesn't have a dedicated train loop.
12
+ The train loop does however supply supervision in the form of a barlow twins objective.
13
+
14
+ See the git history for some other approaches for style embedding, like the SWIN transformer
15
+ and a simple LSTM baseline. GST turned out to be the best.
16
+ """
17
+
18
+ def __init__(self, embedding_dim=16, style_tts_encoder=False):
19
+ super().__init__()
20
+ self.embedding_dim = embedding_dim
21
+ self.use_gst = not style_tts_encoder
22
+ if style_tts_encoder:
23
+ self.style_encoder = StyleTTSEncoder(style_dim=embedding_dim)
24
+ else:
25
+ self.style_encoder = GSTStyleEncoder(gst_token_dim=embedding_dim)
26
+
27
+ def forward(self,
28
+ batch_of_feature_sequences,
29
+ batch_of_feature_sequence_lengths):
30
+ """
31
+ Args:
32
+ batch_of_feature_sequences: b is the batch axis, 128 features per timestep
33
+ and l time-steps, which may include padding
34
+ for most elements in the batch (b, l, 128)
35
+ batch_of_feature_sequence_lengths: indicate for every element in the batch,
36
+ what the true length is, since they are
37
+ all padded to the length of the longest
38
+ element in the batch (b, 1)
39
+ Returns:
40
+ batch of n dimensional embeddings (b,n)
41
+ """
42
+
43
+ minimum_sequence_length = 512
44
+ specs = list()
45
+ for index, spec_length in enumerate(batch_of_feature_sequence_lengths):
46
+ spec = batch_of_feature_sequences[index][:spec_length]
47
+ # double the length at least once, then check
48
+ spec = spec.repeat((2, 1))
49
+ current_spec_length = len(spec)
50
+ while current_spec_length < minimum_sequence_length:
51
+ # make it longer
52
+ spec = spec.repeat((2, 1))
53
+ current_spec_length = len(spec)
54
+ specs.append(spec[:minimum_sequence_length])
55
+
56
+ spec_batch = torch.stack(specs, dim=0)
57
+ return self.style_encoder(speech=spec_batch)
58
+
59
+
60
+ if __name__ == '__main__':
61
+ style_emb = StyleEmbedding(style_tts_encoder=False)
62
+ print(f"GST parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}")
63
+
64
+ seq_length = 398
65
+ print(style_emb(torch.randn(5, seq_length, 512),
66
+ torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape)
67
+
68
+ style_emb = StyleEmbedding(style_tts_encoder=True)
69
+ print(f"StyleTTS encoder parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}")
70
+
71
+ seq_length = 398
72
+ print(style_emb(torch.randn(5, seq_length, 512),
73
+ torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape)
Architectures/EmbeddingModel/StyleTTSEncoder.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MIT Licensed Code
3
+
4
+ Copyright (c) 2022 Aaron (Yinghao) Li
5
+
6
+ https://github.com/yl4579/StyleTTS/blob/main/models.py
7
+ """
8
+
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from torch.nn.utils import spectral_norm
15
+
16
+
17
+ class StyleEncoder(nn.Module):
18
+ def __init__(self, dim_in=128, style_dim=64, max_conv_dim=384):
19
+ super().__init__()
20
+ blocks = []
21
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
22
+
23
+ repeat_num = 4
24
+ for _ in range(repeat_num):
25
+ dim_out = min(dim_in * 2, max_conv_dim)
26
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
27
+ dim_in = dim_out
28
+
29
+ blocks += [nn.LeakyReLU(0.2)]
30
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
31
+ blocks += [nn.AdaptiveAvgPool2d(1)]
32
+ blocks += [nn.LeakyReLU(0.2)]
33
+ self.shared = nn.Sequential(*blocks)
34
+
35
+ self.unshared = nn.Linear(dim_out, style_dim)
36
+
37
+ def forward(self, speech):
38
+ h = self.shared(speech.unsqueeze(1))
39
+ h = h.view(h.size(0), -1)
40
+ s = self.unshared(h)
41
+
42
+ return s
43
+
44
+
45
+ class ResBlk(nn.Module):
46
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
47
+ normalize=False, downsample='none'):
48
+ super().__init__()
49
+ self.actv = actv
50
+ self.normalize = normalize
51
+ self.downsample = DownSample(downsample)
52
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
53
+ self.learned_sc = dim_in != dim_out
54
+ self._build_weights(dim_in, dim_out)
55
+
56
+ def _build_weights(self, dim_in, dim_out):
57
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
58
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
59
+ if self.normalize:
60
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
61
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
62
+ if self.learned_sc:
63
+ self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
64
+
65
+ def _shortcut(self, x):
66
+ if self.learned_sc:
67
+ x = self.conv1x1(x)
68
+ if self.downsample:
69
+ x = self.downsample(x)
70
+ return x
71
+
72
+ def _residual(self, x):
73
+ if self.normalize:
74
+ x = self.norm1(x)
75
+ x = self.actv(x)
76
+ x = self.conv1(x)
77
+ x = self.downsample_res(x)
78
+ if self.normalize:
79
+ x = self.norm2(x)
80
+ x = self.actv(x)
81
+ x = self.conv2(x)
82
+ return x
83
+
84
+ def forward(self, x):
85
+ x = self._shortcut(x) + self._residual(x)
86
+ return x / math.sqrt(2) # unit variance
87
+
88
+
89
+ class LearnedDownSample(nn.Module):
90
+ def __init__(self, layer_type, dim_in):
91
+ super().__init__()
92
+ self.layer_type = layer_type
93
+
94
+ if self.layer_type == 'none':
95
+ self.conv = nn.Identity()
96
+ elif self.layer_type == 'timepreserve':
97
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
98
+ elif self.layer_type == 'half':
99
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
100
+ else:
101
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
102
+
103
+ def forward(self, x):
104
+ return self.conv(x)
105
+
106
+
107
+ class LearnedUpSample(nn.Module):
108
+ def __init__(self, layer_type, dim_in):
109
+ super().__init__()
110
+ self.layer_type = layer_type
111
+
112
+ if self.layer_type == 'none':
113
+ self.conv = nn.Identity()
114
+ elif self.layer_type == 'timepreserve':
115
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
116
+ elif self.layer_type == 'half':
117
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
118
+ else:
119
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
120
+
121
+ def forward(self, x):
122
+ return self.conv(x)
123
+
124
+
125
+ class DownSample(nn.Module):
126
+ def __init__(self, layer_type):
127
+ super().__init__()
128
+ self.layer_type = layer_type
129
+
130
+ def forward(self, x):
131
+ if self.layer_type == 'none':
132
+ return x
133
+ elif self.layer_type == 'timepreserve':
134
+ return F.avg_pool2d(x, (2, 1))
135
+ elif self.layer_type == 'half':
136
+ if x.shape[-1] % 2 != 0:
137
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
138
+ return F.avg_pool2d(x, 2)
139
+ else:
140
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
141
+
142
+
143
+ class UpSample(nn.Module):
144
+ def __init__(self, layer_type):
145
+ super().__init__()
146
+ self.layer_type = layer_type
147
+
148
+ def forward(self, x):
149
+ if self.layer_type == 'none':
150
+ return x
151
+ elif self.layer_type == 'timepreserve':
152
+ return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
153
+ elif self.layer_type == 'half':
154
+ return F.interpolate(x, scale_factor=2, mode='nearest')
155
+ else:
156
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
Architectures/EmbeddingModel/__init__.py ADDED
File without changes
Architectures/GeneralLayers/Attention.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Shigeki Karita, 2019
2
+ # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux, 2021
4
+
5
+ """Multi-Head Attention layer definition."""
6
+
7
+ import math
8
+
9
+ import numpy
10
+ import torch
11
+ from torch import nn
12
+
13
+ from Utility.utils import make_non_pad_mask
14
+
15
+
16
+ class MultiHeadedAttention(nn.Module):
17
+ """
18
+ Multi-Head Attention layer.
19
+
20
+ Args:
21
+ n_head (int): The number of heads.
22
+ n_feat (int): The number of features.
23
+ dropout_rate (float): Dropout rate.
24
+ """
25
+
26
+ def __init__(self, n_head, n_feat, dropout_rate):
27
+ """
28
+ Construct an MultiHeadedAttention object.
29
+ """
30
+ super(MultiHeadedAttention, self).__init__()
31
+ assert n_feat % n_head == 0
32
+ # We assume d_v always equals d_k
33
+ self.d_k = n_feat // n_head
34
+ self.h = n_head
35
+ self.linear_q = nn.Linear(n_feat, n_feat)
36
+ self.linear_k = nn.Linear(n_feat, n_feat)
37
+ self.linear_v = nn.Linear(n_feat, n_feat)
38
+ self.linear_out = nn.Linear(n_feat, n_feat)
39
+ self.attn = None
40
+ self.dropout = nn.Dropout(p=dropout_rate)
41
+
42
+ def forward_qkv(self, query, key, value):
43
+ """
44
+ Transform query, key and value.
45
+
46
+ Args:
47
+ query (torch.Tensor): Query tensor (#batch, time1, size).
48
+ key (torch.Tensor): Key tensor (#batch, time2, size).
49
+ value (torch.Tensor): Value tensor (#batch, time2, size).
50
+
51
+ Returns:
52
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
53
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
54
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
55
+ """
56
+ n_batch = query.size(0)
57
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
58
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
59
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
60
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
61
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
62
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
63
+
64
+ return q, k, v
65
+
66
+ def forward_attention(self, value, scores, mask):
67
+ """
68
+ Compute attention context vector.
69
+
70
+ Args:
71
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
72
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
73
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
74
+
75
+ Returns:
76
+ torch.Tensor: Transformed value (#batch, time1, d_model)
77
+ weighted by the attention score (#batch, time1, time2).
78
+ """
79
+ n_batch = value.size(0)
80
+ if mask is not None:
81
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
82
+ min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
83
+ scores = scores.masked_fill(mask, min_value)
84
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
85
+ else:
86
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
87
+
88
+ p_attn = self.dropout(self.attn)
89
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
90
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)) # (batch, time1, d_model)
91
+
92
+ return self.linear_out(x) # (batch, time1, d_model)
93
+
94
+ def forward(self, query, key, value, mask):
95
+ """
96
+ Compute scaled dot product attention.
97
+
98
+ Args:
99
+ query (torch.Tensor): Query tensor (#batch, time1, size).
100
+ key (torch.Tensor): Key tensor (#batch, time2, size).
101
+ value (torch.Tensor): Value tensor (#batch, time2, size).
102
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
103
+ (#batch, time1, time2).
104
+
105
+ Returns:
106
+ torch.Tensor: Output tensor (#batch, time1, d_model).
107
+ """
108
+ q, k, v = self.forward_qkv(query, key, value)
109
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
110
+ return self.forward_attention(v, scores, mask)
111
+
112
+
113
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
114
+ """
115
+ Multi-Head Attention layer with relative position encoding.
116
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
117
+ Paper: https://arxiv.org/abs/1901.02860
118
+ Args:
119
+ n_head (int): The number of heads.
120
+ n_feat (int): The number of features.
121
+ dropout_rate (float): Dropout rate.
122
+ zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
123
+ """
124
+
125
+ def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
126
+ """Construct an RelPositionMultiHeadedAttention object."""
127
+ super().__init__(n_head, n_feat, dropout_rate)
128
+ self.zero_triu = zero_triu
129
+ # linear transformation for positional encoding
130
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
131
+ # these two learnable bias are used in matrix c and matrix d
132
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
133
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
134
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
135
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
136
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
137
+
138
+ def rel_shift(self, x):
139
+ """
140
+ Compute relative positional encoding.
141
+ Args:
142
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
143
+ time1 means the length of query vector.
144
+ Returns:
145
+ torch.Tensor: Output tensor.
146
+ """
147
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
148
+ x_padded = torch.cat([zero_pad, x], dim=-1)
149
+
150
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
151
+ x = x_padded[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1] # only keep the positions from 0 to time2
152
+
153
+ if self.zero_triu:
154
+ ones = torch.ones((x.size(2), x.size(3)), device=x.device)
155
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
156
+
157
+ return x
158
+
159
+ def forward(self, query, key, value, pos_emb, mask):
160
+ """
161
+ Compute 'Scaled Dot Product Attention' with rel. positional encoding.
162
+ Args:
163
+ query (torch.Tensor): Query tensor (#batch, time1, size).
164
+ key (torch.Tensor): Key tensor (#batch, time2, size).
165
+ value (torch.Tensor): Value tensor (#batch, time2, size).
166
+ pos_emb (torch.Tensor): Positional embedding tensor
167
+ (#batch, 2*time1-1, size).
168
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
169
+ (#batch, time1, time2).
170
+ Returns:
171
+ torch.Tensor: Output tensor (#batch, time1, d_model).
172
+ """
173
+ q, k, v = self.forward_qkv(query, key, value)
174
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
175
+
176
+ n_batch_pos = pos_emb.size(0)
177
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
178
+ p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
179
+
180
+ # (batch, head, time1, d_k)
181
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
182
+ # (batch, head, time1, d_k)
183
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
184
+
185
+ # compute attention score
186
+ # first compute matrix a and matrix c
187
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
188
+ # (batch, head, time1, time2)
189
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
190
+
191
+ # compute matrix b and matrix d
192
+ # (batch, head, time1, 2*time1-1)
193
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
194
+ matrix_bd = self.rel_shift(matrix_bd)
195
+
196
+ scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2)
197
+
198
+ return self.forward_attention(v, scores, mask)
199
+
200
+
201
+ class GuidedAttentionLoss(torch.nn.Module):
202
+ """
203
+ Guided attention loss function module.
204
+
205
+ This module calculates the guided attention loss described
206
+ in `Efficiently Trainable Text-to-Speech System Based
207
+ on Deep Convolutional Networks with Guided Attention`_,
208
+ which forces the attention to be diagonal.
209
+
210
+ .. _`Efficiently Trainable Text-to-Speech System
211
+ Based on Deep Convolutional Networks with Guided Attention`:
212
+ https://arxiv.org/abs/1710.08969
213
+ """
214
+
215
+ def __init__(self, sigma=0.4, alpha=1.0):
216
+ """
217
+ Initialize guided attention loss module.
218
+
219
+ Args:
220
+ sigma (float, optional): Standard deviation to control
221
+ how close attention to a diagonal.
222
+ alpha (float, optional): Scaling coefficient (lambda).
223
+ reset_always (bool, optional): Whether to always reset masks.
224
+ """
225
+ super(GuidedAttentionLoss, self).__init__()
226
+ self.sigma = sigma
227
+ self.alpha = alpha
228
+ self.guided_attn_masks = None
229
+ self.masks = None
230
+
231
+ def _reset_masks(self):
232
+ self.guided_attn_masks = None
233
+ self.masks = None
234
+
235
+ def forward(self, att_ws, ilens, olens):
236
+ """
237
+ Calculate forward propagation.
238
+
239
+ Args:
240
+ att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in).
241
+ ilens (LongTensor): Batch of input lenghts (B,).
242
+ olens (LongTensor): Batch of output lenghts (B,).
243
+
244
+ Returns:
245
+ Tensor: Guided attention loss value.
246
+ """
247
+ self._reset_masks()
248
+ self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(att_ws.device)
249
+ self.masks = self._make_masks(ilens, olens).to(att_ws.device)
250
+ losses = self.guided_attn_masks * att_ws
251
+ loss = torch.mean(losses.masked_select(self.masks))
252
+ self._reset_masks()
253
+ return self.alpha * loss
254
+
255
+ def _make_guided_attention_masks(self, ilens, olens):
256
+ n_batches = len(ilens)
257
+ max_ilen = max(ilens)
258
+ max_olen = max(olens)
259
+ guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=ilens.device)
260
+ for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
261
+ guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma)
262
+ return guided_attn_masks
263
+
264
+ @staticmethod
265
+ def _make_guided_attention_mask(ilen, olen, sigma):
266
+ """
267
+ Make guided attention mask.
268
+ """
269
+ grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device).float(), torch.arange(ilen, device=ilen.device).float())
270
+ return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2)))
271
+
272
+ @staticmethod
273
+ def _make_masks(ilens, olens):
274
+ """
275
+ Make masks indicating non-padded part.
276
+
277
+ Args:
278
+ ilens (LongTensor or List): Batch of lengths (B,).
279
+ olens (LongTensor or List): Batch of lengths (B,).
280
+
281
+ Returns:
282
+ Tensor: Mask tensor indicating non-padded part.
283
+ dtype=torch.uint8 in PyTorch 1.2-
284
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
285
+ """
286
+ in_masks = make_non_pad_mask(ilens, device=ilens.device) # (B, T_in)
287
+ out_masks = make_non_pad_mask(olens, device=olens.device) # (B, T_out)
288
+ return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
289
+
290
+
291
+ class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
292
+ """
293
+ Guided attention loss function module for multi head attention.
294
+
295
+ Args:
296
+ sigma (float, optional): Standard deviation to control
297
+ how close attention to a diagonal.
298
+ alpha (float, optional): Scaling coefficient (lambda).
299
+ reset_always (bool, optional): Whether to always reset masks.
300
+ """
301
+
302
+ def forward(self, att_ws, ilens, olens):
303
+ """
304
+ Calculate forward propagation.
305
+
306
+ Args:
307
+ att_ws (Tensor):
308
+ Batch of multi head attention weights (B, H, T_max_out, T_max_in).
309
+ ilens (LongTensor): Batch of input lenghts (B,).
310
+ olens (LongTensor): Batch of output lenghts (B,).
311
+
312
+ Returns:
313
+ Tensor: Guided attention loss value.
314
+ """
315
+ if self.guided_attn_masks is None:
316
+ self.guided_attn_masks = (self._make_guided_attention_masks(ilens, olens).to(att_ws.device).unsqueeze(1))
317
+ if self.masks is None:
318
+ self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
319
+ losses = self.guided_attn_masks * att_ws
320
+ loss = torch.mean(losses.masked_select(self.masks))
321
+ if self.reset_always:
322
+ self._reset_masks()
323
+
324
+ return self.alpha * loss
Architectures/GeneralLayers/ConditionalLayerNorm.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code taken from https://github.com/tuanh123789/AdaSpeech/blob/main/model/adaspeech_modules.py
3
+ By https://github.com/tuanh123789
4
+ No license specified
5
+
6
+ Implemented as outlined in AdaSpeech https://arxiv.org/pdf/2103.00993.pdf
7
+ Used in this toolkit similar to how it is done in AdaSpeech 4 https://arxiv.org/pdf/2204.00436.pdf
8
+
9
+ """
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+
15
+ class ConditionalLayerNorm(nn.Module):
16
+
17
+ def __init__(self,
18
+ hidden_dim,
19
+ speaker_embedding_dim,
20
+ dim=-1):
21
+ super(ConditionalLayerNorm, self).__init__()
22
+ self.dim = dim
23
+ if isinstance(hidden_dim, int):
24
+ self.normal_shape = hidden_dim
25
+ self.speaker_embedding_dim = speaker_embedding_dim
26
+ self.W_scale = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape),
27
+ nn.Tanh(),
28
+ nn.Linear(self.normal_shape, self.normal_shape))
29
+ self.W_bias = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape),
30
+ nn.Tanh(),
31
+ nn.Linear(self.normal_shape, self.normal_shape))
32
+ self.reset_parameters()
33
+
34
+ def reset_parameters(self):
35
+ torch.nn.init.constant_(self.W_scale[0].weight, 0.0)
36
+ torch.nn.init.constant_(self.W_scale[2].weight, 0.0)
37
+ torch.nn.init.constant_(self.W_scale[0].bias, 1.0)
38
+ torch.nn.init.constant_(self.W_scale[2].bias, 1.0)
39
+ torch.nn.init.constant_(self.W_bias[0].weight, 0.0)
40
+ torch.nn.init.constant_(self.W_bias[2].weight, 0.0)
41
+ torch.nn.init.constant_(self.W_bias[0].bias, 0.0)
42
+ torch.nn.init.constant_(self.W_bias[2].bias, 0.0)
43
+
44
+ def forward(self, x, speaker_embedding):
45
+
46
+ if self.dim != -1:
47
+ x = x.transpose(-1, self.dim)
48
+
49
+ mean = x.mean(dim=-1, keepdim=True)
50
+ var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
51
+ scale = self.W_scale(speaker_embedding)
52
+ bias = self.W_bias(speaker_embedding)
53
+
54
+ y = scale.unsqueeze(1) * ((x - mean) / var) + bias.unsqueeze(1)
55
+
56
+ if self.dim != -1:
57
+ y = y.transpose(-1, self.dim)
58
+
59
+ return y
60
+
61
+
62
+ class SequentialWrappableConditionalLayerNorm(nn.Module):
63
+
64
+ def __init__(self,
65
+ hidden_dim,
66
+ speaker_embedding_dim):
67
+ super(SequentialWrappableConditionalLayerNorm, self).__init__()
68
+ if isinstance(hidden_dim, int):
69
+ self.normal_shape = hidden_dim
70
+ self.speaker_embedding_dim = speaker_embedding_dim
71
+ self.W_scale = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape),
72
+ nn.Tanh(),
73
+ nn.Linear(self.normal_shape, self.normal_shape))
74
+ self.W_bias = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape),
75
+ nn.Tanh(),
76
+ nn.Linear(self.normal_shape, self.normal_shape))
77
+ self.reset_parameters()
78
+
79
+ def reset_parameters(self):
80
+ torch.nn.init.constant_(self.W_scale[0].weight, 0.0)
81
+ torch.nn.init.constant_(self.W_scale[2].weight, 0.0)
82
+ torch.nn.init.constant_(self.W_scale[0].bias, 1.0)
83
+ torch.nn.init.constant_(self.W_scale[2].bias, 1.0)
84
+ torch.nn.init.constant_(self.W_bias[0].weight, 0.0)
85
+ torch.nn.init.constant_(self.W_bias[2].weight, 0.0)
86
+ torch.nn.init.constant_(self.W_bias[0].bias, 0.0)
87
+ torch.nn.init.constant_(self.W_bias[2].bias, 0.0)
88
+
89
+ def forward(self, packed_input):
90
+ x, speaker_embedding = packed_input
91
+ mean = x.mean(dim=-1, keepdim=True)
92
+ var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
93
+ scale = self.W_scale(speaker_embedding)
94
+ bias = self.W_bias(speaker_embedding)
95
+
96
+ y = scale.unsqueeze(1) * ((x - mean) / var) + bias.unsqueeze(1)
97
+
98
+ return y
99
+
100
+
101
+ class AdaIN1d(nn.Module):
102
+ """
103
+ MIT Licensed
104
+
105
+ Copyright (c) 2022 Aaron (Yinghao) Li
106
+ https://github.com/yl4579/StyleTTS/blob/main/models.py
107
+ """
108
+
109
+ def __init__(self, style_dim, num_features):
110
+ super().__init__()
111
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
112
+ self.fc = nn.Linear(style_dim, num_features * 2)
113
+
114
+ def forward(self, x, s):
115
+ h = self.fc(s)
116
+ h = h.view(h.size(0), h.size(1), 1)
117
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
118
+ return (1 + gamma.transpose(1, 2)) * self.norm(x.transpose(1, 2)).transpose(1, 2) + beta.transpose(1, 2)
Architectures/GeneralLayers/Conformer.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet, but heavily modified
3
+ """
4
+
5
+ import torch
6
+
7
+ from Architectures.GeneralLayers.Attention import RelPositionMultiHeadedAttention
8
+ from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d
9
+ from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
10
+ from Architectures.GeneralLayers.Convolution import ConvolutionModule
11
+ from Architectures.GeneralLayers.EncoderLayer import EncoderLayer
12
+ from Architectures.GeneralLayers.LayerNorm import LayerNorm
13
+ from Architectures.GeneralLayers.MultiLayeredConv1d import MultiLayeredConv1d
14
+ from Architectures.GeneralLayers.MultiSequential import repeat
15
+ from Architectures.GeneralLayers.PositionalEncoding import RelPositionalEncoding
16
+ from Architectures.GeneralLayers.Swish import Swish
17
+ from Utility.utils import integrate_with_utt_embed
18
+
19
+
20
+ class Conformer(torch.nn.Module):
21
+ """
22
+ Conformer encoder module.
23
+
24
+ Args:
25
+ idim (int): Input dimension.
26
+ attention_dim (int): Dimension of attention.
27
+ attention_heads (int): The number of heads of multi head attention.
28
+ linear_units (int): The number of units of position-wise feed forward.
29
+ num_blocks (int): The number of decoder blocks.
30
+ dropout_rate (float): Dropout rate.
31
+ positional_dropout_rate (float): Dropout rate after adding positional encoding.
32
+ attention_dropout_rate (float): Dropout rate in attention.
33
+ input_layer (Union[str, torch.nn.Module]): Input layer type.
34
+ normalize_before (bool): Whether to use layer_norm before the first block.
35
+ concat_after (bool): Whether to concat attention layer's input and output.
36
+ if True, additional linear will be applied.
37
+ i.e. x -> x + linear(concat(x, att(x)))
38
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
39
+ positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
40
+ macaron_style (bool): Whether to use macaron style for positionwise layer.
41
+ use_cnn_module (bool): Whether to use convolution module.
42
+ cnn_module_kernel (int): Kernel size of convolution module.
43
+
44
+ """
45
+
46
+ def __init__(self, conformer_type, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1,
47
+ attention_dropout_rate=0.0, input_layer="conv2d", normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1,
48
+ macaron_style=False, use_cnn_module=False, cnn_module_kernel=31, zero_triu=False, utt_embed=None, lang_embs=None, lang_emb_size=16, use_output_norm=True, embedding_integration="AdaIN"):
49
+ super(Conformer, self).__init__()
50
+
51
+ activation = Swish()
52
+ self.conv_subsampling_factor = 1
53
+ self.use_output_norm = use_output_norm
54
+
55
+ if isinstance(input_layer, torch.nn.Module):
56
+ self.embed = input_layer
57
+ self.art_embed_norm = LayerNorm(attention_dim)
58
+ self.pos_enc = RelPositionalEncoding(attention_dim, positional_dropout_rate)
59
+ elif input_layer is None:
60
+ self.embed = None
61
+ self.pos_enc = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate))
62
+ else:
63
+ raise ValueError("unknown input_layer: " + input_layer)
64
+
65
+ if self.use_output_norm:
66
+ self.output_norm = LayerNorm(attention_dim)
67
+ self.utt_embed = utt_embed
68
+ self.conformer_type = conformer_type
69
+ self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"]
70
+ if utt_embed is not None:
71
+ if conformer_type == "encoder": # the encoder gets an additional conditioning signal added to its output
72
+ if embedding_integration == "AdaIN":
73
+ self.encoder_embedding_projection = AdaIN1d(style_dim=utt_embed, num_features=attention_dim)
74
+ elif embedding_integration == "ConditionalLayerNorm":
75
+ self.encoder_embedding_projection = ConditionalLayerNorm(speaker_embedding_dim=utt_embed, hidden_dim=attention_dim)
76
+ else:
77
+ self.encoder_embedding_projection = torch.nn.Linear(attention_dim + utt_embed, attention_dim)
78
+ else:
79
+ if embedding_integration == "AdaIN":
80
+ self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: AdaIN1d(style_dim=utt_embed, num_features=attention_dim))
81
+ elif embedding_integration == "ConditionalLayerNorm":
82
+ self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: ConditionalLayerNorm(speaker_embedding_dim=utt_embed, hidden_dim=attention_dim))
83
+ else:
84
+ self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: torch.nn.Linear(attention_dim + utt_embed, attention_dim))
85
+ if lang_embs is not None:
86
+ self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=lang_emb_size)
87
+ self.language_embedding_projection = torch.nn.Linear(lang_emb_size, attention_dim)
88
+ self.language_emb_norm = LayerNorm(attention_dim)
89
+ # self-attention module definition
90
+ encoder_selfattn_layer = RelPositionMultiHeadedAttention
91
+ encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu)
92
+
93
+ # feed-forward module definition
94
+ positionwise_layer = MultiLayeredConv1d
95
+ positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate,)
96
+
97
+ # convolution module definition
98
+ convolution_layer = ConvolutionModule
99
+ convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
100
+
101
+ self.encoders = repeat(num_blocks, lambda lnum: EncoderLayer(attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args),
102
+ positionwise_layer(*positionwise_layer_args),
103
+ positionwise_layer(*positionwise_layer_args) if macaron_style else None,
104
+ convolution_layer(*convolution_layer_args) if use_cnn_module else None, dropout_rate,
105
+ normalize_before, concat_after))
106
+
107
+ def forward(self,
108
+ xs,
109
+ masks,
110
+ utterance_embedding=None,
111
+ lang_ids=None):
112
+ """
113
+ Encode input sequence.
114
+ Args:
115
+ utterance_embedding: embedding containing lots of conditioning signals
116
+ lang_ids: ids of the languages per sample in the batch
117
+ xs (torch.Tensor): Input tensor (#batch, time, idim).
118
+ masks (torch.Tensor): Mask tensor (#batch, time).
119
+ Returns:
120
+ torch.Tensor: Output tensor (#batch, time, attention_dim).
121
+ torch.Tensor: Mask tensor (#batch, time).
122
+ """
123
+
124
+ if self.embed is not None:
125
+ xs = self.embed(xs)
126
+ xs = self.art_embed_norm(xs)
127
+
128
+ if lang_ids is not None:
129
+ lang_embs = self.language_embedding(lang_ids)
130
+ projected_lang_embs = self.language_embedding_projection(lang_embs).unsqueeze(-1).transpose(1, 2)
131
+ projected_lang_embs = self.language_emb_norm(projected_lang_embs)
132
+ xs = xs + projected_lang_embs # offset phoneme representation by language specific offset
133
+
134
+ xs = self.pos_enc(xs)
135
+
136
+ for encoder_index, encoder in enumerate(self.encoders):
137
+ if self.utt_embed:
138
+ if isinstance(xs, tuple):
139
+ x, pos_emb = xs[0], xs[1]
140
+ if self.conformer_type != "encoder":
141
+ x = integrate_with_utt_embed(hs=x, utt_embeddings=utterance_embedding, projection=self.decoder_embedding_projections[encoder_index], embedding_training=self.use_conditional_layernorm_embedding_integration)
142
+ xs = (x, pos_emb)
143
+ else:
144
+ if self.conformer_type != "encoder":
145
+ xs = integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding, projection=self.decoder_embedding_projections[encoder_index], embedding_training=self.use_conditional_layernorm_embedding_integration)
146
+ xs, masks = encoder(xs, masks)
147
+
148
+ if isinstance(xs, tuple):
149
+ xs = xs[0]
150
+
151
+ if self.use_output_norm and not (self.utt_embed and self.conformer_type == "encoder"):
152
+ xs = self.output_norm(xs)
153
+
154
+ if self.utt_embed and self.conformer_type == "encoder":
155
+ xs = integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding,
156
+ projection=self.encoder_embedding_projection, embedding_training=self.use_conditional_layernorm_embedding_integration)
157
+
158
+ return xs, masks
Architectures/GeneralLayers/Convolution.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # Northwestern Polytechnical University (Pengcheng Guo)
3
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4
+ # Adapted by Florian Lux 2021
5
+
6
+
7
+ from torch import nn
8
+
9
+
10
+ class ConvolutionModule(nn.Module):
11
+ """
12
+ ConvolutionModule in Conformer model.
13
+
14
+ Args:
15
+ channels (int): The number of channels of conv layers.
16
+ kernel_size (int): Kernel size of conv layers.
17
+
18
+ """
19
+
20
+ def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
21
+ super(ConvolutionModule, self).__init__()
22
+ # kernel_size should be an odd number for 'SAME' padding
23
+ assert (kernel_size - 1) % 2 == 0
24
+
25
+ self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, )
26
+ self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, )
27
+ self.norm = nn.BatchNorm1d(channels)
28
+ self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, )
29
+ self.activation = activation
30
+
31
+ def forward(self, x):
32
+ """
33
+ Compute convolution module.
34
+
35
+ Args:
36
+ x (torch.Tensor): Input tensor (#batch, time, channels).
37
+
38
+ Returns:
39
+ torch.Tensor: Output tensor (#batch, time, channels).
40
+
41
+ """
42
+ # exchange the temporal dimension and the feature dimension
43
+ x = x.transpose(1, 2)
44
+
45
+ # GLU mechanism
46
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
47
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
48
+
49
+ # 1D Depthwise Conv
50
+ x = self.depthwise_conv(x)
51
+ x = self.activation(self.norm(x))
52
+
53
+ x = self.pointwise_conv2(x)
54
+
55
+ return x.transpose(1, 2)
Architectures/GeneralLayers/DurationPredictor.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+
6
+ import torch
7
+
8
+ from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d
9
+ from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
10
+ from Architectures.GeneralLayers.LayerNorm import LayerNorm
11
+ from Utility.utils import integrate_with_utt_embed
12
+
13
+
14
+ class DurationPredictor(torch.nn.Module):
15
+ """
16
+ Duration predictor module.
17
+
18
+ This is a module of duration predictor described
19
+ in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
20
+ The duration predictor predicts a duration of each frame in log domain
21
+ from the hidden embeddings of encoder.
22
+
23
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
24
+ https://arxiv.org/pdf/1905.09263.pdf
25
+
26
+ Note:
27
+ The calculation domain of outputs is different
28
+ between in `forward` and in `inference`. In `forward`,
29
+ the outputs are calculated in log domain but in `inference`,
30
+ those are calculated in linear domain.
31
+
32
+ """
33
+
34
+ def __init__(self, idim,
35
+ n_layers=2,
36
+ n_chans=384,
37
+ kernel_size=3,
38
+ dropout_rate=0.1,
39
+ offset=1.0,
40
+ utt_embed_dim=None,
41
+ embedding_integration="AdaIN"):
42
+ """
43
+ Initialize duration predictor module.
44
+
45
+ Args:
46
+ idim (int): Input dimension.
47
+ n_layers (int, optional): Number of convolutional layers.
48
+ n_chans (int, optional): Number of channels of convolutional layers.
49
+ kernel_size (int, optional): Kernel size of convolutional layers.
50
+ dropout_rate (float, optional): Dropout rate.
51
+ offset (float, optional): Offset value to avoid nan in log domain.
52
+
53
+ """
54
+ super(DurationPredictor, self).__init__()
55
+ self.offset = offset
56
+ self.conv = torch.nn.ModuleList()
57
+ self.dropouts = torch.nn.ModuleList()
58
+ self.norms = torch.nn.ModuleList()
59
+ self.embedding_projections = torch.nn.ModuleList()
60
+ self.utt_embed_dim = utt_embed_dim
61
+ self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"]
62
+
63
+ for idx in range(n_layers):
64
+ if utt_embed_dim is not None:
65
+ if embedding_integration == "AdaIN":
66
+ self.embedding_projections += [AdaIN1d(style_dim=utt_embed_dim, num_features=idim)]
67
+ elif embedding_integration == "ConditionalLayerNorm":
68
+ self.embedding_projections += [ConditionalLayerNorm(speaker_embedding_dim=utt_embed_dim, hidden_dim=idim)]
69
+ else:
70
+ self.embedding_projections += [torch.nn.Linear(utt_embed_dim + idim, idim)]
71
+ else:
72
+ self.embedding_projections += [lambda x: x]
73
+ in_chans = idim if idx == 0 else n_chans
74
+ self.conv += [torch.nn.Sequential(torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ),
75
+ torch.nn.ReLU())]
76
+ self.norms += [LayerNorm(n_chans, dim=1)]
77
+ self.dropouts += [torch.nn.Dropout(dropout_rate)]
78
+
79
+ self.linear = torch.nn.Linear(n_chans, 1)
80
+
81
+ def _forward(self, xs, x_masks=None, is_inference=False, utt_embed=None):
82
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
83
+
84
+ for f, c, d, p in zip(self.conv, self.norms, self.dropouts, self.embedding_projections):
85
+ xs = f(xs) # (B, C, Tmax)
86
+ if self.utt_embed_dim is not None:
87
+ xs = integrate_with_utt_embed(hs=xs.transpose(1, 2), utt_embeddings=utt_embed, projection=p, embedding_training=self.use_conditional_layernorm_embedding_integration).transpose(1, 2)
88
+ xs = c(xs)
89
+ xs = d(xs)
90
+
91
+ # NOTE: targets are transformed to log domain in the loss calculation, so this will learn to predict in the log space, which makes the value range easier to handle.
92
+ xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax)
93
+
94
+ if is_inference:
95
+ # NOTE: since we learned to predict in the log domain, we have to invert the log during inference.
96
+ xs = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
97
+ else:
98
+ xs = xs.masked_fill(x_masks, 0.0)
99
+
100
+ return xs
101
+
102
+ def forward(self, xs, padding_mask=None, utt_embed=None):
103
+ """
104
+ Calculate forward propagation.
105
+
106
+ Args:
107
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
108
+ padding_mask (ByteTensor, optional):
109
+ Batch of masks indicating padded part (B, Tmax).
110
+
111
+ Returns:
112
+ Tensor: Batch of predicted durations in log domain (B, Tmax).
113
+
114
+ """
115
+ return self._forward(xs, padding_mask, False, utt_embed=utt_embed)
116
+
117
+ def inference(self, xs, padding_mask=None, utt_embed=None):
118
+ """
119
+ Inference duration.
120
+
121
+ Args:
122
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
123
+ padding_mask (ByteTensor, optional):
124
+ Batch of masks indicating padded part (B, Tmax).
125
+
126
+ Returns:
127
+ LongTensor: Batch of predicted durations in linear domain (B, Tmax).
128
+
129
+ """
130
+ return self._forward(xs, padding_mask, True, utt_embed=utt_embed)
131
+
132
+
133
+ class DurationPredictorLoss(torch.nn.Module):
134
+ """
135
+ Loss function module for duration predictor.
136
+
137
+ The loss value is Calculated in log domain to make it Gaussian.
138
+
139
+ """
140
+
141
+ def __init__(self, offset=1.0, reduction="mean"):
142
+ """
143
+ Args:
144
+ offset (float, optional): Offset value to avoid nan in log domain.
145
+ reduction (str): Reduction type in loss calculation.
146
+
147
+ """
148
+ super(DurationPredictorLoss, self).__init__()
149
+ self.criterion = torch.nn.MSELoss(reduction=reduction)
150
+ self.offset = offset
151
+
152
+ def forward(self, outputs, targets):
153
+ """
154
+ Calculate forward propagation.
155
+
156
+ Args:
157
+ outputs (Tensor): Batch of prediction durations in log domain (B, T)
158
+ targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
159
+
160
+ Returns:
161
+ Tensor: Mean squared error loss value.
162
+
163
+ Note:
164
+ `outputs` is in log domain but `targets` is in linear domain.
165
+
166
+ """
167
+ # NOTE: outputs is in log domain while targets in linear
168
+ targets = torch.log(targets.float() + self.offset)
169
+ loss = self.criterion(outputs, targets)
170
+
171
+ return loss
Architectures/GeneralLayers/EncoderLayer.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # Northwestern Polytechnical University (Pengcheng Guo)
3
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4
+ # Adapted by Florian Lux 2021
5
+
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from Architectures.GeneralLayers.LayerNorm import LayerNorm
11
+
12
+
13
+ class EncoderLayer(nn.Module):
14
+ """
15
+ Encoder layer module.
16
+
17
+ Args:
18
+ size (int): Input dimension.
19
+ self_attn (torch.nn.Module): Self-attention module instance.
20
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
21
+ can be used as the argument.
22
+ feed_forward (torch.nn.Module): Feed-forward module instance.
23
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
24
+ can be used as the argument.
25
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
26
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
27
+ can be used as the argument.
28
+ conv_module (torch.nn.Module): Convolution module instance.
29
+ `ConvlutionModule` instance can be used as the argument.
30
+ dropout_rate (float): Dropout rate.
31
+ normalize_before (bool): Whether to use layer_norm before the first block.
32
+ concat_after (bool): Whether to concat attention layer's input and output.
33
+ if True, additional linear will be applied.
34
+ i.e. x -> x + linear(concat(x, att(x)))
35
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
36
+
37
+ """
38
+
39
+ def __init__(self, size, self_attn, feed_forward, feed_forward_macaron, conv_module, dropout_rate, normalize_before=True, concat_after=False, ):
40
+ super(EncoderLayer, self).__init__()
41
+ self.self_attn = self_attn
42
+ self.feed_forward = feed_forward
43
+ self.feed_forward_macaron = feed_forward_macaron
44
+ self.conv_module = conv_module
45
+ self.norm_ff = LayerNorm(size) # for the FNN module
46
+ self.norm_mha = LayerNorm(size) # for the MHA module
47
+ if feed_forward_macaron is not None:
48
+ self.norm_ff_macaron = LayerNorm(size)
49
+ self.ff_scale = 0.5
50
+ else:
51
+ self.ff_scale = 1.0
52
+ if self.conv_module is not None:
53
+ self.norm_conv = LayerNorm(size) # for the CNN module
54
+ self.norm_final = LayerNorm(size) # for the final output of the block
55
+ self.dropout = nn.Dropout(dropout_rate)
56
+ self.size = size
57
+ self.normalize_before = normalize_before
58
+ self.concat_after = concat_after
59
+ if self.concat_after:
60
+ self.concat_linear = nn.Linear(size + size, size)
61
+
62
+ def forward(self, x_input, mask, cache=None):
63
+ """
64
+ Compute encoded features.
65
+
66
+ Args:
67
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
68
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
69
+ - w/o pos emb: Tensor (#batch, time, size).
70
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
71
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
72
+
73
+ Returns:
74
+ torch.Tensor: Output tensor (#batch, time, size).
75
+ torch.Tensor: Mask tensor (#batch, time).
76
+
77
+ """
78
+ if isinstance(x_input, tuple):
79
+ x, pos_emb = x_input[0], x_input[1]
80
+ else:
81
+ x, pos_emb = x_input, None
82
+
83
+ # whether to use macaron style
84
+ if self.feed_forward_macaron is not None:
85
+ residual = x
86
+ if self.normalize_before:
87
+ x = self.norm_ff_macaron(x)
88
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
89
+ if not self.normalize_before:
90
+ x = self.norm_ff_macaron(x)
91
+
92
+ # multi-headed self-attention module
93
+ residual = x
94
+ if self.normalize_before:
95
+ x = self.norm_mha(x)
96
+
97
+ if cache is None:
98
+ x_q = x
99
+ else:
100
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
101
+ x_q = x[:, -1:, :]
102
+ residual = residual[:, -1:, :]
103
+ mask = None if mask is None else mask[:, -1:, :]
104
+
105
+ if pos_emb is not None:
106
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
107
+ else:
108
+ x_att = self.self_attn(x_q, x, x, mask)
109
+
110
+ if self.concat_after:
111
+ x_concat = torch.cat((x, x_att), dim=-1)
112
+ x = residual + self.concat_linear(x_concat)
113
+ else:
114
+ x = residual + self.dropout(x_att)
115
+ if not self.normalize_before:
116
+ x = self.norm_mha(x)
117
+
118
+ # convolution module
119
+ if self.conv_module is not None:
120
+ residual = x
121
+ if self.normalize_before:
122
+ x = self.norm_conv(x)
123
+ x = residual + self.dropout(self.conv_module(x))
124
+ if not self.normalize_before:
125
+ x = self.norm_conv(x)
126
+
127
+ # feed forward module
128
+ residual = x
129
+ if self.normalize_before:
130
+ x = self.norm_ff(x)
131
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
132
+ if not self.normalize_before:
133
+ x = self.norm_ff(x)
134
+
135
+ if self.conv_module is not None:
136
+ x = self.norm_final(x)
137
+
138
+ if cache is not None:
139
+ x = torch.cat([cache, x], dim=1)
140
+
141
+ if pos_emb is not None:
142
+ return (x, pos_emb), mask
143
+
144
+ return x, mask
Architectures/GeneralLayers/LayerNorm.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Shigeki Karita, 2019
2
+ # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux, 2021
4
+
5
+ import torch
6
+
7
+
8
+ class LayerNorm(torch.nn.LayerNorm):
9
+ """
10
+ Layer normalization module.
11
+
12
+ Args:
13
+ nout (int): Output dim size.
14
+ dim (int): Dimension to be normalized.
15
+ """
16
+
17
+ def __init__(self, nout, dim=-1, eps=1e-12):
18
+ """
19
+ Construct an LayerNorm object.
20
+ """
21
+ super(LayerNorm, self).__init__(nout, eps=eps)
22
+ self.dim = dim
23
+
24
+ def forward(self, x):
25
+ """
26
+ Apply layer normalization.
27
+
28
+ Args:
29
+ x (torch.Tensor): Input tensor.
30
+
31
+ Returns:
32
+ torch.Tensor: Normalized tensor.
33
+ """
34
+ if self.dim == -1:
35
+ return super(LayerNorm, self).forward(x)
36
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
Architectures/GeneralLayers/LengthRegulator.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ from abc import ABC
6
+
7
+ import torch
8
+
9
+ from Utility.utils import pad_list
10
+
11
+
12
+ class LengthRegulator(torch.nn.Module, ABC):
13
+ """
14
+ Length regulator module for feed-forward Transformer.
15
+
16
+ This is a module of length regulator described in
17
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
18
+ The length regulator expands char or
19
+ phoneme-level embedding features to frame-level by repeating each
20
+ feature based on the corresponding predicted durations.
21
+
22
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
23
+ https://arxiv.org/pdf/1905.09263.pdf
24
+
25
+ """
26
+
27
+ def __init__(self, pad_value=0.0):
28
+ """
29
+ Initialize length regulator module.
30
+
31
+ Args:
32
+ pad_value (float, optional): Value used for padding.
33
+ """
34
+ super(LengthRegulator, self).__init__()
35
+ self.pad_value = pad_value
36
+
37
+ def forward(self, xs, ds, alpha=1.0):
38
+ """
39
+ Calculate forward propagation.
40
+ Args:
41
+ xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D).
42
+ ds (LongTensor): Batch of durations of each frame (B, T).
43
+ alpha (float, optional): Alpha value to control speed of speech.
44
+ Returns:
45
+ Tensor: replicated input tensor based on durations (B, T*, D).
46
+ """
47
+
48
+ if alpha != 1.0:
49
+ assert alpha > 0
50
+ ds = torch.round(ds.float() * alpha).long()
51
+
52
+ if ds.sum() == 0:
53
+ ds[ds.sum(dim=1).eq(0)] = 1
54
+
55
+ return pad_list([self._repeat_one_sequence(x, d) for x, d in zip(xs, ds)], self.pad_value)
56
+
57
+ def _repeat_one_sequence(self, x, d):
58
+ """
59
+ Repeat each frame according to duration
60
+ """
61
+ return torch.repeat_interleave(x, d, dim=0)
Architectures/GeneralLayers/MultiLayeredConv1d.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ """
6
+ Layer modules for FFT block in FastSpeech (Feed-forward Transformer).
7
+ """
8
+
9
+ import torch
10
+
11
+
12
+ class MultiLayeredConv1d(torch.nn.Module):
13
+ """
14
+ Multi-layered conv1d for Transformer block.
15
+
16
+ This is a module of multi-layered conv1d designed
17
+ to replace positionwise feed-forward network
18
+ in Transformer block, which is introduced in
19
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
20
+
21
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
22
+ https://arxiv.org/pdf/1905.09263.pdf
23
+ """
24
+
25
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
26
+ """
27
+ Initialize MultiLayeredConv1d module.
28
+
29
+ Args:
30
+ in_chans (int): Number of input channels.
31
+ hidden_chans (int): Number of hidden channels.
32
+ kernel_size (int): Kernel size of conv1d.
33
+ dropout_rate (float): Dropout rate.
34
+ """
35
+ super(MultiLayeredConv1d, self).__init__()
36
+ self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
37
+ self.w_2 = torch.nn.Conv1d(hidden_chans, in_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
38
+ self.dropout = torch.nn.Dropout(dropout_rate)
39
+
40
+ def forward(self, x):
41
+ """
42
+ Calculate forward propagation.
43
+
44
+ Args:
45
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
46
+
47
+ Returns:
48
+ torch.Tensor: Batch of output tensors (B, T, hidden_chans).
49
+ """
50
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
51
+ return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
52
+
53
+
54
+ class Conv1dLinear(torch.nn.Module):
55
+ """
56
+ Conv1D + Linear for Transformer block.
57
+
58
+ A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
59
+ """
60
+
61
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
62
+ """
63
+ Initialize Conv1dLinear module.
64
+
65
+ Args:
66
+ in_chans (int): Number of input channels.
67
+ hidden_chans (int): Number of hidden channels.
68
+ kernel_size (int): Kernel size of conv1d.
69
+ dropout_rate (float): Dropout rate.
70
+ """
71
+ super(Conv1dLinear, self).__init__()
72
+ self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, )
73
+ self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
74
+ self.dropout = torch.nn.Dropout(dropout_rate)
75
+
76
+ def forward(self, x):
77
+ """
78
+ Calculate forward propagation.
79
+
80
+ Args:
81
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
82
+
83
+ Returns:
84
+ torch.Tensor: Batch of output tensors (B, T, hidden_chans).
85
+ """
86
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
87
+ return self.w_2(self.dropout(x))
Architectures/GeneralLayers/MultiSequential.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Shigeki Karita, 2019
2
+ # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux, 2021
4
+
5
+ import torch
6
+
7
+
8
+ class MultiSequential(torch.nn.Sequential):
9
+ """
10
+ Multi-input multi-output torch.nn.Sequential.
11
+ """
12
+
13
+ def forward(self, *args):
14
+ """
15
+ Repeat.
16
+ """
17
+ for m in self:
18
+ args = m(*args)
19
+ return args
20
+
21
+
22
+ def repeat(N, fn):
23
+ """
24
+ Repeat module N times.
25
+
26
+ Args:
27
+ N (int): Number of repeat time.
28
+ fn (Callable): Function to generate module.
29
+
30
+ Returns:
31
+ MultiSequential: Repeated model instance.
32
+ """
33
+ return MultiSequential(*[fn(n) for n in range(N)])
Architectures/GeneralLayers/PositionalEncoding.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet
3
+ """
4
+
5
+ import math
6
+
7
+ import torch
8
+
9
+
10
+ class PositionalEncoding(torch.nn.Module):
11
+ """
12
+ Positional encoding.
13
+
14
+ Args:
15
+ d_model (int): Embedding dimension.
16
+ dropout_rate (float): Dropout rate.
17
+ max_len (int): Maximum input length.
18
+ reverse (bool): Whether to reverse the input position.
19
+ """
20
+
21
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
22
+ """
23
+ Construct an PositionalEncoding object.
24
+ """
25
+ super(PositionalEncoding, self).__init__()
26
+ self.d_model = d_model
27
+ self.reverse = reverse
28
+ self.xscale = math.sqrt(self.d_model)
29
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
30
+ self.pe = None
31
+ self.extend_pe(torch.tensor(0.0, device=d_model.device).expand(1, max_len))
32
+
33
+ def extend_pe(self, x):
34
+ """
35
+ Reset the positional encodings.
36
+ """
37
+ if self.pe is not None:
38
+ if self.pe.size(1) >= x.size(1):
39
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
40
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
41
+ return
42
+ pe = torch.zeros(x.size(1), self.d_model)
43
+ if self.reverse:
44
+ position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
45
+ else:
46
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
47
+ div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model))
48
+ pe[:, 0::2] = torch.sin(position * div_term)
49
+ pe[:, 1::2] = torch.cos(position * div_term)
50
+ pe = pe.unsqueeze(0)
51
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
52
+
53
+ def forward(self, x):
54
+ """
55
+ Add positional encoding.
56
+
57
+ Args:
58
+ x (torch.Tensor): Input tensor (batch, time, `*`).
59
+
60
+ Returns:
61
+ torch.Tensor: Encoded tensor (batch, time, `*`).
62
+ """
63
+ self.extend_pe(x)
64
+ x = x * self.xscale + self.pe[:, : x.size(1)]
65
+ return self.dropout(x)
66
+
67
+
68
+ class RelPositionalEncoding(torch.nn.Module):
69
+ """
70
+ Relative positional encoding module (new implementation).
71
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
72
+ See : Appendix B in https://arxiv.org/abs/1901.02860
73
+ Args:
74
+ d_model (int): Embedding dimension.
75
+ dropout_rate (float): Dropout rate.
76
+ max_len (int): Maximum input length.
77
+ """
78
+
79
+ def __init__(self, d_model, dropout_rate, max_len=5000):
80
+ """
81
+ Construct an PositionalEncoding object.
82
+ """
83
+ super(RelPositionalEncoding, self).__init__()
84
+ self.d_model = d_model
85
+ self.xscale = math.sqrt(self.d_model)
86
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
87
+ self.pe = None
88
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
89
+
90
+ def extend_pe(self, x):
91
+ """Reset the positional encodings."""
92
+ if self.pe is not None:
93
+ # self.pe contains both positive and negative parts
94
+ # the length of self.pe is 2 * input_len - 1
95
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
96
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
97
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
98
+ return
99
+ # Suppose `i` means to the position of query vecotr and `j` means the
100
+ # position of key vector. We use position relative positions when keys
101
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
102
+ pe_positive = torch.zeros(x.size(1), self.d_model, device=x.device)
103
+ pe_negative = torch.zeros(x.size(1), self.d_model, device=x.device)
104
+ position = torch.arange(0, x.size(1), dtype=torch.float32, device=x.device).unsqueeze(1)
105
+ div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32, device=x.device) * -(math.log(10000.0) / self.d_model))
106
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
107
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
108
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
109
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
110
+
111
+ # Reserve the order of positive indices and concat both positive and
112
+ # negative indices. This is used to support the shifting trick
113
+ # as in https://arxiv.org/abs/1901.02860
114
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
115
+ pe_negative = pe_negative[1:].unsqueeze(0)
116
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
117
+ self.pe = pe.to(dtype=x.dtype)
118
+
119
+ def forward(self, x):
120
+ """
121
+ Add positional encoding.
122
+ Args:
123
+ x (torch.Tensor): Input tensor (batch, time, `*`).
124
+ Returns:
125
+ torch.Tensor: Encoded tensor (batch, time, `*`).
126
+ """
127
+ self.extend_pe(x)
128
+ x = x * self.xscale
129
+ pos_emb = self.pe[:, self.pe.size(1) // 2 - x.size(1) + 1: self.pe.size(1) // 2 + x.size(1), ]
130
+ return self.dropout(x), self.dropout(pos_emb)
131
+
132
+
133
+ class ScaledPositionalEncoding(PositionalEncoding):
134
+ """
135
+ Scaled positional encoding module.
136
+
137
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
138
+
139
+ Args:
140
+ d_model (int): Embedding dimension.
141
+ dropout_rate (float): Dropout rate.
142
+ max_len (int): Maximum input length.
143
+
144
+ """
145
+
146
+ def __init__(self, d_model, dropout_rate, max_len=5000):
147
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
148
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
149
+
150
+ def reset_parameters(self):
151
+ self.alpha.data = torch.tensor(1.0)
152
+
153
+ def forward(self, x):
154
+ """
155
+ Add positional encoding.
156
+
157
+ Args:
158
+ x (torch.Tensor): Input tensor (batch, time, `*`).
159
+
160
+ Returns:
161
+ torch.Tensor: Encoded tensor (batch, time, `*`).
162
+
163
+ """
164
+ self.extend_pe(x)
165
+ x = x + self.alpha * self.pe[:, : x.size(1)]
166
+ return self.dropout(x)
Architectures/GeneralLayers/PositionwiseFeedForward.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Written by Shigeki Karita, 2019
2
+ # Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux, 2021
4
+
5
+
6
+ import torch
7
+
8
+
9
+ class PositionwiseFeedForward(torch.nn.Module):
10
+ """
11
+ Args:
12
+ idim (int): Input dimenstion.
13
+ hidden_units (int): The number of hidden units.
14
+ dropout_rate (float): Dropout rate.
15
+
16
+ """
17
+
18
+ def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
19
+ super(PositionwiseFeedForward, self).__init__()
20
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
21
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
22
+ self.dropout = torch.nn.Dropout(dropout_rate)
23
+ self.activation = activation
24
+
25
+ def forward(self, x):
26
+ return self.w_2(self.dropout(self.activation(self.w_1(x))))
Architectures/GeneralLayers/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ This directory contains a collection of layers that are used both during training time and during inference time. Large
2
+ portions of these layers are either directly taken from ESPnet or adaptations of such.
Architectures/GeneralLayers/ResidualBlock.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ References:
5
+ - https://github.com/jik876/hifi-gan
6
+ - https://github.com/kan-bayashi/ParallelWaveGAN
7
+ """
8
+
9
+ import torch
10
+
11
+
12
+ class Conv1d(torch.nn.Conv1d):
13
+ """
14
+ Conv1d module with customized initialization.
15
+ """
16
+
17
+ def __init__(self, *args, **kwargs):
18
+ super(Conv1d, self).__init__(*args, **kwargs)
19
+
20
+ def reset_parameters(self):
21
+ torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
22
+ if self.bias is not None:
23
+ torch.nn.init.constant_(self.bias, 0.0)
24
+
25
+
26
+ class Conv1d1x1(Conv1d):
27
+ """
28
+ 1x1 Conv1d with customized initialization.
29
+ """
30
+
31
+ def __init__(self, in_channels, out_channels, bias):
32
+ super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias)
33
+
34
+
35
+ class HiFiGANResidualBlock(torch.nn.Module):
36
+ """Residual block module in HiFiGAN."""
37
+
38
+ def __init__(self,
39
+ kernel_size=3,
40
+ channels=512,
41
+ dilations=(1, 3, 5),
42
+ bias=True,
43
+ use_additional_convs=True,
44
+ nonlinear_activation="LeakyReLU",
45
+ nonlinear_activation_params={"negative_slope": 0.1}, ):
46
+ """
47
+ Initialize HiFiGANResidualBlock module.
48
+
49
+ Args:
50
+ kernel_size (int): Kernel size of dilation convolution layer.
51
+ channels (int): Number of channels for convolution layer.
52
+ dilations (List[int]): List of dilation factors.
53
+ use_additional_convs (bool): Whether to use additional convolution layers.
54
+ bias (bool): Whether to add bias parameter in convolution layers.
55
+ nonlinear_activation (str): Activation function module name.
56
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
57
+ """
58
+ super().__init__()
59
+ self.use_additional_convs = use_additional_convs
60
+ self.convs1 = torch.nn.ModuleList()
61
+ if use_additional_convs:
62
+ self.convs2 = torch.nn.ModuleList()
63
+ assert kernel_size % 2 == 1, "Kernel size must be odd number."
64
+ for dilation in dilations:
65
+ self.convs1 += [torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
66
+ torch.nn.Conv1d(channels,
67
+ channels,
68
+ kernel_size,
69
+ 1,
70
+ dilation=dilation,
71
+ bias=bias,
72
+ padding=(kernel_size - 1) // 2 * dilation, ), )]
73
+ if use_additional_convs:
74
+ self.convs2 += [torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
75
+ torch.nn.Conv1d(channels,
76
+ channels,
77
+ kernel_size,
78
+ 1,
79
+ dilation=1,
80
+ bias=bias,
81
+ padding=(kernel_size - 1) // 2, ), )]
82
+
83
+ def forward(self, x):
84
+ """
85
+ Calculate forward propagation.
86
+
87
+ Args:
88
+ x (Tensor): Input tensor (B, channels, T).
89
+
90
+ Returns:
91
+ Tensor: Output tensor (B, channels, T).
92
+ """
93
+ for idx in range(len(self.convs1)):
94
+ xt = self.convs1[idx](x)
95
+ if self.use_additional_convs:
96
+ xt = self.convs2[idx](xt)
97
+ x = xt + x
98
+ return x
Architectures/GeneralLayers/ResidualStack.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2021
4
+
5
+
6
+ import torch
7
+
8
+
9
+ class ResidualStack(torch.nn.Module):
10
+
11
+ def __init__(self, kernel_size=3, channels=32, dilation=1, bias=True, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2},
12
+ pad="ReflectionPad1d", pad_params={}, ):
13
+ """
14
+ Initialize ResidualStack module.
15
+
16
+ Args:
17
+ kernel_size (int): Kernel size of dilation convolution layer.
18
+ channels (int): Number of channels of convolution layers.
19
+ dilation (int): Dilation factor.
20
+ bias (bool): Whether to add bias parameter in convolution layers.
21
+ nonlinear_activation (str): Activation function module name.
22
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
23
+ pad (str): Padding function module name before dilated convolution layer.
24
+ pad_params (dict): Hyperparameters for padding function.
25
+
26
+ """
27
+ super(ResidualStack, self).__init__()
28
+
29
+ # defile residual stack part
30
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
31
+ self.stack = torch.nn.Sequential(getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
32
+ getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params),
33
+ torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias),
34
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
35
+ torch.nn.Conv1d(channels, channels, 1, bias=bias), )
36
+
37
+ # defile extra layer for skip connection
38
+ self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias)
39
+
40
+ def forward(self, c):
41
+ """
42
+ Calculate forward propagation.
43
+
44
+ Args:
45
+ c (Tensor): Input tensor (B, channels, T).
46
+
47
+ Returns:
48
+ Tensor: Output tensor (B, chennels, T).
49
+
50
+ """
51
+ return self.stack(c) + self.skip_layer(c)
Architectures/GeneralLayers/STFT.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet
3
+ """
4
+
5
+ import torch
6
+ from torch.functional import stft as torch_stft
7
+ from torch_complex.tensor import ComplexTensor
8
+
9
+ from Utility.utils import make_pad_mask
10
+
11
+
12
+ class STFT(torch.nn.Module):
13
+
14
+ def __init__(self, n_fft=512,
15
+ win_length=None,
16
+ hop_length=128,
17
+ window="hann",
18
+ center=True,
19
+ normalized=False,
20
+ onesided=True):
21
+ super().__init__()
22
+ self.n_fft = n_fft
23
+ if win_length is None:
24
+ self.win_length = n_fft
25
+ else:
26
+ self.win_length = win_length
27
+ self.hop_length = hop_length
28
+ self.center = center
29
+ self.normalized = normalized
30
+ self.onesided = onesided
31
+ self.window = window
32
+
33
+ def extra_repr(self):
34
+ return (f"n_fft={self.n_fft}, "
35
+ f"win_length={self.win_length}, "
36
+ f"hop_length={self.hop_length}, "
37
+ f"center={self.center}, "
38
+ f"normalized={self.normalized}, "
39
+ f"onesided={self.onesided}")
40
+
41
+ def forward(self, input_wave, ilens=None):
42
+ """
43
+ STFT forward function.
44
+ Args:
45
+ input_wave: (Batch, Nsamples) or (Batch, Nsample, Channels)
46
+ ilens: (Batch)
47
+ Returns:
48
+ output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
49
+ """
50
+ bs = input_wave.size(0)
51
+
52
+ if input_wave.dim() == 3:
53
+ multi_channel = True
54
+ # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
55
+ input_wave = input_wave.transpose(1, 2).reshape(-1, input_wave.size(1))
56
+ else:
57
+ multi_channel = False
58
+
59
+ # output: (Batch, Freq, Frames, 2=real_imag)
60
+ # or (Batch, Channel, Freq, Frames, 2=real_imag)
61
+ if self.window is not None:
62
+ window_func = getattr(torch, f"{self.window}_window")
63
+ window = window_func(self.win_length, dtype=input_wave.dtype, device=input_wave.device)
64
+ else:
65
+ window = None
66
+
67
+ complex_output = torch_stft(input=input_wave,
68
+ n_fft=self.n_fft,
69
+ win_length=self.win_length,
70
+ hop_length=self.hop_length,
71
+ center=self.center,
72
+ window=window,
73
+ normalized=self.normalized,
74
+ onesided=self.onesided,
75
+ return_complex=True)
76
+ output = torch.view_as_real(complex_output)
77
+ # output: (Batch, Freq, Frames, 2=real_imag)
78
+ # -> (Batch, Frames, Freq, 2=real_imag)
79
+ output = output.transpose(1, 2)
80
+ if multi_channel:
81
+ # output: (Batch * Channel, Frames, Freq, 2=real_imag)
82
+ # -> (Batch, Frame, Channel, Freq, 2=real_imag)
83
+ output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(1, 2)
84
+
85
+ if ilens is not None:
86
+ if self.center:
87
+ pad = self.win_length // 2
88
+ ilens = ilens + 2 * pad
89
+
90
+ olens = torch.div((ilens - self.win_length), self.hop_length, rounding_mode='trunc') + 1
91
+ output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
92
+ else:
93
+ olens = None
94
+
95
+ return output, olens
96
+
97
+ def inverse(self, input, ilens=None):
98
+ """
99
+ Inverse STFT.
100
+ Args:
101
+ input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
102
+ ilens: (batch,)
103
+ Returns:
104
+ wavs: (batch, samples)
105
+ ilens: (batch,)
106
+ """
107
+ istft = torch.functional.istft
108
+
109
+ if self.window is not None:
110
+ window_func = getattr(torch, f"{self.window}_window")
111
+ window = window_func(self.win_length, dtype=input.dtype, device=input.device)
112
+ else:
113
+ window = None
114
+
115
+ if isinstance(input, ComplexTensor):
116
+ input = torch.stack([input.real, input.imag], dim=-1)
117
+ assert input.shape[-1] == 2
118
+ input = input.transpose(1, 2)
119
+
120
+ wavs = istft(input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=window, center=self.center,
121
+ normalized=self.normalized, onesided=self.onesided, length=ilens.max() if ilens is not None else ilens)
122
+
123
+ return wavs, ilens
Architectures/GeneralLayers/Swish.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # Northwestern Polytechnical University (Pengcheng Guo)
3
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
4
+ # Adapted by Florian Lux 2021
5
+
6
+ import torch
7
+
8
+
9
+ class Swish(torch.nn.Module):
10
+ """
11
+ Construct a Swish activation function for Conformer.
12
+ """
13
+
14
+ def forward(self, x):
15
+ """
16
+ Return Swish activation function.
17
+ """
18
+ return x * torch.sigmoid(x)
Architectures/GeneralLayers/VariancePredictor.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Tomoki Hayashi
2
+ # MIT License (https://opensource.org/licenses/MIT)
3
+ # Adapted by Florian Lux 2023
4
+
5
+ from abc import ABC
6
+
7
+ import torch
8
+
9
+ from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d
10
+ from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
11
+ from Architectures.GeneralLayers.LayerNorm import LayerNorm
12
+ from Utility.utils import integrate_with_utt_embed
13
+
14
+
15
+ class VariancePredictor(torch.nn.Module, ABC):
16
+ """
17
+ Variance predictor module.
18
+
19
+ This is a module of variance predictor described in `FastSpeech 2:
20
+ Fast and High-Quality End-to-End Text to Speech`_.
21
+
22
+ .. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`:
23
+ https://arxiv.org/abs/2006.04558
24
+
25
+ """
26
+
27
+ def __init__(self,
28
+ idim,
29
+ n_layers=2,
30
+ n_chans=384,
31
+ kernel_size=3,
32
+ bias=True,
33
+ dropout_rate=0.5,
34
+ utt_embed_dim=None,
35
+ embedding_integration="AdaIN"):
36
+ """
37
+ Initialize duration predictor module.
38
+
39
+ Args:
40
+ idim (int): Input dimension.
41
+ n_layers (int, optional): Number of convolutional layers.
42
+ n_chans (int, optional): Number of channels of convolutional layers.
43
+ kernel_size (int, optional): Kernel size of convolutional layers.
44
+ dropout_rate (float, optional): Dropout rate.
45
+ """
46
+ super().__init__()
47
+ self.conv = torch.nn.ModuleList()
48
+ self.dropouts = torch.nn.ModuleList()
49
+ self.norms = torch.nn.ModuleList()
50
+ self.embedding_projections = torch.nn.ModuleList()
51
+ self.utt_embed_dim = utt_embed_dim
52
+ self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"]
53
+
54
+ for idx in range(n_layers):
55
+ if utt_embed_dim is not None:
56
+ if embedding_integration == "AdaIN":
57
+ self.embedding_projections += [AdaIN1d(style_dim=utt_embed_dim, num_features=idim)]
58
+ elif embedding_integration == "ConditionalLayerNorm":
59
+ self.embedding_projections += [ConditionalLayerNorm(speaker_embedding_dim=utt_embed_dim, hidden_dim=idim)]
60
+ else:
61
+ self.embedding_projections += [torch.nn.Linear(utt_embed_dim + idim, idim)]
62
+ else:
63
+ self.embedding_projections += [lambda x: x]
64
+ in_chans = idim if idx == 0 else n_chans
65
+ self.conv += [torch.nn.Sequential(torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias, ),
66
+ torch.nn.ReLU())]
67
+ self.norms += [LayerNorm(n_chans, dim=1)]
68
+ self.dropouts += [torch.nn.Dropout(dropout_rate)]
69
+
70
+ self.linear = torch.nn.Linear(n_chans, 1)
71
+
72
+ def forward(self, xs, padding_mask=None, utt_embed=None):
73
+ """
74
+ Calculate forward propagation.
75
+
76
+ Args:
77
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
78
+ padding_mask (ByteTensor, optional):
79
+ Batch of masks indicating padded part (B, Tmax).
80
+
81
+ Returns:
82
+ Tensor: Batch of predicted sequences (B, Tmax, 1).
83
+ """
84
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
85
+
86
+ for f, c, d, p in zip(self.conv, self.norms, self.dropouts, self.embedding_projections):
87
+ xs = f(xs) # (B, C, Tmax)
88
+ if self.utt_embed_dim is not None:
89
+ xs = integrate_with_utt_embed(hs=xs.transpose(1, 2), utt_embeddings=utt_embed, projection=p, embedding_training=self.use_conditional_layernorm_embedding_integration).transpose(1, 2)
90
+ xs = c(xs)
91
+ xs = d(xs)
92
+
93
+ xs = self.linear(xs.transpose(1, 2)) # (B, Tmax, 1)
94
+
95
+ if padding_mask is not None:
96
+ xs = xs.masked_fill(padding_mask, 0.0)
97
+
98
+ return xs
Architectures/GeneralLayers/__init__.py ADDED
File without changes
Architectures/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ This directory contains all the models that are used in this toolkit for various tasks. The models' directories contain their
2
+ feature extractors, their datasets, their architectures, and their train loops.
Architectures/ToucanTTS/CodecDiscriminator.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def weights_init_D(m):
6
+ classname = m.__class__.__name__
7
+ if classname.find('Conv') != -1:
8
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
9
+ elif classname.find('BatchNorm') != -1:
10
+ nn.init.constant_(m.weight, 1)
11
+ nn.init.constant_(m.bias, 0)
12
+
13
+
14
+ class SpectrogramDiscriminator(torch.nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.D = DiscriminatorNet()
18
+ self.D.apply(weights_init_D)
19
+
20
+ def _generator_feedback(self, data_generated, data_real):
21
+ for p in self.D.parameters():
22
+ p.requires_grad = False # freeze critic
23
+
24
+ score_fake, fmap_fake = self.D(data_generated)
25
+ _, fmap_real = self.D(data_real)
26
+
27
+ feature_matching_loss = 0.0
28
+ for feat_fake, feat_real in zip(fmap_fake, fmap_real):
29
+ feature_matching_loss += nn.functional.l1_loss(feat_fake, feat_real.detach())
30
+
31
+ discr_loss = nn.functional.mse_loss(input=score_fake, target=torch.ones(score_fake.shape, device=score_fake.device), reduction="mean")
32
+
33
+ return feature_matching_loss + discr_loss
34
+
35
+ def _discriminator_feature_matching(self, data_generated, data_real):
36
+ for p in self.D.parameters():
37
+ p.requires_grad = True # unfreeze critic
38
+ self.D.train()
39
+
40
+ score_fake, _ = self.D(data_generated)
41
+ score_real, _ = self.D(data_real)
42
+
43
+ discr_loss = 0.0
44
+ discr_loss = discr_loss + nn.functional.mse_loss(input=score_fake, target=torch.zeros(score_fake.shape, device=score_fake.device), reduction="mean")
45
+ discr_loss = discr_loss + nn.functional.mse_loss(input=score_real, target=torch.ones(score_real.shape, device=score_real.device), reduction="mean")
46
+
47
+ return discr_loss
48
+
49
+ def calc_discriminator_loss(self, data_generated, data_real):
50
+ return self._discriminator_feature_matching(data_generated.detach(), data_real)
51
+
52
+ def calc_generator_feedback(self, data_generated, data_real):
53
+ return self._generator_feedback(data_generated, data_real)
54
+
55
+
56
+ class DiscriminatorNet(nn.Module):
57
+ def __init__(self):
58
+ super().__init__()
59
+ self.filters = nn.ModuleList([
60
+ nn.utils.weight_norm(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
61
+ nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
62
+ nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
63
+ nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
64
+ nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
65
+ ])
66
+
67
+ self.out = nn.utils.weight_norm(nn.Conv2d(32, 1, 3, 1, 1))
68
+
69
+ self.fc = nn.Linear(900, 1) # this needs to be changed everytime the window length is changes. It would be nice if this could be done dynamically.
70
+
71
+ def forward(self, y):
72
+ feature_maps = list()
73
+ feature_maps.append(y)
74
+ for d in self.filters:
75
+ y = d(y)
76
+ feature_maps.append(y)
77
+ y = nn.functional.leaky_relu(y, 0.1)
78
+ y = self.out(y)
79
+ feature_maps.append(y)
80
+ y = torch.flatten(y, 1, -1)
81
+ y = self.fc(y)
82
+
83
+ return y, feature_maps
84
+
85
+
86
+ if __name__ == '__main__':
87
+ d = SpectrogramDiscriminator()
88
+ fake = torch.randn([2, 100, 72]) # [Batch, Sequence Length, Spectrogram Buckets]
89
+ real = torch.randn([2, 100, 72]) # [Batch, Sequence Length, Spectrogram Buckets]
90
+
91
+ critic_loss = d.calc_discriminator_loss((fake.unsqueeze(1)), real.unsqueeze(1))
92
+ generator_loss = d.calc_generator_feedback(fake.unsqueeze(1), real.unsqueeze(1))
93
+ print(critic_loss)
94
+ print(generator_loss)
Architectures/ToucanTTS/CodecRefinementTransformer.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from Architectures.GeneralLayers.Conformer import Conformer
4
+
5
+
6
+ class CodecRefinementTransformer(torch.nn.Module):
7
+
8
+ def __init__(self,
9
+ attention_dimension=128,
10
+ num_codebooks=4,
11
+ codebook_size=1024,
12
+ backtranslation_dim=8,
13
+ attention_heads=4,
14
+ positionwise_conv_kernel_size=1,
15
+ use_macaron_style_in_conformer=True,
16
+ use_cnn_in_conformer=False, # for now, we try using just a regular transformer
17
+ decoder_layers=6,
18
+ decoder_units=1280,
19
+ decoder_concat_after=False,
20
+ conformer_decoder_kernel_size=31,
21
+ decoder_normalize_before=True,
22
+ transformer_dec_dropout_rate=0.2,
23
+ transformer_dec_positional_dropout_rate=0.1,
24
+ transformer_dec_attn_dropout_rate=0.1,
25
+ utt_embed_dim=512,
26
+ use_conditional_layernorm_embedding_integration=False,
27
+ ):
28
+ super().__init__()
29
+
30
+ self.reconstruction_transformer = Conformer(
31
+ conformer_type="decoder",
32
+ attention_dim=num_codebooks * backtranslation_dim,
33
+ attention_heads=attention_heads,
34
+ linear_units=decoder_units,
35
+ num_blocks=decoder_layers,
36
+ input_layer=None,
37
+ dropout_rate=transformer_dec_dropout_rate,
38
+ positional_dropout_rate=transformer_dec_positional_dropout_rate,
39
+ attention_dropout_rate=transformer_dec_attn_dropout_rate,
40
+ normalize_before=decoder_normalize_before,
41
+ concat_after=decoder_concat_after,
42
+ positionwise_conv_kernel_size=positionwise_conv_kernel_size,
43
+ macaron_style=use_macaron_style_in_conformer,
44
+ use_cnn_module=use_cnn_in_conformer,
45
+ cnn_module_kernel=conformer_decoder_kernel_size,
46
+ use_output_norm=False,
47
+ utt_embed=utt_embed_dim,
48
+ use_conditional_layernorm_embedding_integration=use_conditional_layernorm_embedding_integration
49
+ )
50
+
51
+ self.num_codebooks = num_codebooks
52
+ self.codebook_size = codebook_size
53
+ self.input_embeddings = torch.nn.ModuleList()
54
+ self.backtranslation_heads = torch.nn.ModuleList()
55
+ self.hierarchical_classifier = torch.nn.ModuleList()
56
+ self.padding_id = codebook_size + 5
57
+ for head in range(num_codebooks):
58
+ self.input_embeddings.append(torch.nn.Embedding(num_embeddings=self.padding_id + 1, embedding_dim=backtranslation_dim, padding_idx=self.padding_id))
59
+ self.backtranslation_heads.append(torch.nn.Embedding(num_embeddings=self.padding_id + 1, embedding_dim=backtranslation_dim, padding_idx=self.padding_id))
60
+ self.hierarchical_classifier.append(torch.nn.Linear(num_codebooks * backtranslation_dim + head * backtranslation_dim, codebook_size))
61
+
62
+ self.criterion = MaskedRefinementObjective()
63
+ for backtranslation_head in self.backtranslation_heads:
64
+ torch.nn.init.normal_(backtranslation_head.weight, mean=0, std=attention_dimension ** -0.5)
65
+ for input_embedding in self.input_embeddings:
66
+ torch.nn.init.normal_(input_embedding.weight, mean=0, std=attention_dimension ** -0.5)
67
+
68
+ def forward(self, index_sequence, is_inference, speaker_embedding, padding_mask=None, gold_index_sequence=None):
69
+ """
70
+ index_sequence: [batch, codebook_index, time_steps] a sequence of indexes that come from an argmax of the previous prediction layer.
71
+ is_inference: boolean flag that indicates whether to return the masked language modelling loss or the refined sequence
72
+ speaker_embedding: [batch, speaker_embed_dim]
73
+ padding_mask: [batch, time_steps] a mask that is True for all time steps that are padding and should not be considered and False everywhere else.
74
+
75
+ return: loss if is_inference is false, otherwise [batch, codebook_index, time_steps] a sequence of indexes with the same shape and same interpretation, refined through iterative masked language modelling.
76
+ """
77
+
78
+ if not is_inference:
79
+ index_sequence_padding_accounted = index_sequence.masked_fill(mask=padding_mask.unsqueeze(1), value=self.padding_id)
80
+ else:
81
+ index_sequence_padding_accounted = index_sequence # in the case of inference, there is no padding
82
+
83
+ sequence_of_continuous_tokens = self.indexes_per_codebook_to_stacked_embedding_vector(index_sequence_padding_accounted) # return [batch, time_steps, num_codebooks x backtranslation_dim]
84
+ contextualized_sequence = self.contextualize_sequence(sequence_of_continuous_tokens, speaker_embedding, non_padding_mask=~padding_mask if padding_mask is not None else None)
85
+
86
+ predicted_indexes_one_hot = list()
87
+ backtranslated_indexes = list()
88
+ for head_index, classifier_head in enumerate(self.hierarchical_classifier):
89
+ # each codebook considers all previous codebooks.
90
+ predicted_indexes_one_hot.append(classifier_head(torch.cat([contextualized_sequence] + backtranslated_indexes, dim=2)))
91
+ predicted_lookup_index = torch.argmax(predicted_indexes_one_hot[-1], dim=-1)
92
+ backtranslation = self.backtranslation_heads[head_index](predicted_lookup_index)
93
+ if len(backtranslation.size()) == 1:
94
+ backtranslation = backtranslation.unsqueeze(0)
95
+ backtranslated_indexes.append(backtranslation)
96
+ indexes = torch.cat(predicted_indexes_one_hot, dim=2)
97
+ # [Batch, Sequence, Hidden]
98
+ indexes = indexes.view(contextualized_sequence.size(0), contextualized_sequence.size(1), self.num_codebooks, self.codebook_size)
99
+ # [Batch, Sequence, Codebook, Classes]
100
+ indexes = indexes.transpose(1, 2)
101
+ # [Batch, Codebook, Sequence, Classes]
102
+ indexes = indexes.transpose(2, 3)
103
+ # [Batch, Codebook, Classes, Sequence]
104
+ indexes = indexes.transpose(0, 1)
105
+ # [Codebook, Batch, Classes, Sequence]
106
+
107
+ if is_inference:
108
+ return indexes
109
+ else:
110
+ return self.criterion(predicted_one_hot=indexes, gold_one_hot=gold_index_sequence, non_pad_mask=~padding_mask)
111
+
112
+ def contextualize_sequence(self, masked_sequence, utterance_embedding, non_padding_mask):
113
+ decoded_speech, _ = self.reconstruction_transformer(masked_sequence, non_padding_mask.unsqueeze(2) if non_padding_mask is not None else None, utterance_embedding=utterance_embedding)
114
+ return decoded_speech
115
+
116
+ def indexes_per_codebook_to_stacked_embedding_vector(self, index_sequence_per_codebook):
117
+ continuous_frame_sequences = list()
118
+
119
+ for codebook_id, backtranslation_head in enumerate(self.backtranslation_heads):
120
+ continuous_frame_sequences.append(backtranslation_head(index_sequence_per_codebook.transpose(0, 1)[codebook_id]))
121
+ stacked_embedding_vector = torch.cat(continuous_frame_sequences, dim=-1)
122
+ return stacked_embedding_vector
123
+
124
+
125
+ class MaskedRefinementObjective(torch.nn.Module):
126
+
127
+ def __init__(self):
128
+ super().__init__()
129
+ self.classification_loss = torch.nn.CrossEntropyLoss(reduction="none")
130
+ self.l1_loss = torch.nn.L1Loss(reduction="none")
131
+
132
+ def forward(self, predicted_one_hot, gold_one_hot, non_pad_mask):
133
+ ce = list()
134
+ for one_hot_pred, one_hot_target in zip(predicted_one_hot, gold_one_hot.transpose(0, 1).transpose(2, 3)):
135
+ # we iterate over codebooks
136
+ ce.append(self.classification_loss(one_hot_pred, one_hot_target))
137
+ classification_loss = torch.stack(ce).sum(0)
138
+ # make weighted mask and apply it
139
+ out_masks = non_pad_mask.unsqueeze(-1).to(gold_one_hot.device)
140
+ out_masks = torch.nn.functional.pad(out_masks.transpose(1, 2), [0, gold_one_hot.size(2) - out_masks.size(1), 0, 0, 0, 0], value=False).transpose(1, 2)
141
+ out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
142
+ out_weights /= gold_one_hot.size(0) * gold_one_hot.size(-1)
143
+ # apply weight
144
+ classification_loss = classification_loss.mul(out_weights.squeeze()).masked_select(out_masks.squeeze()).sum()
145
+
146
+ return classification_loss, classification_loss
147
+
148
+
149
+ def one_hot_sequence_to_token_sequence(batch_of_indexes_one_hot_per_codebook):
150
+ return torch.argmax(batch_of_indexes_one_hot_per_codebook, dim=-2).transpose(0, 1)
151
+
152
+
153
+ if __name__ == '__main__':
154
+ from Architectures.ToucanTTS.ToucanTTS import ToucanTTS
155
+ from Utility.utils import make_pad_mask
156
+
157
+ # prepare dummy inputs
158
+ num_codebooks = 4
159
+ dummy_text_batch = torch.randint(low=0, high=2, size=[3, 3, 62]).float() # [Batch, Sequence Length, Features per Phone]
160
+ dummy_text_lens = torch.LongTensor([2, 3, 3])
161
+ gold_speech_batch = torch.randn([3, num_codebooks, 30, 1024]) # [Batch, Sequence Length, Spectrogram Buckets]
162
+ gold_speech_lens = torch.LongTensor([10, 30, 20])
163
+ gold_durations = torch.LongTensor([[10, 0, 0], [10, 15, 5], [5, 5, 10]])
164
+ gold_pitch = torch.Tensor([[[1.0], [0.], [0.]], [[1.1], [1.2], [0.8]], [[1.1], [1.2], [0.8]]])
165
+ gold_energy = torch.Tensor([[[1.0], [1.3], [0.]], [[1.1], [1.4], [0.8]], [[1.1], [1.2], [0.8]]])
166
+ dummy_utterance_embed = torch.randn([3, 512]) # [Batch, Dimensions of Speaker Embedding]
167
+ dummy_language_id = torch.LongTensor([5, 3, 2]).unsqueeze(1)
168
+
169
+ # run TTS on pseudo inputs
170
+ batch_of_indexes_one_hot_per_codebook, _, _, _, _, _ = ToucanTTS(num_codebooks=num_codebooks, use_language_model=False)._forward(dummy_text_batch,
171
+ dummy_text_lens,
172
+ gold_speech_batch,
173
+ gold_speech_lens,
174
+ gold_durations,
175
+ gold_pitch,
176
+ gold_energy,
177
+ utterance_embedding=dummy_utterance_embed,
178
+ lang_ids=dummy_language_id)
179
+
180
+ # reformat outputs to be a token sequence
181
+ batch_of_indexes = one_hot_sequence_to_token_sequence(batch_of_indexes_one_hot_per_codebook)
182
+
183
+ # refine the output of the TTS with the Language Model
184
+ refiner = CodecRefinementTransformer()
185
+
186
+ loss = refiner(index_sequence=one_hot_sequence_to_token_sequence(gold_speech_batch.transpose(3, 2)).transpose(0, 1), padding_mask=make_pad_mask(gold_speech_lens), is_inference=False, speaker_embedding=dummy_utterance_embed, gold_index_sequence=gold_speech_batch)
187
+ print(loss)
188
+
189
+ refined_indexes = refiner(index_sequence=batch_of_indexes[1].unsqueeze(0), is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None)
190
+ print(refined_indexes.shape)
191
+ refined_indexes = one_hot_sequence_to_token_sequence(refined_indexes)
192
+ refined_indexes = refiner(index_sequence=refined_indexes, is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None)
193
+ print(refined_indexes.shape)
194
+ refined_indexes = one_hot_sequence_to_token_sequence(refined_indexes)
195
+ refined_indexes = refiner(index_sequence=refined_indexes, is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None)
196
+ print(refined_indexes.shape)
197
+ refined_indexes = one_hot_sequence_to_token_sequence(refined_indexes)
198
+ refined_indexes = refiner(index_sequence=refined_indexes, is_inference=True, speaker_embedding=dummy_utterance_embed[0].unsqueeze(0), gold_index_sequence=None)
199
+ print(refined_indexes.shape)
Architectures/ToucanTTS/DurationCalculator.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Nagoya University (Tomoki Hayashi)
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ import matplotlib.pyplot as plt
6
+
7
+ import torch
8
+
9
+
10
+ class DurationCalculator(torch.nn.Module):
11
+
12
+ def __init__(self, reduction_factor=1.0):
13
+ super().__init__()
14
+
15
+ @torch.no_grad()
16
+ def forward(self, att_ws, vis=None):
17
+ """
18
+ Convert alignment matrix to durations.
19
+ """
20
+ if vis is not None:
21
+ plt.figure(figsize=(8, 4))
22
+ plt.imshow(att_ws.cpu().numpy(), interpolation='nearest', aspect='auto', origin="lower")
23
+ plt.xlabel("Inputs")
24
+ plt.ylabel("Outputs")
25
+ plt.tight_layout()
26
+ plt.savefig(vis)
27
+ plt.close()
28
+ # calculate duration from 2d alignment matrix
29
+ durations = torch.stack([att_ws.argmax(-1).eq(i).sum() for i in range(att_ws.shape[1])])
30
+ return durations.view(-1)
Architectures/ToucanTTS/EnergyCalculator.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Nagoya University (Tomoki Hayashi)
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from Architectures.GeneralLayers.STFT import STFT
9
+ from Utility.utils import pad_list
10
+
11
+
12
+ class EnergyCalculator(torch.nn.Module):
13
+
14
+ def __init__(self, fs=16000, n_fft=1024, win_length=None, hop_length=256, window="hann", center=True,
15
+ normalized=False, onesided=True, use_token_averaged_energy=True, reduction_factor=1):
16
+ super().__init__()
17
+
18
+ self.fs = fs
19
+ self.n_fft = n_fft
20
+ self.hop_length = hop_length
21
+ self.win_length = win_length
22
+ self.window = window
23
+ self.use_token_averaged_energy = use_token_averaged_energy
24
+ if use_token_averaged_energy:
25
+ assert reduction_factor >= 1
26
+ self.reduction_factor = reduction_factor
27
+
28
+ self.stft = STFT(n_fft=n_fft, win_length=win_length, hop_length=hop_length, window=window, center=center, normalized=normalized, onesided=onesided)
29
+
30
+ def output_size(self):
31
+ return 1
32
+
33
+ def get_parameters(self):
34
+ return dict(fs=self.fs, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, win_length=self.win_length, center=self.stft.center,
35
+ normalized=self.stft.normalized, use_token_averaged_energy=self.use_token_averaged_energy, reduction_factor=self.reduction_factor)
36
+
37
+ def forward(self, input_waves, input_waves_lengths=None, feats_lengths=None, durations=None,
38
+ durations_lengths=None, norm_by_average=True, text=None):
39
+ # If not provided, we assume that the inputs have the same length
40
+ if input_waves_lengths is None:
41
+ input_waves_lengths = (input_waves.new_ones(input_waves.shape[0], dtype=torch.long) * input_waves.shape[1])
42
+
43
+ # Domain-conversion: e.g. Stft: time -> time-freq
44
+ input_stft, energy_lengths = self.stft(input_waves, input_waves_lengths)
45
+
46
+ assert input_stft.dim() >= 4, input_stft.shape
47
+ assert input_stft.shape[-1] == 2, input_stft.shape
48
+
49
+ # input_stft: (..., F, 2) -> (..., F)
50
+ input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2
51
+ # sum over frequency (B, N, F) -> (B, N)
52
+ energy = torch.sqrt(torch.clamp(input_power.sum(dim=2), min=1.0e-10))
53
+
54
+ # (Optional): Adjust length to match with the features
55
+ if feats_lengths is not None:
56
+ energy = [self._adjust_num_frames(e[:el].view(-1), fl) for e, el, fl in zip(energy, energy_lengths, feats_lengths)]
57
+ energy_lengths = feats_lengths
58
+
59
+ # (Optional): Average by duration to calculate token-wise energy
60
+ if self.use_token_averaged_energy:
61
+ energy = [self._average_by_duration(e[:el].view(-1), d, text) for e, el, d in zip(energy, energy_lengths, durations)]
62
+ energy_lengths = durations_lengths
63
+
64
+ # Padding
65
+ if isinstance(energy, list):
66
+ energy = pad_list(energy, 0.0)
67
+
68
+ if norm_by_average:
69
+ average = energy[0][energy[0] != 0.0].mean()
70
+ energy = energy / average
71
+
72
+ # Return with the shape (B, T, 1)
73
+ return energy.unsqueeze(-1), energy_lengths
74
+
75
+ def _average_by_duration(self, x, d, text=None):
76
+ d_cumsum = F.pad(d.cumsum(dim=0), (1, 0))
77
+ x_avg = [x[start:end].mean() if len(x[start:end]) != 0 else x.new_tensor(0.0) for start, end in zip(d_cumsum[:-1], d_cumsum[1:])]
78
+
79
+ # find tokens that are not phoneme and set energy to 0
80
+ # while this makes sense, it make sit harder to model, so we leave this out
81
+ # if text is not None:
82
+ # for i, vector in enumerate(text):
83
+ # if vector[get_feature_to_index_lookup()["phoneme"]] == 0:
84
+ # x_avg[i] = torch.tensor(0.0, device=x.device)
85
+
86
+ return torch.stack(x_avg)
87
+
88
+ @staticmethod
89
+ def _adjust_num_frames(x, num_frames):
90
+ if num_frames > len(x):
91
+ x = F.pad(x, (0, num_frames - len(x)))
92
+ elif num_frames < len(x):
93
+ x = x[:num_frames]
94
+ return x
Architectures/ToucanTTS/Glow.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ import torch
4
+ import torch.distributions as dist
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from Architectures.ToucanTTS import glow_utils
9
+ from Architectures.ToucanTTS.wavenet import WN
10
+
11
+
12
+ class ActNorm(nn.Module):
13
+
14
+ def __init__(self, channels, ddi=False, **kwargs):
15
+ super().__init__()
16
+ self.channels = channels
17
+ self.initialized = not ddi
18
+
19
+ self.logs = nn.Parameter(torch.zeros(1, channels, 1))
20
+ self.bias = nn.Parameter(torch.zeros(1, channels, 1))
21
+
22
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
23
+ if x_mask is None:
24
+ x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
25
+ x_len = torch.sum(x_mask, [1, 2])
26
+ if not self.initialized:
27
+ self.initialize(x, x_mask)
28
+ self.initialized = True
29
+
30
+ if reverse:
31
+ z = (x - self.bias) * torch.exp(-self.logs) * x_mask
32
+ logdet = torch.sum(-self.logs) * x_len
33
+ else:
34
+ z = (self.bias + torch.exp(self.logs) * x) * x_mask
35
+ logdet = torch.sum(self.logs) * x_len # [b]
36
+ return z, logdet
37
+
38
+ def store_inverse(self):
39
+ pass
40
+
41
+ def set_ddi(self, ddi):
42
+ self.initialized = not ddi
43
+
44
+ def initialize(self, x, x_mask):
45
+ with torch.no_grad():
46
+ denom = torch.sum(x_mask, [0, 2])
47
+ m = torch.sum(x * x_mask, [0, 2]) / denom
48
+ m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
49
+ v = m_sq - (m ** 2)
50
+ logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
51
+
52
+ bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
53
+ logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
54
+
55
+ self.bias.data.copy_(bias_init)
56
+ self.logs.data.copy_(logs_init)
57
+
58
+
59
+ class InvConvNear(nn.Module):
60
+
61
+ def __init__(self, channels, n_split=4, no_jacobian=False, lu=True, n_sqz=2, **kwargs):
62
+ super().__init__()
63
+ assert (n_split % 2 == 0)
64
+ self.channels = channels
65
+ self.n_split = n_split
66
+ self.n_sqz = n_sqz
67
+ self.no_jacobian = no_jacobian
68
+
69
+ w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_(), 'complete')[0]
70
+ if torch.det(w_init) < 0:
71
+ w_init[:, 0] = -1 * w_init[:, 0]
72
+ self.lu = lu
73
+ if lu:
74
+ # LU decomposition can slightly speed up the inverse
75
+ np_p, np_l, np_u = scipy.linalg.lu(w_init)
76
+ np_s = np.diag(np_u)
77
+ np_sign_s = np.sign(np_s)
78
+ np_log_s = np.log(np.abs(np_s))
79
+ np_u = np.triu(np_u, k=1)
80
+ l_mask = np.tril(np.ones(w_init.shape, dtype=float), -1)
81
+ eye = np.eye(*w_init.shape, dtype=float)
82
+
83
+ self.register_buffer('p', torch.Tensor(np_p.astype(float)))
84
+ self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
85
+ self.l = nn.Parameter(torch.Tensor(np_l.astype(float)), requires_grad=True)
86
+ self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)), requires_grad=True)
87
+ self.u = nn.Parameter(torch.Tensor(np_u.astype(float)), requires_grad=True)
88
+ self.register_buffer('l_mask', torch.Tensor(l_mask))
89
+ self.register_buffer('eye', torch.Tensor(eye))
90
+ else:
91
+ self.weight = nn.Parameter(w_init)
92
+
93
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
94
+ b, c, t = x.size()
95
+ assert (c % self.n_split == 0)
96
+ if x_mask is None:
97
+ x_mask = 1
98
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
99
+ else:
100
+ x_len = torch.sum(x_mask, [1, 2])
101
+
102
+ x = x.view(b, self.n_sqz, c // self.n_split, self.n_split // self.n_sqz, t)
103
+ x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
104
+
105
+ if self.lu:
106
+ self.weight, log_s = self._get_weight()
107
+ logdet = log_s.sum()
108
+ logdet = logdet * (c / self.n_split) * x_len
109
+ else:
110
+ logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
111
+
112
+ if reverse:
113
+ if hasattr(self, "weight_inv"):
114
+ weight = self.weight_inv
115
+ else:
116
+ weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
117
+ logdet = -logdet
118
+ else:
119
+ weight = self.weight
120
+ if self.no_jacobian:
121
+ logdet = 0
122
+
123
+ weight = weight.view(self.n_split, self.n_split, 1, 1).to(x.device)
124
+ z = F.conv2d(x, weight)
125
+
126
+ z = z.view(b, self.n_sqz, self.n_split // self.n_sqz, c // self.n_split, t)
127
+ z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
128
+ return z, logdet
129
+
130
+ def _get_weight(self):
131
+ l, log_s, u = self.l, self.log_s, self.u
132
+ l = l * self.l_mask + self.eye
133
+ u = u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(log_s))
134
+ weight = torch.matmul(self.p, torch.matmul(l, u))
135
+ return weight, log_s
136
+
137
+ def store_inverse(self):
138
+ weight, _ = self._get_weight()
139
+ self.weight_inv = torch.inverse(weight.float()).to(next(self.parameters()).device)
140
+
141
+
142
+ class InvConv(nn.Module):
143
+
144
+ def __init__(self, channels, no_jacobian=False, lu=True, **kwargs):
145
+ super().__init__()
146
+ w_shape = [channels, channels]
147
+ w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(float)
148
+ LU_decomposed = lu
149
+ if not LU_decomposed:
150
+ # Sample a random orthogonal matrix:
151
+ self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
152
+ else:
153
+ np_p, np_l, np_u = scipy.linalg.lu(w_init)
154
+ np_s = np.diag(np_u)
155
+ np_sign_s = np.sign(np_s)
156
+ np_log_s = np.log(np.abs(np_s))
157
+ np_u = np.triu(np_u, k=1)
158
+ l_mask = np.tril(np.ones(w_shape, dtype=float), -1)
159
+ eye = np.eye(*w_shape, dtype=float)
160
+
161
+ self.register_buffer('p', torch.Tensor(np_p.astype(float)))
162
+ self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
163
+ self.l = nn.Parameter(torch.Tensor(np_l.astype(float)))
164
+ self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)))
165
+ self.u = nn.Parameter(torch.Tensor(np_u.astype(float)))
166
+ self.l_mask = torch.Tensor(l_mask)
167
+ self.eye = torch.Tensor(eye)
168
+ self.w_shape = w_shape
169
+ self.LU = LU_decomposed
170
+ self.weight = None
171
+
172
+ def get_weight(self, device, reverse):
173
+ w_shape = self.w_shape
174
+ self.p = self.p.to(device)
175
+ self.sign_s = self.sign_s.to(device)
176
+ self.l_mask = self.l_mask.to(device)
177
+ self.eye = self.eye.to(device)
178
+ l = self.l * self.l_mask + self.eye
179
+ u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
180
+ dlogdet = self.log_s.sum()
181
+ if not reverse:
182
+ w = torch.matmul(self.p, torch.matmul(l, u))
183
+ else:
184
+ l = torch.inverse(l.double()).float()
185
+ u = torch.inverse(u.double()).float()
186
+ w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
187
+ return w.view(w_shape[0], w_shape[1], 1), dlogdet
188
+
189
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
190
+ """
191
+ log-det = log|abs(|W|)| * pixels
192
+ """
193
+ b, c, t = x.size()
194
+ if x_mask is None:
195
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
196
+ else:
197
+ x_len = torch.sum(x_mask, [1, 2])
198
+ logdet = 0
199
+ if not reverse:
200
+ weight, dlogdet = self.get_weight(x.device, reverse)
201
+ z = F.conv1d(x, weight)
202
+ if logdet is not None:
203
+ logdet = logdet + dlogdet * x_len
204
+ return z, logdet
205
+ else:
206
+ if self.weight is None:
207
+ weight, dlogdet = self.get_weight(x.device, reverse)
208
+ else:
209
+ weight, dlogdet = self.weight, self.dlogdet
210
+ z = F.conv1d(x, weight)
211
+ if logdet is not None:
212
+ logdet = logdet - dlogdet * x_len
213
+ return z, logdet
214
+
215
+ def store_inverse(self):
216
+ self.weight, self.dlogdet = self.get_weight('cuda', reverse=True)
217
+
218
+
219
+ class CouplingBlock(nn.Module):
220
+
221
+ def __init__(self, in_channels, hidden_channels, kernel_size, dilation_rate, n_layers,
222
+ gin_channels=0, p_dropout=0., sigmoid_scale=False, wn=None, use_weightnorm=True):
223
+ super().__init__()
224
+ self.in_channels = in_channels
225
+ self.hidden_channels = hidden_channels
226
+ self.kernel_size = kernel_size
227
+ self.dilation_rate = dilation_rate
228
+ self.n_layers = n_layers
229
+ self.gin_channels = gin_channels
230
+ self.p_dropout = p_dropout
231
+ self.sigmoid_scale = sigmoid_scale
232
+
233
+ start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
234
+ if use_weightnorm:
235
+ start = torch.nn.utils.weight_norm(start)
236
+ self.start = start
237
+ # Initializing last layer to 0 makes the affine coupling layers
238
+ # do nothing at first. This helps with training stability
239
+ end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
240
+ end.weight.data.zero_()
241
+ end.bias.data.zero_()
242
+ self.end = end
243
+ self.wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels, p_dropout, use_weightnorm=use_weightnorm)
244
+ if wn is not None:
245
+ self.wn.in_layers = wn.in_layers
246
+ self.wn.res_skip_layers = wn.res_skip_layers
247
+
248
+ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
249
+ if x_mask is None:
250
+ x_mask = 1
251
+ x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
252
+
253
+ x = self.start(x_0) * x_mask
254
+ x = self.wn(x, x_mask, g)
255
+ out = self.end(x)
256
+
257
+ z_0 = x_0
258
+ m = out[:, :self.in_channels // 2, :]
259
+ logs = out[:, self.in_channels // 2:, :]
260
+ if self.sigmoid_scale:
261
+ logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
262
+ if reverse:
263
+ z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
264
+ logdet = torch.sum(-logs * x_mask, [1, 2])
265
+ else:
266
+ z_1 = (m + torch.exp(logs) * x_1) * x_mask
267
+ logdet = torch.sum(logs * x_mask, [1, 2])
268
+ z = torch.cat([z_0, z_1], 1)
269
+ return z, logdet
270
+
271
+ def store_inverse(self):
272
+ self.wn.remove_weight_norm()
273
+
274
+
275
+ class Glow(nn.Module):
276
+
277
+ def __init__(self,
278
+ in_channels,
279
+ hidden_channels,
280
+ kernel_size,
281
+ dilation_rate,
282
+ n_blocks,
283
+ n_layers,
284
+ condition_integration_projection,
285
+ p_dropout=0.,
286
+ n_split=4,
287
+ n_sqz=2,
288
+ sigmoid_scale=False,
289
+ text_condition_channels=0,
290
+ inv_conv_type='near',
291
+ share_cond_layers=False,
292
+ share_wn_layers=0,
293
+ use_weightnorm=True # If weightnorm is set to false, we can deepcopy the module, which we need to be able to do to perform SWA. Without weightnorm, the module will probably take a little longer to converge.
294
+ ):
295
+ super().__init__()
296
+
297
+ self.in_channels = in_channels
298
+ self.hidden_channels = hidden_channels
299
+ self.kernel_size = kernel_size
300
+ self.dilation_rate = dilation_rate
301
+ self.n_blocks = n_blocks
302
+ self.n_layers = n_layers
303
+ self.p_dropout = p_dropout
304
+ self.n_split = n_split
305
+ self.n_sqz = n_sqz
306
+ self.sigmoid_scale = sigmoid_scale
307
+ self.text_condition_channels = text_condition_channels
308
+ self.share_cond_layers = share_cond_layers
309
+ self.prior_dist = dist.Normal(0, 1)
310
+ self.g_proj = condition_integration_projection
311
+ if text_condition_channels != 0 and share_cond_layers:
312
+ cond_layer = torch.nn.Conv1d(text_condition_channels * n_sqz, 2 * hidden_channels * n_layers, 1)
313
+ if use_weightnorm:
314
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
315
+ else:
316
+ self.cond_layer = cond_layer
317
+ wn = None
318
+ self.flows = nn.ModuleList()
319
+ for b in range(n_blocks):
320
+ self.flows.append(ActNorm(channels=in_channels * n_sqz))
321
+ if inv_conv_type == 'near':
322
+ self.flows.append(InvConvNear(channels=in_channels * n_sqz, n_split=n_split, n_sqz=n_sqz))
323
+ if inv_conv_type == 'invconv':
324
+ self.flows.append(InvConv(channels=in_channels * n_sqz))
325
+ if share_wn_layers > 0:
326
+ if b % share_wn_layers == 0:
327
+ wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, text_condition_channels * n_sqz, p_dropout, share_cond_layers, use_weightnorm=use_weightnorm)
328
+ self.flows.append(
329
+ CouplingBlock(
330
+ in_channels * n_sqz,
331
+ hidden_channels,
332
+ kernel_size=kernel_size,
333
+ dilation_rate=dilation_rate,
334
+ n_layers=n_layers,
335
+ gin_channels=text_condition_channels * n_sqz,
336
+ p_dropout=p_dropout,
337
+ sigmoid_scale=sigmoid_scale,
338
+ wn=wn,
339
+ use_weightnorm=use_weightnorm
340
+ ))
341
+
342
+ def forward(self, tgt_mels, infer, mel_out, encoded_texts, tgt_nonpadding, glow_sampling_temperature=0.2):
343
+ x_recon = mel_out.transpose(1, 2)
344
+ g = x_recon
345
+ B, _, T = g.shape
346
+ if encoded_texts is not None and self.text_condition_channels != 0:
347
+ g = torch.cat([g, encoded_texts.transpose(1, 2)], 1)
348
+ g = self.g_proj(g)
349
+ prior_dist = self.prior_dist
350
+ if not infer:
351
+ y_lengths = tgt_nonpadding.sum(-1)
352
+ tgt_mels = tgt_mels.transpose(1, 2)
353
+ z_postflow, ldj = self._forward(tgt_mels, tgt_nonpadding, g=g)
354
+ ldj = ldj / y_lengths / 80
355
+ try:
356
+ postflow_loss = -prior_dist.log_prob(z_postflow).mean() - ldj.mean()
357
+ except ValueError:
358
+ print("log probability of postflow could not be calculated for this step")
359
+ postflow_loss = None
360
+ return postflow_loss
361
+ else:
362
+ nonpadding = torch.ones_like(x_recon[:, :1, :]) if tgt_nonpadding is None else tgt_nonpadding
363
+ z_post = torch.randn(x_recon.shape).to(g.device) * glow_sampling_temperature
364
+ x_recon, _ = self._forward(z_post, nonpadding, g, reverse=True)
365
+ return x_recon.transpose(1, 2)
366
+
367
+ def _forward(self, x, x_mask=None, g=None, reverse=False, return_hiddens=False):
368
+ logdet_tot = 0
369
+ if not reverse:
370
+ flows = self.flows
371
+ else:
372
+ flows = reversed(self.flows)
373
+ if return_hiddens:
374
+ hs = []
375
+ if self.n_sqz > 1:
376
+ x, x_mask_ = glow_utils.squeeze(x, x_mask, self.n_sqz)
377
+ if g is not None:
378
+ g, _ = glow_utils.squeeze(g, x_mask, self.n_sqz)
379
+ x_mask = x_mask_
380
+ if self.share_cond_layers and g is not None:
381
+ g = self.cond_layer(g)
382
+ for f in flows:
383
+ x, logdet = f(x, x_mask, g=g, reverse=reverse)
384
+ if return_hiddens:
385
+ hs.append(x)
386
+ logdet_tot += logdet
387
+ if self.n_sqz > 1:
388
+ x, x_mask = glow_utils.unsqueeze(x, x_mask, self.n_sqz)
389
+ if return_hiddens:
390
+ return x, logdet_tot, hs
391
+ return x, logdet_tot
392
+
393
+ def store_inverse(self):
394
+ def remove_weight_norm(m):
395
+ try:
396
+ nn.utils.remove_weight_norm(m)
397
+ except ValueError: # this module didn't have weight norm
398
+ return
399
+
400
+ self.apply(remove_weight_norm)
401
+ for f in self.flows:
402
+ f.store_inverse()
Architectures/ToucanTTS/InferenceToucanTTS.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dotwiz
2
+ import torch
3
+ from torch.nn import Linear
4
+ from torch.nn import Sequential
5
+ from torch.nn import Tanh
6
+
7
+ from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d
8
+ from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
9
+ from Architectures.GeneralLayers.Conformer import Conformer
10
+ from Architectures.GeneralLayers.DurationPredictor import DurationPredictor
11
+ from Architectures.GeneralLayers.LengthRegulator import LengthRegulator
12
+ from Architectures.GeneralLayers.VariancePredictor import VariancePredictor
13
+ from Architectures.ToucanTTS.Glow import Glow
14
+ from Preprocessing.articulatory_features import get_feature_to_index_lookup
15
+ from Utility.utils import integrate_with_utt_embed
16
+ from Utility.utils import make_non_pad_mask
17
+
18
+
19
+ class ToucanTTS(torch.nn.Module):
20
+
21
+ def __init__(self,
22
+ weights,
23
+ config):
24
+ super().__init__()
25
+
26
+ self.config = config
27
+ config = dotwiz.DotWiz(config)
28
+
29
+ input_feature_dimensions = config.input_feature_dimensions
30
+ attention_dimension = config.attention_dimension
31
+ attention_heads = config.attention_heads
32
+ positionwise_conv_kernel_size = config.positionwise_conv_kernel_size
33
+ use_scaled_positional_encoding = config.use_scaled_positional_encoding
34
+ use_macaron_style_in_conformer = config.use_macaron_style_in_conformer
35
+ use_cnn_in_conformer = config.use_cnn_in_conformer
36
+ encoder_layers = config.encoder_layers
37
+ encoder_units = config.encoder_units
38
+ encoder_normalize_before = config.encoder_normalize_before
39
+ encoder_concat_after = config.encoder_concat_after
40
+ conformer_encoder_kernel_size = config.conformer_encoder_kernel_size
41
+ transformer_enc_dropout_rate = config.transformer_enc_dropout_rate
42
+ transformer_enc_positional_dropout_rate = config.transformer_enc_positional_dropout_rate
43
+ transformer_enc_attn_dropout_rate = config.transformer_enc_attn_dropout_rate
44
+ decoder_layers = config.decoder_layers
45
+ decoder_units = config.decoder_units
46
+ decoder_concat_after = config.decoder_concat_after
47
+ conformer_decoder_kernel_size = config.conformer_decoder_kernel_size
48
+ decoder_normalize_before = config.decoder_normalize_before
49
+ transformer_dec_dropout_rate = config.transformer_dec_dropout_rate
50
+ transformer_dec_positional_dropout_rate = config.transformer_dec_positional_dropout_rate
51
+ transformer_dec_attn_dropout_rate = config.transformer_dec_attn_dropout_rate
52
+ duration_predictor_layers = config.duration_predictor_layers
53
+ duration_predictor_kernel_size = config.duration_predictor_kernel_size
54
+ duration_predictor_dropout_rate = config.duration_predictor_dropout_rate
55
+ pitch_predictor_layers = config.pitch_predictor_layers
56
+ pitch_predictor_kernel_size = config.pitch_predictor_kernel_size
57
+ pitch_predictor_dropout = config.pitch_predictor_dropout
58
+ pitch_embed_kernel_size = config.pitch_embed_kernel_size
59
+ pitch_embed_dropout = config.pitch_embed_dropout
60
+ energy_predictor_layers = config.energy_predictor_layers
61
+ energy_predictor_kernel_size = config.energy_predictor_kernel_size
62
+ energy_predictor_dropout = config.energy_predictor_dropout
63
+ energy_embed_kernel_size = config.energy_embed_kernel_size
64
+ energy_embed_dropout = config.energy_embed_dropout
65
+ utt_embed_dim = config.utt_embed_dim
66
+ lang_embs = config.lang_embs
67
+ embedding_integration = config.embedding_integration
68
+ glow_kernel_size = config.glow_kernel_size
69
+ glow_blocks = config.glow_blocks
70
+ glow_layers = config.glow_layers
71
+ lang_emb_size = config.lang_emb_size
72
+ integrate_language_embedding_into_encoder_out = config.integrate_language_embedding_into_encoder_out
73
+
74
+ self.input_feature_dimensions = input_feature_dimensions
75
+ self.attention_dimension = attention_dimension
76
+ self.use_scaled_pos_enc = use_scaled_positional_encoding
77
+ self.multilingual_model = lang_embs is not None
78
+ self.multispeaker_model = utt_embed_dim is not None
79
+ self.integrate_language_embedding_into_encoder_out = integrate_language_embedding_into_encoder_out
80
+ self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"]
81
+
82
+ articulatory_feature_embedding = Sequential(Linear(input_feature_dimensions, 100), Tanh(), Linear(100, attention_dimension))
83
+ self.encoder = Conformer(conformer_type="encoder",
84
+ attention_dim=attention_dimension,
85
+ attention_heads=attention_heads,
86
+ linear_units=encoder_units,
87
+ num_blocks=encoder_layers,
88
+ input_layer=articulatory_feature_embedding,
89
+ dropout_rate=transformer_enc_dropout_rate,
90
+ positional_dropout_rate=transformer_enc_positional_dropout_rate,
91
+ attention_dropout_rate=transformer_enc_attn_dropout_rate,
92
+ normalize_before=encoder_normalize_before,
93
+ concat_after=encoder_concat_after,
94
+ positionwise_conv_kernel_size=positionwise_conv_kernel_size,
95
+ macaron_style=use_macaron_style_in_conformer,
96
+ use_cnn_module=True,
97
+ cnn_module_kernel=conformer_encoder_kernel_size,
98
+ zero_triu=False,
99
+ utt_embed=utt_embed_dim,
100
+ lang_embs=lang_embs,
101
+ lang_emb_size=lang_emb_size,
102
+ use_output_norm=True,
103
+ embedding_integration=embedding_integration)
104
+
105
+ if self.integrate_language_embedding_into_encoder_out:
106
+ if embedding_integration == "AdaIN":
107
+ self.language_embedding_infusion = AdaIN1d(style_dim=lang_emb_size, num_features=attention_dimension)
108
+ elif embedding_integration == "ConditionalLayerNorm":
109
+ self.language_embedding_infusion = ConditionalLayerNorm(speaker_embedding_dim=lang_emb_size, hidden_dim=attention_dimension)
110
+ else:
111
+ self.language_embedding_infusion = torch.nn.Linear(attention_dimension + lang_emb_size, attention_dimension)
112
+
113
+ self.duration_predictor = DurationPredictor(idim=attention_dimension,
114
+ n_layers=duration_predictor_layers,
115
+ n_chans=attention_dimension,
116
+ kernel_size=duration_predictor_kernel_size,
117
+ dropout_rate=duration_predictor_dropout_rate,
118
+ utt_embed_dim=utt_embed_dim,
119
+ embedding_integration=embedding_integration)
120
+
121
+ self.pitch_predictor = VariancePredictor(idim=attention_dimension,
122
+ n_layers=pitch_predictor_layers,
123
+ n_chans=attention_dimension,
124
+ kernel_size=pitch_predictor_kernel_size,
125
+ dropout_rate=pitch_predictor_dropout,
126
+ utt_embed_dim=utt_embed_dim,
127
+ embedding_integration=embedding_integration)
128
+
129
+ self.energy_predictor = VariancePredictor(idim=attention_dimension,
130
+ n_layers=energy_predictor_layers,
131
+ n_chans=attention_dimension,
132
+ kernel_size=energy_predictor_kernel_size,
133
+ dropout_rate=energy_predictor_dropout,
134
+ utt_embed_dim=utt_embed_dim,
135
+ embedding_integration=embedding_integration)
136
+
137
+ self.pitch_embed = Sequential(torch.nn.Conv1d(in_channels=1,
138
+ out_channels=attention_dimension,
139
+ kernel_size=pitch_embed_kernel_size,
140
+ padding=(pitch_embed_kernel_size - 1) // 2),
141
+ torch.nn.Dropout(pitch_embed_dropout))
142
+
143
+ self.energy_embed = Sequential(torch.nn.Conv1d(in_channels=1,
144
+ out_channels=attention_dimension,
145
+ kernel_size=energy_embed_kernel_size,
146
+ padding=(energy_embed_kernel_size - 1) // 2),
147
+ torch.nn.Dropout(energy_embed_dropout))
148
+
149
+ self.length_regulator = LengthRegulator()
150
+
151
+ self.decoder = Conformer(conformer_type="decoder",
152
+ attention_dim=attention_dimension,
153
+ attention_heads=attention_heads,
154
+ linear_units=decoder_units,
155
+ num_blocks=decoder_layers,
156
+ input_layer=None,
157
+ dropout_rate=transformer_dec_dropout_rate,
158
+ positional_dropout_rate=transformer_dec_positional_dropout_rate,
159
+ attention_dropout_rate=transformer_dec_attn_dropout_rate,
160
+ normalize_before=decoder_normalize_before,
161
+ concat_after=decoder_concat_after,
162
+ positionwise_conv_kernel_size=positionwise_conv_kernel_size,
163
+ macaron_style=use_macaron_style_in_conformer,
164
+ use_cnn_module=use_cnn_in_conformer,
165
+ cnn_module_kernel=conformer_decoder_kernel_size,
166
+ use_output_norm=not embedding_integration in ["AdaIN", "ConditionalLayerNorm"],
167
+ utt_embed=utt_embed_dim,
168
+ embedding_integration=embedding_integration)
169
+
170
+ self.output_projection = torch.nn.Linear(attention_dimension, 128)
171
+
172
+ self.post_flow = Glow(
173
+ in_channels=128,
174
+ hidden_channels=attention_dimension, # post_glow_hidden
175
+ kernel_size=glow_kernel_size, # post_glow_kernel_size
176
+ dilation_rate=1,
177
+ n_blocks=glow_blocks, # post_glow_n_blocks (original 12 in paper)
178
+ n_layers=glow_layers, # post_glow_n_block_layers (original 3 in paper)
179
+ n_split=4,
180
+ n_sqz=2,
181
+ text_condition_channels=attention_dimension,
182
+ share_cond_layers=False, # post_share_cond_layers
183
+ share_wn_layers=4,
184
+ sigmoid_scale=False,
185
+ condition_integration_projection=torch.nn.Conv1d(128 + attention_dimension, attention_dimension, 5, padding=2)
186
+ )
187
+
188
+ self.load_state_dict(weights)
189
+ self.eval()
190
+
191
+ def _forward(self,
192
+ text_tensors,
193
+ text_lengths,
194
+ gold_durations=None,
195
+ gold_pitch=None,
196
+ gold_energy=None,
197
+ duration_scaling_factor=1.0,
198
+ utterance_embedding=None,
199
+ lang_ids=None,
200
+ pitch_variance_scale=1.0,
201
+ energy_variance_scale=1.0,
202
+ pause_duration_scaling_factor=1.0,
203
+ glow_sampling_temperature=0.2):
204
+
205
+ if not self.multilingual_model:
206
+ lang_ids = None
207
+
208
+ if not self.multispeaker_model:
209
+ utterance_embedding = None
210
+ else:
211
+ utterance_embedding = torch.nn.functional.normalize(utterance_embedding)
212
+
213
+ # encoding the texts
214
+ text_masks = make_non_pad_mask(text_lengths, device=text_lengths.device).unsqueeze(-2)
215
+ encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids)
216
+
217
+ if self.integrate_language_embedding_into_encoder_out:
218
+ lang_embs = self.encoder.language_embedding(lang_ids).squeeze(-1).detach()
219
+ encoded_texts = integrate_with_utt_embed(hs=encoded_texts, utt_embeddings=lang_embs, projection=self.language_embedding_infusion, embedding_training=self.use_conditional_layernorm_embedding_integration)
220
+
221
+ # predicting pitch, energy and durations
222
+ pitch_predictions = self.pitch_predictor(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_pitch is None else gold_pitch
223
+ energy_predictions = self.energy_predictor(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_energy is None else gold_energy
224
+ predicted_durations = self.duration_predictor.inference(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_durations is None else gold_durations
225
+
226
+ # modifying the predictions with control parameters
227
+ for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)):
228
+ if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1:
229
+ predicted_durations[0][phoneme_index] = 0
230
+ if phoneme_vector[get_feature_to_index_lookup()["silence"]] == 1 and pause_duration_scaling_factor != 1.0:
231
+ predicted_durations[0][phoneme_index] = torch.round(predicted_durations[0][phoneme_index].float() * pause_duration_scaling_factor).long()
232
+ if duration_scaling_factor != 1.0:
233
+ assert duration_scaling_factor > 0
234
+ predicted_durations = torch.round(predicted_durations.float() * duration_scaling_factor).long()
235
+ pitch_predictions = make_near_zero_to_zero(pitch_predictions.squeeze(0)).unsqueeze(0)
236
+ energy_predictions = make_near_zero_to_zero(energy_predictions.squeeze(0)).unsqueeze(0)
237
+ pitch_predictions = _scale_variance(pitch_predictions, pitch_variance_scale)
238
+ energy_predictions = _scale_variance(energy_predictions, energy_variance_scale)
239
+
240
+ # enriching the text with pitch and energy info
241
+ embedded_pitch_curve = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2)
242
+ embedded_energy_curve = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2)
243
+ enriched_encoded_texts = encoded_texts + embedded_pitch_curve + embedded_energy_curve
244
+
245
+ # predicting durations for text and upsampling accordingly
246
+ upsampled_enriched_encoded_texts = self.length_regulator(enriched_encoded_texts, predicted_durations)
247
+
248
+ # decoding spectrogram
249
+ decoded_speech, _ = self.decoder(upsampled_enriched_encoded_texts, None, utterance_embedding=utterance_embedding)
250
+
251
+ frames = self.output_projection(decoded_speech)
252
+
253
+ refined_codec_frames = self.post_flow(tgt_mels=None, infer=True, mel_out=frames, encoded_texts=upsampled_enriched_encoded_texts, tgt_nonpadding=None, glow_sampling_temperature=glow_sampling_temperature)
254
+
255
+ return refined_codec_frames, predicted_durations.squeeze(), pitch_predictions.squeeze(), energy_predictions.squeeze()
256
+
257
+ @torch.inference_mode()
258
+ def forward(self,
259
+ text,
260
+ durations=None,
261
+ pitch=None,
262
+ energy=None,
263
+ utterance_embedding=None,
264
+ return_duration_pitch_energy=False,
265
+ lang_id=None,
266
+ duration_scaling_factor=1.0,
267
+ pitch_variance_scale=1.0,
268
+ energy_variance_scale=1.0,
269
+ pause_duration_scaling_factor=1.0,
270
+ glow_sampling_temperature=0.2):
271
+ """
272
+ Generate the sequence of spectrogram frames given the sequence of vectorized phonemes.
273
+
274
+ Args:
275
+ text: input sequence of vectorized phonemes
276
+ durations: durations to be used (optional, if not provided, they will be predicted)
277
+ pitch: token-averaged pitch curve to be used (optional, if not provided, it will be predicted)
278
+ energy: token-averaged energy curve to be used (optional, if not provided, it will be predicted)
279
+ return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting
280
+ utterance_embedding: embedding of speaker information
281
+ lang_id: id to be fed into the embedding layer that contains language information
282
+ duration_scaling_factor: reasonable values are 0.8 < scale < 1.2.
283
+ 1.0 means no scaling happens, higher values increase durations for the whole
284
+ utterance, lower values decrease durations for the whole utterance.
285
+ pitch_variance_scale: reasonable values are 0.6 < scale < 1.4.
286
+ 1.0 means no scaling happens, higher values increase variance of the pitch curve,
287
+ lower values decrease variance of the pitch curve.
288
+ energy_variance_scale: reasonable values are 0.6 < scale < 1.4.
289
+ 1.0 means no scaling happens, higher values increase variance of the energy curve,
290
+ lower values decrease variance of the energy curve.
291
+ pause_duration_scaling_factor: reasonable values are 0.6 < scale < 1.4.
292
+ scales the durations of pauses on top of the regular duration scaling
293
+
294
+ Returns:
295
+ features spectrogram
296
+
297
+ """
298
+ # setup batch axis
299
+ text_length = torch.tensor([text.shape[0]], dtype=torch.long, device=text.device)
300
+ if durations is not None:
301
+ durations = durations.unsqueeze(0).to(text.device)
302
+ if pitch is not None:
303
+ pitch = pitch.unsqueeze(0).to(text.device)
304
+ if energy is not None:
305
+ energy = energy.unsqueeze(0).to(text.device)
306
+ if lang_id is not None:
307
+ lang_id = lang_id.to(text.device)
308
+
309
+ outs, \
310
+ predicted_durations, \
311
+ pitch_predictions, \
312
+ energy_predictions = self._forward(text.unsqueeze(0),
313
+ text_length,
314
+ gold_durations=durations,
315
+ gold_pitch=pitch,
316
+ gold_energy=energy,
317
+ utterance_embedding=utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None, lang_ids=lang_id,
318
+ duration_scaling_factor=duration_scaling_factor,
319
+ pitch_variance_scale=pitch_variance_scale,
320
+ energy_variance_scale=energy_variance_scale,
321
+ pause_duration_scaling_factor=pause_duration_scaling_factor,
322
+ glow_sampling_temperature=glow_sampling_temperature)
323
+
324
+ if return_duration_pitch_energy:
325
+ return outs.squeeze().transpose(0, 1), predicted_durations, pitch_predictions, energy_predictions
326
+ return outs.squeeze().transpose(0, 1)
327
+
328
+ def store_inverse_all(self):
329
+ def remove_weight_norm(m):
330
+ try:
331
+ torch.nn.utils.remove_weight_norm(m)
332
+ except ValueError: # this module didn't have weight norm
333
+ return
334
+ self.post_flow.store_inverse()
335
+ self.apply(remove_weight_norm)
336
+
337
+
338
+ def _scale_variance(sequence, scale):
339
+ if scale == 1.0:
340
+ return sequence
341
+ average = sequence[0][sequence[0] != 0.0].mean()
342
+ sequence = sequence - average # center sequence around 0
343
+ sequence = sequence * scale # scale the variance
344
+ sequence = sequence + average # move center back to original with changed variance
345
+ for sequence_index in range(len(sequence[0])):
346
+ if sequence[0][sequence_index] < 0.0:
347
+ sequence[0][sequence_index] = 0.0
348
+ return sequence
349
+
350
+
351
+ def smooth_time_series(matrix, n_neighbors):
352
+ """
353
+ Smooth a 2D matrix along the time axis using a moving average.
354
+
355
+ Parameters:
356
+ - matrix (torch.Tensor): Input matrix (2D tensor) representing the time series.
357
+ - n_neighbors (int): Number of neighboring rows to include in the moving average.
358
+
359
+ Returns:
360
+ - torch.Tensor: Smoothed matrix.
361
+ """
362
+ smoothed_matrix = torch.zeros_like(matrix)
363
+ for i in range(matrix.size(0)):
364
+ lower = max(0, i - n_neighbors)
365
+ upper = min(matrix.size(0), i + n_neighbors + 1)
366
+ smoothed_matrix[i] = torch.mean(matrix[lower:upper], dim=0)
367
+
368
+ return smoothed_matrix
369
+
370
+
371
+ def make_near_zero_to_zero(sequence):
372
+ for index in range(len(sequence)):
373
+ if sequence[index] < 0.2:
374
+ sequence[index] = 0.0
375
+ return sequence
Architectures/ToucanTTS/LanguageEmbeddingSpaceStructureLoss.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import pickle
3
+
4
+ import torch
5
+
6
+ from Preprocessing.multilinguality.create_distance_lookups import CacheCreator
7
+ from Utility.utils import load_json_from_path
8
+
9
+
10
+ class LanguageEmbeddingSpaceStructureLoss(torch.nn.Module):
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+ cc = CacheCreator(cache_root="Preprocessing/multilinguality")
15
+ if not os.path.exists('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json'):
16
+ cc.create_tree_cache(cache_root="Preprocessing/multilinguality")
17
+ if not os.path.exists('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json'):
18
+ cc.create_map_cache(cache_root="Preprocessing/multilinguality")
19
+ if not os.path.exists("Preprocessing/multilinguality/asp_dict.pkl"):
20
+ print("download asp file") # TODO downloader script with release
21
+
22
+ self.tree_dist = load_json_from_path('Preprocessing/multilinguality/lang_1_to_lang_2_to_tree_dist.json')
23
+ self.map_dist = load_json_from_path('Preprocessing/multilinguality/lang_1_to_lang_2_to_map_dist.json')
24
+ with open("Preprocessing/multilinguality/asp_dict.pkl", 'rb') as dictfile:
25
+ self.asp_sim = pickle.load(dictfile)
26
+ self.lang_list = list(self.asp_sim.keys()) # list of all languages, to get lang_b's index
27
+
28
+ self.largest_value_map_dist = 0.0
29
+ for _, values in self.map_dist.items():
30
+ for _, value in values.items():
31
+ self.largest_value_map_dist = max(self.largest_value_map_dist, value)
32
+
33
+ self.iso_codes_to_ids = load_json_from_path("Preprocessing/multilinguality/iso_lookup.json")[-1]
34
+ self.ids_to_iso_codes = {v: k for k, v in self.iso_codes_to_ids.items()}
35
+
36
+ def forward(self, language_ids, language_embeddings):
37
+ """
38
+ Args:
39
+ language_ids (Tensor): IDs of languages in the same order as the embeddings to calculate the distances according to the metrics.
40
+ language_embeddings (Tensor): Batch of language embeddings, of which the distances will be compared to the distances according to the metrics.
41
+
42
+ Returns:
43
+ Tensor: Language Embedding Structure Loss Value
44
+ """
45
+
46
+ losses = list()
47
+ for language_id_1, language_embedding_1 in zip(language_ids, language_embeddings):
48
+ for language_id_2, language_embedding_2 in zip(language_ids, language_embeddings):
49
+ if language_id_1 != language_id_2:
50
+ embed_dist = torch.nn.functional.l1_loss(language_embedding_1, language_embedding_2)
51
+ lang_1 = self.ids_to_iso_codes[language_id_1]
52
+ lang_2 = self.ids_to_iso_codes[language_id_2]
53
+
54
+ # Value Range Normalized Tree Dist
55
+ try:
56
+ tree_dist = self.tree_dist[lang_1][lang_2]
57
+ except KeyError:
58
+ tree_dist = self.tree_dist[lang_2][lang_1]
59
+
60
+ # Value Range Normalized Map Dist
61
+ try:
62
+ map_dist = self.map_dist[lang_1][lang_2] / self.largest_value_map_dist
63
+ except KeyError:
64
+ map_dist = self.map_dist[lang_2][lang_1] / self.largest_value_map_dist
65
+
66
+ # Value Range Normalized ASP Dist
67
+ lang_2_idx = self.lang_list.index(lang_2)
68
+ asp_dist = 1.0 - self.asp_sim[lang_1][lang_2_idx] # it's a similarity measure that goes from 0 to 1, so we subtract it from 1 to turn it into a distance
69
+
70
+ # Average distance should be similar to embedding distance to bring some structure into the embedding-space
71
+ metric_distance = (torch.tensor(tree_dist) + torch.tensor(map_dist) + torch.tensor(asp_dist)) / 3
72
+ losses.append(torch.nn.functional.l1_loss(embed_dist, metric_distance))
73
+
74
+ return sum(losses) / len(losses)
Architectures/ToucanTTS/PitchCalculator.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Nagoya University (Tomoki Hayashi)
2
+ # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+ # Adapted by Florian Lux 2021
4
+
5
+ import math
6
+
7
+ import numpy as np
8
+ import parselmouth
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from scipy.interpolate import interp1d
12
+
13
+
14
+ class Parselmouth(torch.nn.Module):
15
+ """
16
+ F0 estimation with Parselmouth https://parselmouth.readthedocs.io/en/stable/index.html
17
+ """
18
+
19
+ def __init__(self, fs=16000, n_fft=1024, hop_length=256, f0min=40, f0max=600, use_token_averaged_f0=True,
20
+ use_continuous_f0=True, use_log_f0=False, reduction_factor=1):
21
+ super().__init__()
22
+ self.fs = fs
23
+ self.n_fft = n_fft
24
+ self.hop_length = hop_length
25
+ self.frame_period = 1000 * hop_length / fs
26
+ self.f0min = f0min
27
+ self.f0max = f0max
28
+ self.use_token_averaged_f0 = use_token_averaged_f0
29
+ self.use_continuous_f0 = use_continuous_f0
30
+ self.use_log_f0 = use_log_f0
31
+ if use_token_averaged_f0:
32
+ assert reduction_factor >= 1
33
+ self.reduction_factor = reduction_factor
34
+
35
+ def output_size(self):
36
+ return 1
37
+
38
+ def get_parameters(self):
39
+ return dict(fs=self.fs, n_fft=self.n_fft, hop_length=self.hop_length, f0min=self.f0min, f0max=self.f0max,
40
+ use_token_averaged_f0=self.use_token_averaged_f0, use_continuous_f0=self.use_continuous_f0, use_log_f0=self.use_log_f0,
41
+ reduction_factor=self.reduction_factor)
42
+
43
+ def forward(self, input_waves, input_waves_lengths=None, feats_lengths=None, durations=None,
44
+ durations_lengths=None, norm_by_average=True, text=None):
45
+
46
+ # F0 extraction
47
+ pitch = self._calculate_f0(input_waves[0])
48
+
49
+ # Adjust length to match with the feature sequences
50
+ pitch = self._adjust_num_frames(pitch, feats_lengths[0]).view(-1)
51
+
52
+ pitch = self._average_by_duration(pitch, durations[0], text).view(-1)
53
+ pitch_lengths = durations_lengths
54
+
55
+ if norm_by_average:
56
+ average = pitch[pitch != 0.0].mean()
57
+ pitch = pitch / average
58
+
59
+ # Return with the shape (B, T, 1)
60
+ return pitch.unsqueeze(-1), pitch_lengths
61
+
62
+ def _calculate_f0(self, input):
63
+ x = input.cpu().numpy().astype(np.double)
64
+ snd = parselmouth.Sound(values=x, sampling_frequency=self.fs)
65
+ f0 = snd.to_pitch(time_step=self.hop_length / self.fs, pitch_floor=self.f0min, pitch_ceiling=self.f0max).selected_array['frequency']
66
+ if self.use_continuous_f0:
67
+ f0 = self._convert_to_continuous_f0(f0)
68
+ if self.use_log_f0:
69
+ nonzero_idxs = np.where(f0 != 0)[0]
70
+ f0[nonzero_idxs] = np.log(f0[nonzero_idxs])
71
+ return input.new_tensor(f0.reshape(-1), dtype=torch.float)
72
+
73
+ @staticmethod
74
+ def _adjust_num_frames(x, num_frames):
75
+ if num_frames > len(x):
76
+ # x = F.pad(x, (0, num_frames - len(x)))
77
+ x = F.pad(x, (math.ceil((num_frames - len(x)) / 2), math.floor((num_frames - len(x)) / 2)))
78
+ elif num_frames < len(x):
79
+ x = x[:num_frames]
80
+ return x
81
+
82
+ @staticmethod
83
+ def _convert_to_continuous_f0(f0: np.array):
84
+ if (f0 == 0).all():
85
+ return f0
86
+
87
+ # padding start and end of f0 sequence
88
+ start_f0 = f0[f0 != 0][0]
89
+ end_f0 = f0[f0 != 0][-1]
90
+ start_idx = np.where(f0 == start_f0)[0][0]
91
+ end_idx = np.where(f0 == end_f0)[0][-1]
92
+ f0[:start_idx] = start_f0
93
+ f0[end_idx:] = end_f0
94
+
95
+ # get non-zero frame index
96
+ nonzero_idxs = np.where(f0 != 0)[0]
97
+
98
+ # perform linear interpolation
99
+ interp_fn = interp1d(nonzero_idxs, f0[nonzero_idxs])
100
+ f0 = interp_fn(np.arange(0, f0.shape[0]))
101
+
102
+ return f0
103
+
104
+ def _average_by_duration(self, x, d, text=None):
105
+ d_cumsum = F.pad(d.cumsum(dim=0), (1, 0))
106
+ x_avg = [
107
+ x[start:end].masked_select(x[start:end].gt(0.0)).mean(dim=0) if len(x[start:end].masked_select(x[start:end].gt(0.0))) != 0 else x.new_tensor(0.0)
108
+ for start, end in zip(d_cumsum[:-1], d_cumsum[1:])]
109
+
110
+ # find tokens that are not voiced and set pitch to 0
111
+ # while this makes sense, it makes it harder for the model to learn, so we leave this out now.
112
+ # if text is not None:
113
+ # for i, vector in enumerate(text):
114
+ # if vector[get_feature_to_index_lookup()["voiced"]] == 0:
115
+ # x_avg[i] = torch.tensor(0.0, device=x.device)
116
+
117
+ return torch.stack(x_avg)