ruslanmv commited on
Commit
08d5f37
·
1 Parent(s): 3bfbeda

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +19 -0
  2. README.md +2 -1
  3. app.py +234 -0
  4. encoder/__init__.py +0 -0
  5. encoder/__pycache__/__init__.cpython-38.pyc +0 -0
  6. encoder/__pycache__/audio.cpython-38.pyc +0 -0
  7. encoder/__pycache__/inference.cpython-38.pyc +0 -0
  8. encoder/__pycache__/model.cpython-38.pyc +0 -0
  9. encoder/__pycache__/params_data.cpython-38.pyc +0 -0
  10. encoder/__pycache__/params_model.cpython-38.pyc +0 -0
  11. encoder/audio.py +117 -0
  12. encoder/config.py +45 -0
  13. encoder/data_objects/__init__.py +2 -0
  14. encoder/data_objects/random_cycler.py +37 -0
  15. encoder/data_objects/speaker.py +40 -0
  16. encoder/data_objects/speaker_batch.py +13 -0
  17. encoder/data_objects/speaker_verification_dataset.py +56 -0
  18. encoder/data_objects/utterance.py +26 -0
  19. encoder/inference.py +178 -0
  20. encoder/model.py +135 -0
  21. encoder/params_data.py +29 -0
  22. encoder/params_model.py +11 -0
  23. encoder/preprocess.py +184 -0
  24. encoder/train.py +125 -0
  25. encoder/visualizations.py +179 -0
  26. requirements.txt +0 -0
  27. synthesizer/LICENSE.txt +24 -0
  28. synthesizer/__init__.py +1 -0
  29. synthesizer/__pycache__/__init__.cpython-38.pyc +0 -0
  30. synthesizer/__pycache__/audio.cpython-38.pyc +0 -0
  31. synthesizer/__pycache__/hparams.cpython-38.pyc +0 -0
  32. synthesizer/__pycache__/inference.cpython-38.pyc +0 -0
  33. synthesizer/audio.py +206 -0
  34. synthesizer/hparams.py +92 -0
  35. synthesizer/inference.py +165 -0
  36. synthesizer/models/__pycache__/tacotron.cpython-38.pyc +0 -0
  37. synthesizer/models/tacotron.py +519 -0
  38. synthesizer/preprocess.py +258 -0
  39. synthesizer/synthesize.py +92 -0
  40. synthesizer/synthesizer_dataset.py +92 -0
  41. synthesizer/train.py +258 -0
  42. synthesizer/utils/__init__.py +45 -0
  43. synthesizer/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  44. synthesizer/utils/__pycache__/cleaners.cpython-38.pyc +0 -0
  45. synthesizer/utils/__pycache__/numbers.cpython-38.pyc +0 -0
  46. synthesizer/utils/__pycache__/symbols.cpython-38.pyc +0 -0
  47. synthesizer/utils/__pycache__/text.cpython-38.pyc +0 -0
  48. synthesizer/utils/_cmudict.py +62 -0
  49. synthesizer/utils/cleaners.py +88 -0
  50. synthesizer/utils/numbers.py +69 -0
.gitattributes CHANGED
@@ -29,3 +29,22 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.pyc
33
+ *.aux
34
+ *.log
35
+ *.out
36
+ *.synctex.gz
37
+ *.suo
38
+ *__pycache__
39
+ *.idea
40
+ *.ipynb_checkpoints
41
+ *.pickle
42
+ *.npy
43
+ *.blg
44
+ *.bbl
45
+ *.bcf
46
+ *.toc
47
+ *.sh
48
+ encoder/saved_models/*
49
+ synthesizer/saved_models/*
50
+ vocoder/saved_models/*
README.md CHANGED
@@ -3,8 +3,9 @@ title: Clone Your Voice
3
  emoji: 📚
4
  colorFrom: blue
5
  colorTo: yellow
 
6
  sdk: gradio
7
- sdk_version: 3.2
8
  app_file: app.py
9
  pinned: false
10
  ---
 
3
  emoji: 📚
4
  colorFrom: blue
5
  colorTo: yellow
6
+ python_version: 3.8.9
7
  sdk: gradio
8
+ sdk_version: 3.0.4
9
  app_file: app.py
10
  pinned: false
11
  ---
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ import os
5
+ import string
6
+ import numpy as np
7
+ import IPython
8
+ from IPython.display import Audio
9
+ import torch
10
+ import argparse
11
+ import os
12
+ from pathlib import Path
13
+ import librosa
14
+ import numpy as np
15
+ import soundfile as sf
16
+ import torch
17
+ from encoder import inference as encoder
18
+ from encoder.params_model import model_embedding_size as speaker_embedding_size
19
+ from synthesizer.inference import Synthesizer
20
+ from utils.argutils import print_args
21
+ from utils.default_models import ensure_default_models
22
+ from vocoder import inference as vocoder
23
+ import sounddevice as sd
24
+ parser = argparse.ArgumentParser(
25
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
26
+ )
27
+ parser.add_argument("-e", "--enc_model_fpath", type=Path,
28
+ default="saved_models/default/encoder.pt",
29
+ help="Path to a saved encoder")
30
+ parser.add_argument("-s", "--syn_model_fpath", type=Path,
31
+ default="saved_models/default/synthesizer.pt",
32
+ help="Path to a saved synthesizer")
33
+ parser.add_argument("-v", "--voc_model_fpath", type=Path,
34
+ default="saved_models/default/vocoder.pt",
35
+ help="Path to a saved vocoder")
36
+ parser.add_argument("--cpu", action="store_true", help=\
37
+ "If True, processing is done on CPU, even when a GPU is available.")
38
+ parser.add_argument("--no_sound", action="store_true", help=\
39
+ "If True, audio won't be played.")
40
+ parser.add_argument("--seed", type=int, default=None, help=\
41
+ "Optional random number seed value to make toolbox deterministic.")
42
+ args = parser.parse_args()
43
+ arg_dict = vars(args)
44
+ print_args(args, parser)
45
+
46
+ # Hide GPUs from Pytorch to force CPU processing
47
+ if arg_dict.pop("cpu"):
48
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
49
+
50
+ print("Running a test of your configuration...\n")
51
+
52
+ if torch.cuda.is_available():
53
+ device_id = torch.cuda.current_device()
54
+ gpu_properties = torch.cuda.get_device_properties(device_id)
55
+ ## Print some environment information (for debugging purposes)
56
+ print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
57
+ "%.1fGb total memory.\n" %
58
+ (torch.cuda.device_count(),
59
+ device_id,
60
+ gpu_properties.name,
61
+ gpu_properties.major,
62
+ gpu_properties.minor,
63
+ gpu_properties.total_memory / 1e9))
64
+ else:
65
+ print("Using CPU for inference.\n")
66
+
67
+ ## Load the models one by one.
68
+ print("Preparing the encoder, the synthesizer and the vocoder...")
69
+ ensure_default_models(Path("saved_models"))
70
+ encoder.load_model(args.enc_model_fpath)
71
+ synthesizer = Synthesizer(args.syn_model_fpath)
72
+ vocoder.load_model(args.voc_model_fpath)
73
+
74
+ def compute_embedding(in_fpath):
75
+ ## Computing the embedding
76
+ # First, we load the wav using the function that the speaker encoder provides. This is
77
+ # important: there is preprocessing that must be applied.
78
+
79
+ # The following two methods are equivalent:
80
+ # - Directly load from the filepath:
81
+ preprocessed_wav = encoder.preprocess_wav(in_fpath)
82
+ # - If the wav is already loaded:
83
+ original_wav, sampling_rate = librosa.load(str(in_fpath))
84
+ preprocessed_wav = encoder.preprocess_wav(original_wav, sampling_rate)
85
+ print("Loaded file succesfully")
86
+
87
+ # Then we derive the embedding. There are many functions and parameters that the
88
+ # speaker encoder interfaces. These are mostly for in-depth research. You will typically
89
+ # only use this function (with its default parameters):
90
+ embed = encoder.embed_utterance(preprocessed_wav)
91
+
92
+ return embed
93
+ def create_spectrogram(text,embed, synthesizer ):
94
+ # If seed is specified, reset torch seed and force synthesizer reload
95
+ if args.seed is not None:
96
+ torch.manual_seed(args.seed)
97
+ synthesizer = Synthesizer(args.syn_model_fpath)
98
+ # The synthesizer works in batch, so you need to put your data in a list or numpy array
99
+ texts = [text]
100
+ embeds = [embed]
101
+ # If you know what the attention layer alignments are, you can retrieve them here by
102
+ # passing return_alignments=True
103
+ specs = synthesizer.synthesize_spectrograms(texts, embeds)
104
+ spec = specs[0]
105
+ return spec
106
+
107
+ def generate_waveform(spec):
108
+ ## Generating the waveform
109
+ print("Synthesizing the waveform:")
110
+ # If seed is specified, reset torch seed and reload vocoder
111
+ if args.seed is not None:
112
+ torch.manual_seed(args.seed)
113
+ vocoder.load_model(args.voc_model_fpath)
114
+ # Synthesizing the waveform is fairly straightforward. Remember that the longer the
115
+ # spectrogram, the more time-efficient the vocoder.
116
+ generated_wav = vocoder.infer_waveform(spec)
117
+
118
+ ## Post-generation
119
+ # There's a bug with sounddevice that makes the audio cut one second earlier, so we
120
+ # pad it.
121
+ generated_wav = np.pad(generated_wav, (0, synthesizer.sample_rate), mode="constant")
122
+
123
+ # Trim excess silences to compensate for gaps in spectrograms (issue #53)
124
+ generated_wav = encoder.preprocess_wav(generated_wav)
125
+ return generated_wav
126
+
127
+
128
+ def save_on_disk(generated_wav,synthesizer):
129
+ # Save it on the disk
130
+ filename = "cloned_voice.wav"
131
+ print(generated_wav.dtype)
132
+ #OUT=os.environ['OUT_PATH']
133
+ # Returns `None` if key doesn't exist
134
+ #OUT=os.environ.get('OUT_PATH')
135
+ #result = os.path.join(OUT, filename)
136
+ result = filename
137
+ print(" > Saving output to {}".format(result))
138
+ sf.write(result, generated_wav.astype(np.float32), synthesizer.sample_rate)
139
+ print("\nSaved output as %s\n\n" % result)
140
+
141
+ return result
142
+ def play_audio(generated_wav,synthesizer):
143
+ # Play the audio (non-blocking)
144
+ if not args.no_sound:
145
+
146
+ try:
147
+ sd.stop()
148
+ sd.play(generated_wav, synthesizer.sample_rate)
149
+ except sd.PortAudioError as e:
150
+ print("\nCaught exception: %s" % repr(e))
151
+ print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
152
+ except:
153
+ raise
154
+
155
+ def clone_voice(in_fpath, text,synthesizer):
156
+ try:
157
+ # Compute embedding
158
+ embed=compute_embedding(in_fpath)
159
+ print("Created the embedding")
160
+ # Generating the spectrogram
161
+ spec = create_spectrogram(text,embed,synthesizer)
162
+ print("Created the mel spectrogram")
163
+
164
+ # Create waveform
165
+ generated_wav=generate_waveform(spec)
166
+ print("Created the the waveform ")
167
+
168
+ # Save it on the disk
169
+ save_on_disk(generated_wav,synthesizer)
170
+
171
+ #Play the audio
172
+ play_audio(generated_wav,synthesizer)
173
+
174
+ return
175
+ except Exception as e:
176
+ print("Caught exception: %s" % repr(e))
177
+ print("Restarting\n")
178
+
179
+ # Set environment variables
180
+ home_dir = os.getcwd()
181
+ OUT_PATH=os.path.join(home_dir, "out/")
182
+ os.environ['OUT_PATH'] = OUT_PATH
183
+
184
+ # create output path
185
+ os.makedirs(OUT_PATH, exist_ok=True)
186
+
187
+ USE_CUDA = torch.cuda.is_available()
188
+
189
+ os.system('pip install -q pydub ffmpeg-normalize')
190
+ CONFIG_SE_PATH = "config_se.json"
191
+ CHECKPOINT_SE_PATH = "SE_checkpoint.pth.tar"
192
+ def greet(Text,Voicetoclone):
193
+ text= "%s" % (Text)
194
+ #reference_files= "%s" % (Voicetoclone)
195
+ reference_files= Voicetoclone
196
+ print("path url")
197
+ print(Voicetoclone)
198
+ sample= str(Voicetoclone)
199
+ os.environ['sample'] = sample
200
+ size= len(reference_files)*sys.getsizeof(reference_files)
201
+ size2= size / 1000000
202
+ if (size2 > 0.012) or len(text)>2000:
203
+ message="File is greater than 30mb or Text inserted is longer than 2000 characters. Please re-try with smaller sizes."
204
+ print(message)
205
+ raise SystemExit("File is greater than 30mb. Please re-try or Text inserted is longer than 2000 characters. Please re-try with smaller sizes.")
206
+ else:
207
+
208
+ env_var = 'sample'
209
+ if env_var in os.environ:
210
+ print(f'{env_var} value is {os.environ[env_var]}')
211
+ else:
212
+ print(f'{env_var} does not exist')
213
+ #os.system(f'ffmpeg-normalize {os.environ[env_var]} -nt rms -t=-27 -o {os.environ[env_var]} -ar 16000 -f')
214
+ in_fpath = Path(sample)
215
+ #in_fpath= in_fpath.replace("\"", "").replace("\'", "")
216
+
217
+ out_path=clone_voice(in_fpath, text,synthesizer)
218
+
219
+ print(" > text: {}".format(text))
220
+
221
+ print("Generated Audio")
222
+ return "cloned_voice.wav"
223
+
224
+ demo = gr.Interface(
225
+ fn=greet,
226
+ inputs=[gr.inputs.Textbox(label='What would you like the voice to say? (max. 2000 characters per request)'),
227
+ gr.Audio(
228
+ type="filepath",
229
+ source="upload",
230
+ label='Please upload a voice to clone (max. 30mb)')
231
+ ],
232
+ outputs="audio",
233
+ )
234
+ demo.launch()
encoder/__init__.py ADDED
File without changes
encoder/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (200 Bytes). View file
 
encoder/__pycache__/audio.cpython-38.pyc ADDED
Binary file (4.05 kB). View file
 
encoder/__pycache__/inference.cpython-38.pyc ADDED
Binary file (7.25 kB). View file
 
encoder/__pycache__/model.cpython-38.pyc ADDED
Binary file (4.84 kB). View file
 
encoder/__pycache__/params_data.cpython-38.pyc ADDED
Binary file (507 Bytes). View file
 
encoder/__pycache__/params_model.cpython-38.pyc ADDED
Binary file (387 Bytes). View file
 
encoder/audio.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ from warnings import warn
6
+ import numpy as np
7
+ import librosa
8
+ import struct
9
+
10
+ try:
11
+ import webrtcvad
12
+ except:
13
+ warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
14
+ webrtcvad=None
15
+
16
+ int16_max = (2 ** 15) - 1
17
+
18
+
19
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
20
+ source_sr: Optional[int] = None,
21
+ normalize: Optional[bool] = True,
22
+ trim_silence: Optional[bool] = True):
23
+ """
24
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
25
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
26
+
27
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
28
+ just .wav), either the waveform as a numpy array of floats.
29
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
30
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
31
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
32
+ this argument will be ignored.
33
+ """
34
+ # Load the wav from disk if needed
35
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
36
+ wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
37
+ else:
38
+ wav = fpath_or_wav
39
+
40
+ # Resample the wav if needed
41
+ if source_sr is not None and source_sr != sampling_rate:
42
+ wav = librosa.resample(wav, source_sr, sampling_rate)
43
+
44
+ # Apply the preprocessing: normalize volume and shorten long silences
45
+ if normalize:
46
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
47
+ if webrtcvad and trim_silence:
48
+ wav = trim_long_silences(wav)
49
+
50
+ return wav
51
+
52
+
53
+ def wav_to_mel_spectrogram(wav):
54
+ """
55
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
56
+ Note: this not a log-mel spectrogram.
57
+ """
58
+ frames = librosa.feature.melspectrogram(
59
+ wav,
60
+ sampling_rate,
61
+ n_fft=int(sampling_rate * mel_window_length / 1000),
62
+ hop_length=int(sampling_rate * mel_window_step / 1000),
63
+ n_mels=mel_n_channels
64
+ )
65
+ return frames.astype(np.float32).T
66
+
67
+
68
+ def trim_long_silences(wav):
69
+ """
70
+ Ensures that segments without voice in the waveform remain no longer than a
71
+ threshold determined by the VAD parameters in params.py.
72
+
73
+ :param wav: the raw waveform as a numpy array of floats
74
+ :return: the same waveform with silences trimmed away (length <= original wav length)
75
+ """
76
+ # Compute the voice detection window size
77
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
78
+
79
+ # Trim the end of the audio to have a multiple of the window size
80
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
81
+
82
+ # Convert the float waveform to 16-bit mono PCM
83
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
84
+
85
+ # Perform voice activation detection
86
+ voice_flags = []
87
+ vad = webrtcvad.Vad(mode=3)
88
+ for window_start in range(0, len(wav), samples_per_window):
89
+ window_end = window_start + samples_per_window
90
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
91
+ sample_rate=sampling_rate))
92
+ voice_flags = np.array(voice_flags)
93
+
94
+ # Smooth the voice detection with a moving average
95
+ def moving_average(array, width):
96
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
97
+ ret = np.cumsum(array_padded, dtype=float)
98
+ ret[width:] = ret[width:] - ret[:-width]
99
+ return ret[width - 1:] / width
100
+
101
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
102
+ audio_mask = np.round(audio_mask).astype(np.bool)
103
+
104
+ # Dilate the voiced regions
105
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
106
+ audio_mask = np.repeat(audio_mask, samples_per_window)
107
+
108
+ return wav[audio_mask == True]
109
+
110
+
111
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
112
+ if increase_only and decrease_only:
113
+ raise ValueError("Both increase only and decrease only are set")
114
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
115
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
116
+ return wav
117
+ return wav * (10 ** (dBFS_change / 20))
encoder/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librispeech_datasets = {
2
+ "train": {
3
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
+ "other": ["LibriSpeech/train-other-500"]
5
+ },
6
+ "test": {
7
+ "clean": ["LibriSpeech/test-clean"],
8
+ "other": ["LibriSpeech/test-other"]
9
+ },
10
+ "dev": {
11
+ "clean": ["LibriSpeech/dev-clean"],
12
+ "other": ["LibriSpeech/dev-other"]
13
+ },
14
+ }
15
+ libritts_datasets = {
16
+ "train": {
17
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
+ "other": ["LibriTTS/train-other-500"]
19
+ },
20
+ "test": {
21
+ "clean": ["LibriTTS/test-clean"],
22
+ "other": ["LibriTTS/test-other"]
23
+ },
24
+ "dev": {
25
+ "clean": ["LibriTTS/dev-clean"],
26
+ "other": ["LibriTTS/dev-other"]
27
+ },
28
+ }
29
+ voxceleb_datasets = {
30
+ "voxceleb1" : {
31
+ "train": ["VoxCeleb1/wav"],
32
+ "test": ["VoxCeleb1/test_wav"]
33
+ },
34
+ "voxceleb2" : {
35
+ "train": ["VoxCeleb2/dev/aac"],
36
+ "test": ["VoxCeleb2/test_wav"]
37
+ }
38
+ }
39
+
40
+ other_datasets = [
41
+ "LJSpeech-1.1",
42
+ "VCTK-Corpus/wav48",
43
+ ]
44
+
45
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from encoder.data_objects.speaker import Speaker
4
+
5
+
6
+ class SpeakerBatch:
7
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
8
+ self.speakers = speakers
9
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
10
+
11
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
12
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
13
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from encoder.data_objects.speaker import Speaker
4
+ from encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
encoder/inference.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_data import *
2
+ from encoder.model import SpeakerEncoder
3
+ from encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from encoder import audio
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import torch
9
+
10
+ _model = None # type: SpeakerEncoder
11
+ _device = None # type: torch.device
12
+
13
+
14
+ def load_model(weights_fpath: Path, device=None):
15
+ """
16
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
17
+ first call to embed_frames() with the default weights file.
18
+
19
+ :param weights_fpath: the path to saved model weights.
20
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
21
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
22
+ If None, will default to your GPU if it"s available, otherwise your CPU.
23
+ """
24
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
25
+ # was saved on. Worth investigating.
26
+ global _model, _device
27
+ if device is None:
28
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ elif isinstance(device, str):
30
+ _device = torch.device(device)
31
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
32
+ checkpoint = torch.load(weights_fpath, _device)
33
+ _model.load_state_dict(checkpoint["model_state"])
34
+ _model.eval()
35
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
36
+
37
+
38
+ def is_loaded():
39
+ return _model is not None
40
+
41
+
42
+ def embed_frames_batch(frames_batch):
43
+ """
44
+ Computes embeddings for a batch of mel spectrogram.
45
+
46
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
47
+ (batch_size, n_frames, n_channels)
48
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
49
+ """
50
+ if _model is None:
51
+ raise Exception("Model was not loaded. Call load_model() before inference.")
52
+
53
+ frames = torch.from_numpy(frames_batch).to(_device)
54
+ embed = _model.forward(frames).detach().cpu().numpy()
55
+ return embed
56
+
57
+
58
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
59
+ min_pad_coverage=0.75, overlap=0.5):
60
+ """
61
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
62
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
63
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
64
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
65
+ defined in params_data.py.
66
+
67
+ The returned ranges may be indexing further than the length of the waveform. It is
68
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
69
+
70
+ :param n_samples: the number of samples in the waveform
71
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
72
+ utterance
73
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
74
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
75
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
76
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
77
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
78
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
79
+ utterances are entirely disjoint.
80
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
81
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
82
+ utterances.
83
+ """
84
+ assert 0 <= overlap < 1
85
+ assert 0 < min_pad_coverage <= 1
86
+
87
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
88
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
89
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
90
+
91
+ # Compute the slices
92
+ wav_slices, mel_slices = [], []
93
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
94
+ for i in range(0, steps, frame_step):
95
+ mel_range = np.array([i, i + partial_utterance_n_frames])
96
+ wav_range = mel_range * samples_per_frame
97
+ mel_slices.append(slice(*mel_range))
98
+ wav_slices.append(slice(*wav_range))
99
+
100
+ # Evaluate whether extra padding is warranted or not
101
+ last_wav_range = wav_slices[-1]
102
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
103
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
104
+ mel_slices = mel_slices[:-1]
105
+ wav_slices = wav_slices[:-1]
106
+
107
+ return wav_slices, mel_slices
108
+
109
+
110
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
111
+ """
112
+ Computes an embedding for a single utterance.
113
+
114
+ # TODO: handle multiple wavs to benefit from batching on GPU
115
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
116
+ :param using_partials: if True, then the utterance is split in partial utterances of
117
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
118
+ normalized average. If False, the utterance is instead computed from feeding the entire
119
+ spectogram to the network.
120
+ :param return_partials: if True, the partial embeddings will also be returned along with the
121
+ wav slices that correspond to the partial embeddings.
122
+ :param kwargs: additional arguments to compute_partial_splits()
123
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
124
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
125
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
126
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
127
+ instead.
128
+ """
129
+ # Process the entire utterance if not using partials
130
+ if not using_partials:
131
+ frames = audio.wav_to_mel_spectrogram(wav)
132
+ embed = embed_frames_batch(frames[None, ...])[0]
133
+ if return_partials:
134
+ return embed, None, None
135
+ return embed
136
+
137
+ # Compute where to split the utterance into partials and pad if necessary
138
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
139
+ max_wave_length = wave_slices[-1].stop
140
+ if max_wave_length >= len(wav):
141
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
142
+
143
+ # Split the utterance into partials
144
+ frames = audio.wav_to_mel_spectrogram(wav)
145
+ frames_batch = np.array([frames[s] for s in mel_slices])
146
+ partial_embeds = embed_frames_batch(frames_batch)
147
+
148
+ # Compute the utterance embedding from the partial embeddings
149
+ raw_embed = np.mean(partial_embeds, axis=0)
150
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
151
+
152
+ if return_partials:
153
+ return embed, partial_embeds, wave_slices
154
+ return embed
155
+
156
+
157
+ def embed_speaker(wavs, **kwargs):
158
+ raise NotImplemented()
159
+
160
+
161
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
162
+ import matplotlib.pyplot as plt
163
+ if ax is None:
164
+ ax = plt.gca()
165
+
166
+ if shape is None:
167
+ height = int(np.sqrt(len(embed)))
168
+ shape = (height, -1)
169
+ embed = embed.reshape(shape)
170
+
171
+ cmap = cm.get_cmap()
172
+ mappable = ax.imshow(embed, cmap=cmap)
173
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174
+ sm = cm.ScalarMappable(cmap=cmap)
175
+ sm.set_clim(*color_range)
176
+
177
+ ax.set_xticks([]), ax.set_yticks([])
178
+ ax.set_title(title)
encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_model import *
2
+ from encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels,
19
+ hidden_size=model_hidden_size,
20
+ num_layers=model_num_layers,
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
encoder/preprocess.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from functools import partial
3
+ from multiprocessing import Pool
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from encoder import audio
10
+ from encoder.config import librispeech_datasets, anglophone_nationalites
11
+ from encoder.params_data import *
12
+
13
+
14
+ _AUDIO_EXTENSIONS = ("wav", "flac", "m4a", "mp3")
15
+
16
+ class DatasetLog:
17
+ """
18
+ Registers metadata about the dataset in a text file.
19
+ """
20
+ def __init__(self, root, name):
21
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
22
+ self.sample_data = dict()
23
+
24
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
25
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
26
+ self.write_line("-----")
27
+ self._log_params()
28
+
29
+ def _log_params(self):
30
+ from encoder import params_data
31
+ self.write_line("Parameter values:")
32
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
33
+ value = getattr(params_data, param_name)
34
+ self.write_line("\t%s: %s" % (param_name, value))
35
+ self.write_line("-----")
36
+
37
+ def write_line(self, line):
38
+ self.text_file.write("%s\n" % line)
39
+
40
+ def add_sample(self, **kwargs):
41
+ for param_name, value in kwargs.items():
42
+ if not param_name in self.sample_data:
43
+ self.sample_data[param_name] = []
44
+ self.sample_data[param_name].append(value)
45
+
46
+ def finalize(self):
47
+ self.write_line("Statistics:")
48
+ for param_name, values in self.sample_data.items():
49
+ self.write_line("\t%s:" % param_name)
50
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
51
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
52
+ self.write_line("-----")
53
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
54
+ self.write_line("Finished on %s" % end_time)
55
+ self.text_file.close()
56
+
57
+
58
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
59
+ dataset_root = datasets_root.joinpath(dataset_name)
60
+ if not dataset_root.exists():
61
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
62
+ return None, None
63
+ return dataset_root, DatasetLog(out_dir, dataset_name)
64
+
65
+
66
+ def _preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, skip_existing: bool):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ audio_durs = []
90
+ for extension in _AUDIO_EXTENSIONS:
91
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
92
+ # Check if the target output file already exists
93
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
94
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
95
+ if skip_existing and out_fname in existing_fnames:
96
+ continue
97
+
98
+ # Load and preprocess the waveform
99
+ wav = audio.preprocess_wav(in_fpath)
100
+ if len(wav) == 0:
101
+ continue
102
+
103
+ # Create the mel spectrogram, discard those that are too short
104
+ frames = audio.wav_to_mel_spectrogram(wav)
105
+ if len(frames) < partials_n_frames:
106
+ continue
107
+
108
+ out_fpath = speaker_out_dir.joinpath(out_fname)
109
+ np.save(out_fpath, frames)
110
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
111
+ audio_durs.append(len(wav) / sampling_rate)
112
+
113
+ sources_file.close()
114
+
115
+ return audio_durs
116
+
117
+
118
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger):
119
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
120
+
121
+ # Process the utterances for each speaker
122
+ work_fn = partial(_preprocess_speaker, datasets_root=datasets_root, out_dir=out_dir, skip_existing=skip_existing)
123
+ with Pool(4) as pool:
124
+ tasks = pool.imap(work_fn, speaker_dirs)
125
+ for sample_durs in tqdm(tasks, dataset_name, len(speaker_dirs), unit="speakers"):
126
+ for sample_dur in sample_durs:
127
+ logger.add_sample(duration=sample_dur)
128
+
129
+ logger.finalize()
130
+ print("Done preprocessing %s.\n" % dataset_name)
131
+
132
+
133
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
134
+ for dataset_name in librispeech_datasets["train"]["other"]:
135
+ # Initialize the preprocessing
136
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
137
+ if not dataset_root:
138
+ return
139
+
140
+ # Preprocess all speakers
141
+ speaker_dirs = list(dataset_root.glob("*"))
142
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
143
+
144
+
145
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
146
+ # Initialize the preprocessing
147
+ dataset_name = "VoxCeleb1"
148
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
149
+ if not dataset_root:
150
+ return
151
+
152
+ # Get the contents of the meta file
153
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
154
+ metadata = [line.split("\t") for line in metafile][1:]
155
+
156
+ # Select the ID and the nationality, filter out non-anglophone speakers
157
+ nationalities = {line[0]: line[3] for line in metadata}
158
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
159
+ nationality.lower() in anglophone_nationalites]
160
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
161
+ (len(keep_speaker_ids), len(nationalities)))
162
+
163
+ # Get the speaker directories for anglophone speakers only
164
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
165
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
166
+ speaker_dir.name in keep_speaker_ids]
167
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
168
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
169
+
170
+ # Preprocess all speakers
171
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
172
+
173
+
174
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
175
+ # Initialize the preprocessing
176
+ dataset_name = "VoxCeleb2"
177
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
178
+ if not dataset_root:
179
+ return
180
+
181
+ # Get the speaker directories
182
+ # Preprocess all speakers
183
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
184
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
encoder/train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+
5
+ from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
6
+ from encoder.model import SpeakerEncoder
7
+ from encoder.params_model import *
8
+ from encoder.visualizations import Visualizations
9
+ from utils.profiler import Profiler
10
+
11
+
12
+ def sync(device: torch.device):
13
+ # For correct profiling (cuda operations are async)
14
+ if device.type == "cuda":
15
+ torch.cuda.synchronize(device)
16
+
17
+
18
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
19
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
20
+ no_visdom: bool):
21
+ # Create a dataset and a dataloader
22
+ dataset = SpeakerVerificationDataset(clean_data_root)
23
+ loader = SpeakerVerificationDataLoader(
24
+ dataset,
25
+ speakers_per_batch,
26
+ utterances_per_speaker,
27
+ num_workers=4,
28
+ )
29
+
30
+ # Setup the device on which to run the forward pass and the loss. These can be different,
31
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
32
+ # hyperparameters) faster on the CPU.
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ # FIXME: currently, the gradient is None if loss_device is cuda
35
+ loss_device = torch.device("cpu")
36
+
37
+ # Create the model and the optimizer
38
+ model = SpeakerEncoder(device, loss_device)
39
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
40
+ init_step = 1
41
+
42
+ # Configure file path for the model
43
+ model_dir = models_dir / run_id
44
+ model_dir.mkdir(exist_ok=True, parents=True)
45
+ state_fpath = model_dir / "encoder.pt"
46
+
47
+ # Load any existing model
48
+ if not force_restart:
49
+ if state_fpath.exists():
50
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
51
+ checkpoint = torch.load(state_fpath)
52
+ init_step = checkpoint["step"]
53
+ model.load_state_dict(checkpoint["model_state"])
54
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
55
+ optimizer.param_groups[0]["lr"] = learning_rate_init
56
+ else:
57
+ print("No model \"%s\" found, starting training from scratch." % run_id)
58
+ else:
59
+ print("Starting the training from scratch.")
60
+ model.train()
61
+
62
+ # Initialize the visualization environment
63
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
64
+ vis.log_dataset(dataset)
65
+ vis.log_params()
66
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
67
+ vis.log_implementation({"Device": device_name})
68
+
69
+ # Training loop
70
+ profiler = Profiler(summarize_every=10, disabled=False)
71
+ for step, speaker_batch in enumerate(loader, init_step):
72
+ profiler.tick("Blocking, waiting for batch (threaded)")
73
+
74
+ # Forward pass
75
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
76
+ sync(device)
77
+ profiler.tick("Data to %s" % device)
78
+ embeds = model(inputs)
79
+ sync(device)
80
+ profiler.tick("Forward pass")
81
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
82
+ loss, eer = model.loss(embeds_loss)
83
+ sync(loss_device)
84
+ profiler.tick("Loss")
85
+
86
+ # Backward pass
87
+ model.zero_grad()
88
+ loss.backward()
89
+ profiler.tick("Backward pass")
90
+ model.do_gradient_ops()
91
+ optimizer.step()
92
+ profiler.tick("Parameter update")
93
+
94
+ # Update visualizations
95
+ # learning_rate = optimizer.param_groups[0]["lr"]
96
+ vis.update(loss.item(), eer, step)
97
+
98
+ # Draw projections and save them to the backup folder
99
+ if umap_every != 0 and step % umap_every == 0:
100
+ print("Drawing and saving projections (step %d)" % step)
101
+ projection_fpath = model_dir / f"umap_{step:06d}.png"
102
+ embeds = embeds.detach().cpu().numpy()
103
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
104
+ vis.save()
105
+
106
+ # Overwrite the latest version of the model
107
+ if save_every != 0 and step % save_every == 0:
108
+ print("Saving the model (step %d)" % step)
109
+ torch.save({
110
+ "step": step + 1,
111
+ "model_state": model.state_dict(),
112
+ "optimizer_state": optimizer.state_dict(),
113
+ }, state_fpath)
114
+
115
+ # Make a backup
116
+ if backup_every != 0 and step % backup_every == 0:
117
+ print("Making a backup (step %d)" % step)
118
+ backup_fpath = model_dir / f"encoder_{step:06d}.bak"
119
+ torch.save({
120
+ "step": step + 1,
121
+ "model_state": model.state_dict(),
122
+ "optimizer_state": optimizer.state_dict(),
123
+ }, backup_fpath)
124
+
125
+ profiler.tick("Extras (visualizations, saving)")
encoder/visualizations.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from time import perf_counter as timer
3
+
4
+ import numpy as np
5
+ import umap
6
+ import visdom
7
+
8
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
9
+
10
+
11
+ colormap = np.array([
12
+ [76, 255, 0],
13
+ [0, 127, 70],
14
+ [255, 0, 0],
15
+ [255, 217, 38],
16
+ [0, 135, 255],
17
+ [165, 0, 165],
18
+ [255, 167, 255],
19
+ [0, 255, 255],
20
+ [255, 96, 38],
21
+ [142, 76, 0],
22
+ [33, 0, 127],
23
+ [0, 0, 0],
24
+ [183, 183, 183],
25
+ ], dtype=np.float) / 255
26
+
27
+
28
+ class Visualizations:
29
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
30
+ # Tracking data
31
+ self.last_update_timestamp = timer()
32
+ self.update_every = update_every
33
+ self.step_times = []
34
+ self.losses = []
35
+ self.eers = []
36
+ print("Updating the visualizations every %d steps." % update_every)
37
+
38
+ # If visdom is disabled TODO: use a better paradigm for that
39
+ self.disabled = disabled
40
+ if self.disabled:
41
+ return
42
+
43
+ # Set the environment name
44
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
45
+ if env_name is None:
46
+ self.env_name = now
47
+ else:
48
+ self.env_name = "%s (%s)" % (env_name, now)
49
+
50
+ # Connect to visdom and open the corresponding window in the browser
51
+ try:
52
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
53
+ except ConnectionError:
54
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
55
+ "start it.")
56
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
57
+
58
+ # Create the windows
59
+ self.loss_win = None
60
+ self.eer_win = None
61
+ # self.lr_win = None
62
+ self.implementation_win = None
63
+ self.projection_win = None
64
+ self.implementation_string = ""
65
+
66
+ def log_params(self):
67
+ if self.disabled:
68
+ return
69
+ from encoder import params_data
70
+ from encoder import params_model
71
+ param_string = "<b>Model parameters</b>:<br>"
72
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
73
+ value = getattr(params_model, param_name)
74
+ param_string += "\t%s: %s<br>" % (param_name, value)
75
+ param_string += "<b>Data parameters</b>:<br>"
76
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
77
+ value = getattr(params_data, param_name)
78
+ param_string += "\t%s: %s<br>" % (param_name, value)
79
+ self.vis.text(param_string, opts={"title": "Parameters"})
80
+
81
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
82
+ if self.disabled:
83
+ return
84
+ dataset_string = ""
85
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
86
+ dataset_string += "\n" + dataset.get_logs()
87
+ dataset_string = dataset_string.replace("\n", "<br>")
88
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
89
+
90
+ def log_implementation(self, params):
91
+ if self.disabled:
92
+ return
93
+ implementation_string = ""
94
+ for param, value in params.items():
95
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
96
+ implementation_string = implementation_string.replace("\n", "<br>")
97
+ self.implementation_string = implementation_string
98
+ self.implementation_win = self.vis.text(
99
+ implementation_string,
100
+ opts={"title": "Training implementation"}
101
+ )
102
+
103
+ def update(self, loss, eer, step):
104
+ # Update the tracking data
105
+ now = timer()
106
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
107
+ self.last_update_timestamp = now
108
+ self.losses.append(loss)
109
+ self.eers.append(eer)
110
+ print(".", end="")
111
+
112
+ # Update the plots every <update_every> steps
113
+ if step % self.update_every != 0:
114
+ return
115
+ time_string = "Step time: mean: %5dms std: %5dms" % \
116
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
117
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
118
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
119
+ if not self.disabled:
120
+ self.loss_win = self.vis.line(
121
+ [np.mean(self.losses)],
122
+ [step],
123
+ win=self.loss_win,
124
+ update="append" if self.loss_win else None,
125
+ opts=dict(
126
+ legend=["Avg. loss"],
127
+ xlabel="Step",
128
+ ylabel="Loss",
129
+ title="Loss",
130
+ )
131
+ )
132
+ self.eer_win = self.vis.line(
133
+ [np.mean(self.eers)],
134
+ [step],
135
+ win=self.eer_win,
136
+ update="append" if self.eer_win else None,
137
+ opts=dict(
138
+ legend=["Avg. EER"],
139
+ xlabel="Step",
140
+ ylabel="EER",
141
+ title="Equal error rate"
142
+ )
143
+ )
144
+ if self.implementation_win is not None:
145
+ self.vis.text(
146
+ self.implementation_string + ("<b>%s</b>" % time_string),
147
+ win=self.implementation_win,
148
+ opts={"title": "Training implementation"},
149
+ )
150
+
151
+ # Reset the tracking
152
+ self.losses.clear()
153
+ self.eers.clear()
154
+ self.step_times.clear()
155
+
156
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10):
157
+ import matplotlib.pyplot as plt
158
+
159
+ max_speakers = min(max_speakers, len(colormap))
160
+ embeds = embeds[:max_speakers * utterances_per_speaker]
161
+
162
+ n_speakers = len(embeds) // utterances_per_speaker
163
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
164
+ colors = [colormap[i] for i in ground_truth]
165
+
166
+ reducer = umap.UMAP()
167
+ projected = reducer.fit_transform(embeds)
168
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
169
+ plt.gca().set_aspect("equal", "datalim")
170
+ plt.title("UMAP projection (step %d)" % step)
171
+ if not self.disabled:
172
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
173
+ if out_fpath is not None:
174
+ plt.savefig(out_fpath)
175
+ plt.clf()
176
+
177
+ def save(self):
178
+ if not self.disabled:
179
+ self.vis.save([self.env_name])
requirements.txt ADDED
Binary file (450 Bytes). View file
 
synthesizer/LICENSE.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
4
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
5
+ Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
6
+ Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish)
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
synthesizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
synthesizer/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (204 Bytes). View file
 
synthesizer/__pycache__/audio.cpython-38.pyc ADDED
Binary file (6.85 kB). View file
 
synthesizer/__pycache__/hparams.cpython-38.pyc ADDED
Binary file (2.85 kB). View file
 
synthesizer/__pycache__/inference.cpython-38.pyc ADDED
Binary file (6.39 kB). View file
 
synthesizer/audio.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ from scipy import signal
5
+ from scipy.io import wavfile
6
+ import soundfile as sf
7
+
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ sf.write(path, wav.astype(np.float32), sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ #From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
31
+ def start_and_end_indices(quantized, silence_threshold=2):
32
+ for start in range(quantized.size):
33
+ if abs(quantized[start] - 127) > silence_threshold:
34
+ break
35
+ for end in range(quantized.size - 1, 1, -1):
36
+ if abs(quantized[end] - 127) > silence_threshold:
37
+ break
38
+
39
+ assert abs(quantized[start] - 127) > silence_threshold
40
+ assert abs(quantized[end] - 127) > silence_threshold
41
+
42
+ return start, end
43
+
44
+ def get_hop_size(hparams):
45
+ hop_size = hparams.hop_size
46
+ if hop_size is None:
47
+ assert hparams.frame_shift_ms is not None
48
+ hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
49
+ return hop_size
50
+
51
+ def linearspectrogram(wav, hparams):
52
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
53
+ S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
54
+
55
+ if hparams.signal_normalization:
56
+ return _normalize(S, hparams)
57
+ return S
58
+
59
+ def melspectrogram(wav, hparams):
60
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
61
+ S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
62
+
63
+ if hparams.signal_normalization:
64
+ return _normalize(S, hparams)
65
+ return S
66
+
67
+ def inv_linear_spectrogram(linear_spectrogram, hparams):
68
+ """Converts linear spectrogram to waveform using librosa"""
69
+ if hparams.signal_normalization:
70
+ D = _denormalize(linear_spectrogram, hparams)
71
+ else:
72
+ D = linear_spectrogram
73
+
74
+ S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
75
+
76
+ if hparams.use_lws:
77
+ processor = _lws_processor(hparams)
78
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
79
+ y = processor.istft(D).astype(np.float32)
80
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
81
+ else:
82
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
83
+
84
+ def inv_mel_spectrogram(mel_spectrogram, hparams):
85
+ """Converts mel spectrogram to waveform using librosa"""
86
+ if hparams.signal_normalization:
87
+ D = _denormalize(mel_spectrogram, hparams)
88
+ else:
89
+ D = mel_spectrogram
90
+
91
+ S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
92
+
93
+ if hparams.use_lws:
94
+ processor = _lws_processor(hparams)
95
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
96
+ y = processor.istft(D).astype(np.float32)
97
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
98
+ else:
99
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
100
+
101
+ def _lws_processor(hparams):
102
+ import lws
103
+ return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
104
+
105
+ def _griffin_lim(S, hparams):
106
+ """librosa implementation of Griffin-Lim
107
+ Based on https://github.com/librosa/librosa/issues/434
108
+ """
109
+ angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
110
+ S_complex = np.abs(S).astype(np.complex)
111
+ y = _istft(S_complex * angles, hparams)
112
+ for i in range(hparams.griffin_lim_iters):
113
+ angles = np.exp(1j * np.angle(_stft(y, hparams)))
114
+ y = _istft(S_complex * angles, hparams)
115
+ return y
116
+
117
+ def _stft(y, hparams):
118
+ if hparams.use_lws:
119
+ return _lws_processor(hparams).stft(y).T
120
+ else:
121
+ return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
122
+
123
+ def _istft(y, hparams):
124
+ return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
125
+
126
+ ##########################################################
127
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
128
+ def num_frames(length, fsize, fshift):
129
+ """Compute number of time frames of spectrogram
130
+ """
131
+ pad = (fsize - fshift)
132
+ if length % fshift == 0:
133
+ M = (length + pad * 2 - fsize) // fshift + 1
134
+ else:
135
+ M = (length + pad * 2 - fsize) // fshift + 2
136
+ return M
137
+
138
+
139
+ def pad_lr(x, fsize, fshift):
140
+ """Compute left and right padding
141
+ """
142
+ M = num_frames(len(x), fsize, fshift)
143
+ pad = (fsize - fshift)
144
+ T = len(x) + 2 * pad
145
+ r = (M - 1) * fshift + fsize - T
146
+ return pad, pad + r
147
+ ##########################################################
148
+ #Librosa correct padding
149
+ def librosa_pad_lr(x, fsize, fshift):
150
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
151
+
152
+ # Conversions
153
+ _mel_basis = None
154
+ _inv_mel_basis = None
155
+
156
+ def _linear_to_mel(spectogram, hparams):
157
+ global _mel_basis
158
+ if _mel_basis is None:
159
+ _mel_basis = _build_mel_basis(hparams)
160
+ return np.dot(_mel_basis, spectogram)
161
+
162
+ def _mel_to_linear(mel_spectrogram, hparams):
163
+ global _inv_mel_basis
164
+ if _inv_mel_basis is None:
165
+ _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
166
+ return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
167
+
168
+ def _build_mel_basis(hparams):
169
+ assert hparams.fmax <= hparams.sample_rate // 2
170
+ return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
171
+ fmin=hparams.fmin, fmax=hparams.fmax)
172
+
173
+ def _amp_to_db(x, hparams):
174
+ min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
175
+ return 20 * np.log10(np.maximum(min_level, x))
176
+
177
+ def _db_to_amp(x):
178
+ return np.power(10.0, (x) * 0.05)
179
+
180
+ def _normalize(S, hparams):
181
+ if hparams.allow_clipping_in_normalization:
182
+ if hparams.symmetric_mels:
183
+ return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
184
+ -hparams.max_abs_value, hparams.max_abs_value)
185
+ else:
186
+ return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
187
+
188
+ assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
189
+ if hparams.symmetric_mels:
190
+ return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
191
+ else:
192
+ return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
193
+
194
+ def _denormalize(D, hparams):
195
+ if hparams.allow_clipping_in_normalization:
196
+ if hparams.symmetric_mels:
197
+ return (((np.clip(D, -hparams.max_abs_value,
198
+ hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
199
+ + hparams.min_level_db)
200
+ else:
201
+ return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
202
+
203
+ if hparams.symmetric_mels:
204
+ return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
205
+ else:
206
+ return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
synthesizer/hparams.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import pprint
3
+
4
+ class HParams(object):
5
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
6
+ def __setitem__(self, key, value): setattr(self, key, value)
7
+ def __getitem__(self, key): return getattr(self, key)
8
+ def __repr__(self): return pprint.pformat(self.__dict__)
9
+
10
+ def parse(self, string):
11
+ # Overrides hparams from a comma-separated string of name=value pairs
12
+ if len(string) > 0:
13
+ overrides = [s.split("=") for s in string.split(",")]
14
+ keys, values = zip(*overrides)
15
+ keys = list(map(str.strip, keys))
16
+ values = list(map(str.strip, values))
17
+ for k in keys:
18
+ self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
19
+ return self
20
+
21
+ hparams = HParams(
22
+ ### Signal Processing (used in both synthesizer and vocoder)
23
+ sample_rate = 16000,
24
+ n_fft = 800,
25
+ num_mels = 80,
26
+ hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
27
+ win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
28
+ fmin = 55,
29
+ min_level_db = -100,
30
+ ref_level_db = 20,
31
+ max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
32
+ preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
33
+ preemphasize = True,
34
+
35
+ ### Tacotron Text-to-Speech (TTS)
36
+ tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
37
+ tts_encoder_dims = 256,
38
+ tts_decoder_dims = 128,
39
+ tts_postnet_dims = 512,
40
+ tts_encoder_K = 5,
41
+ tts_lstm_dims = 1024,
42
+ tts_postnet_K = 5,
43
+ tts_num_highways = 4,
44
+ tts_dropout = 0.5,
45
+ tts_cleaner_names = ["english_cleaners"],
46
+ tts_stop_threshold = -3.4, # Value below which audio generation ends.
47
+ # For example, for a range of [-4, 4], this
48
+ # will terminate the sequence at the first
49
+ # frame that has all values < -3.4
50
+
51
+ ### Tacotron Training
52
+ tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
53
+ (2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
54
+ (2, 2e-4, 80_000, 12), #
55
+ (2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
56
+ (2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
57
+ (2, 1e-5, 640_000, 12)], # lr = learning rate
58
+
59
+ tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
60
+ tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
61
+ # Set to -1 to generate after completing epoch, or 0 to disable
62
+
63
+ tts_eval_num_samples = 1, # Makes this number of samples
64
+
65
+ ### Data Preprocessing
66
+ max_mel_frames = 900,
67
+ rescale = True,
68
+ rescaling_max = 0.9,
69
+ synthesis_batch_size = 16, # For vocoder preprocessing and inference.
70
+
71
+ ### Mel Visualization and Griffin-Lim
72
+ signal_normalization = True,
73
+ power = 1.5,
74
+ griffin_lim_iters = 60,
75
+
76
+ ### Audio processing options
77
+ fmax = 7600, # Should not exceed (sample_rate // 2)
78
+ allow_clipping_in_normalization = True, # Used when signal_normalization = True
79
+ clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
80
+ use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
81
+ symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
82
+ # and [0, max_abs_value] if False
83
+ trim_silence = True, # Use with sample_rate of 16000 for best results
84
+
85
+ ### SV2TTS
86
+ speaker_embedding_size = 256, # Dimension for the speaker embedding
87
+ silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
88
+ utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
89
+ )
90
+
91
+ def hparams_debug_string():
92
+ return str(hparams)
synthesizer/inference.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from synthesizer import audio
3
+ from synthesizer.hparams import hparams
4
+ from synthesizer.models.tacotron import Tacotron
5
+ from synthesizer.utils.symbols import symbols
6
+ from synthesizer.utils.text import text_to_sequence
7
+ from vocoder.display import simple_table
8
+ from pathlib import Path
9
+ from typing import Union, List
10
+ import numpy as np
11
+ import librosa
12
+
13
+
14
+ class Synthesizer:
15
+ sample_rate = hparams.sample_rate
16
+ hparams = hparams
17
+
18
+ def __init__(self, model_fpath: Path, verbose=True):
19
+ """
20
+ The model isn't instantiated and loaded in memory until needed or until load() is called.
21
+
22
+ :param model_fpath: path to the trained model file
23
+ :param verbose: if False, prints less information when using the model
24
+ """
25
+ self.model_fpath = model_fpath
26
+ self.verbose = verbose
27
+
28
+ # Check for GPU
29
+ if torch.cuda.is_available():
30
+ self.device = torch.device("cuda")
31
+ else:
32
+ self.device = torch.device("cpu")
33
+ if self.verbose:
34
+ print("Synthesizer using device:", self.device)
35
+
36
+ # Tacotron model will be instantiated later on first use.
37
+ self._model = None
38
+
39
+ def is_loaded(self):
40
+ """
41
+ Whether the model is loaded in memory.
42
+ """
43
+ return self._model is not None
44
+
45
+ def load(self):
46
+ """
47
+ Instantiates and loads the model given the weights file that was passed in the constructor.
48
+ """
49
+ self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
50
+ num_chars=len(symbols),
51
+ encoder_dims=hparams.tts_encoder_dims,
52
+ decoder_dims=hparams.tts_decoder_dims,
53
+ n_mels=hparams.num_mels,
54
+ fft_bins=hparams.num_mels,
55
+ postnet_dims=hparams.tts_postnet_dims,
56
+ encoder_K=hparams.tts_encoder_K,
57
+ lstm_dims=hparams.tts_lstm_dims,
58
+ postnet_K=hparams.tts_postnet_K,
59
+ num_highways=hparams.tts_num_highways,
60
+ dropout=hparams.tts_dropout,
61
+ stop_threshold=hparams.tts_stop_threshold,
62
+ speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
63
+
64
+ self._model.load(self.model_fpath)
65
+ self._model.eval()
66
+
67
+ if self.verbose:
68
+ print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
69
+
70
+ def synthesize_spectrograms(self, texts: List[str],
71
+ embeddings: Union[np.ndarray, List[np.ndarray]],
72
+ return_alignments=False):
73
+ """
74
+ Synthesizes mel spectrograms from texts and speaker embeddings.
75
+
76
+ :param texts: a list of N text prompts to be synthesized
77
+ :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
78
+ :param return_alignments: if True, a matrix representing the alignments between the
79
+ characters
80
+ and each decoder output step will be returned for each spectrogram
81
+ :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
82
+ sequence length of spectrogram i, and possibly the alignments.
83
+ """
84
+ # Load the model on the first request.
85
+ if not self.is_loaded():
86
+ self.load()
87
+
88
+ # Preprocess text inputs
89
+ inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
90
+ if not isinstance(embeddings, list):
91
+ embeddings = [embeddings]
92
+
93
+ # Batch inputs
94
+ batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
95
+ for i in range(0, len(inputs), hparams.synthesis_batch_size)]
96
+ batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
97
+ for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
98
+
99
+ specs = []
100
+ for i, batch in enumerate(batched_inputs, 1):
101
+ if self.verbose:
102
+ print(f"\n| Generating {i}/{len(batched_inputs)}")
103
+
104
+ # Pad texts so they are all the same length
105
+ text_lens = [len(text) for text in batch]
106
+ max_text_len = max(text_lens)
107
+ chars = [pad1d(text, max_text_len) for text in batch]
108
+ chars = np.stack(chars)
109
+
110
+ # Stack speaker embeddings into 2D array for batch processing
111
+ speaker_embeds = np.stack(batched_embeds[i-1])
112
+
113
+ # Convert to tensor
114
+ chars = torch.tensor(chars).long().to(self.device)
115
+ speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
116
+
117
+ # Inference
118
+ _, mels, alignments = self._model.generate(chars, speaker_embeddings)
119
+ mels = mels.detach().cpu().numpy()
120
+ for m in mels:
121
+ # Trim silence from end of each spectrogram
122
+ while np.max(m[:, -1]) < hparams.tts_stop_threshold:
123
+ m = m[:, :-1]
124
+ specs.append(m)
125
+
126
+ if self.verbose:
127
+ print("\n\nDone.\n")
128
+ return (specs, alignments) if return_alignments else specs
129
+
130
+ @staticmethod
131
+ def load_preprocess_wav(fpath):
132
+ """
133
+ Loads and preprocesses an audio file under the same conditions the audio files were used to
134
+ train the synthesizer.
135
+ """
136
+ wav = librosa.load(str(fpath), hparams.sample_rate)[0]
137
+ if hparams.rescale:
138
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
139
+ return wav
140
+
141
+ @staticmethod
142
+ def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
143
+ """
144
+ Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
145
+ were fed to the synthesizer when training.
146
+ """
147
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
148
+ wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
149
+ else:
150
+ wav = fpath_or_wav
151
+
152
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
153
+ return mel_spectrogram
154
+
155
+ @staticmethod
156
+ def griffin_lim(mel):
157
+ """
158
+ Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
159
+ with the same parameters present in hparams.py.
160
+ """
161
+ return audio.inv_mel_spectrogram(mel, hparams)
162
+
163
+
164
+ def pad1d(x, max_len, pad_value=0):
165
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
synthesizer/models/__pycache__/tacotron.cpython-38.pyc ADDED
Binary file (14.3 kB). View file
 
synthesizer/models/tacotron.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+
10
+ class HighwayNetwork(nn.Module):
11
+ def __init__(self, size):
12
+ super().__init__()
13
+ self.W1 = nn.Linear(size, size)
14
+ self.W2 = nn.Linear(size, size)
15
+ self.W1.bias.data.fill_(0.)
16
+
17
+ def forward(self, x):
18
+ x1 = self.W1(x)
19
+ x2 = self.W2(x)
20
+ g = torch.sigmoid(x2)
21
+ y = g * F.relu(x1) + (1. - g) * x
22
+ return y
23
+
24
+
25
+ class Encoder(nn.Module):
26
+ def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
27
+ super().__init__()
28
+ prenet_dims = (encoder_dims, encoder_dims)
29
+ cbhg_channels = encoder_dims
30
+ self.embedding = nn.Embedding(num_chars, embed_dims)
31
+ self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
32
+ dropout=dropout)
33
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
34
+ proj_channels=[cbhg_channels, cbhg_channels],
35
+ num_highways=num_highways)
36
+
37
+ def forward(self, x, speaker_embedding=None):
38
+ x = self.embedding(x)
39
+ x = self.pre_net(x)
40
+ x.transpose_(1, 2)
41
+ x = self.cbhg(x)
42
+ if speaker_embedding is not None:
43
+ x = self.add_speaker_embedding(x, speaker_embedding)
44
+ return x
45
+
46
+ def add_speaker_embedding(self, x, speaker_embedding):
47
+ # SV2TTS
48
+ # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
49
+ # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
50
+ # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
51
+ # This concats the speaker embedding for each char in the encoder output
52
+
53
+ # Save the dimensions as human-readable names
54
+ batch_size = x.size()[0]
55
+ num_chars = x.size()[1]
56
+
57
+ if speaker_embedding.dim() == 1:
58
+ idx = 0
59
+ else:
60
+ idx = 1
61
+
62
+ # Start by making a copy of each speaker embedding to match the input text length
63
+ # The output of this has size (batch_size, num_chars * tts_embed_dims)
64
+ speaker_embedding_size = speaker_embedding.size()[idx]
65
+ e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
66
+
67
+ # Reshape it and transpose
68
+ e = e.reshape(batch_size, speaker_embedding_size, num_chars)
69
+ e = e.transpose(1, 2)
70
+
71
+ # Concatenate the tiled speaker embedding with the encoder output
72
+ x = torch.cat((x, e), 2)
73
+ return x
74
+
75
+
76
+ class BatchNormConv(nn.Module):
77
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
78
+ super().__init__()
79
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
80
+ self.bnorm = nn.BatchNorm1d(out_channels)
81
+ self.relu = relu
82
+
83
+ def forward(self, x):
84
+ x = self.conv(x)
85
+ x = F.relu(x) if self.relu is True else x
86
+ return self.bnorm(x)
87
+
88
+
89
+ class CBHG(nn.Module):
90
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
91
+ super().__init__()
92
+
93
+ # List of all rnns to call `flatten_parameters()` on
94
+ self._to_flatten = []
95
+
96
+ self.bank_kernels = [i for i in range(1, K + 1)]
97
+ self.conv1d_bank = nn.ModuleList()
98
+ for k in self.bank_kernels:
99
+ conv = BatchNormConv(in_channels, channels, k)
100
+ self.conv1d_bank.append(conv)
101
+
102
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
103
+
104
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
105
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
106
+
107
+ # Fix the highway input if necessary
108
+ if proj_channels[-1] != channels:
109
+ self.highway_mismatch = True
110
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
111
+ else:
112
+ self.highway_mismatch = False
113
+
114
+ self.highways = nn.ModuleList()
115
+ for i in range(num_highways):
116
+ hn = HighwayNetwork(channels)
117
+ self.highways.append(hn)
118
+
119
+ self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
120
+ self._to_flatten.append(self.rnn)
121
+
122
+ # Avoid fragmentation of RNN parameters and associated warning
123
+ self._flatten_parameters()
124
+
125
+ def forward(self, x):
126
+ # Although we `_flatten_parameters()` on init, when using DataParallel
127
+ # the model gets replicated, making it no longer guaranteed that the
128
+ # weights are contiguous in GPU memory. Hence, we must call it again
129
+ self._flatten_parameters()
130
+
131
+ # Save these for later
132
+ residual = x
133
+ seq_len = x.size(-1)
134
+ conv_bank = []
135
+
136
+ # Convolution Bank
137
+ for conv in self.conv1d_bank:
138
+ c = conv(x) # Convolution
139
+ conv_bank.append(c[:, :, :seq_len])
140
+
141
+ # Stack along the channel axis
142
+ conv_bank = torch.cat(conv_bank, dim=1)
143
+
144
+ # dump the last padding to fit residual
145
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
146
+
147
+ # Conv1d projections
148
+ x = self.conv_project1(x)
149
+ x = self.conv_project2(x)
150
+
151
+ # Residual Connect
152
+ x = x + residual
153
+
154
+ # Through the highways
155
+ x = x.transpose(1, 2)
156
+ if self.highway_mismatch is True:
157
+ x = self.pre_highway(x)
158
+ for h in self.highways: x = h(x)
159
+
160
+ # And then the RNN
161
+ x, _ = self.rnn(x)
162
+ return x
163
+
164
+ def _flatten_parameters(self):
165
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
166
+ to improve efficiency and avoid PyTorch yelling at us."""
167
+ [m.flatten_parameters() for m in self._to_flatten]
168
+
169
+ class PreNet(nn.Module):
170
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
171
+ super().__init__()
172
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
173
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
174
+ self.p = dropout
175
+
176
+ def forward(self, x):
177
+ x = self.fc1(x)
178
+ x = F.relu(x)
179
+ x = F.dropout(x, self.p, training=True)
180
+ x = self.fc2(x)
181
+ x = F.relu(x)
182
+ x = F.dropout(x, self.p, training=True)
183
+ return x
184
+
185
+
186
+ class Attention(nn.Module):
187
+ def __init__(self, attn_dims):
188
+ super().__init__()
189
+ self.W = nn.Linear(attn_dims, attn_dims, bias=False)
190
+ self.v = nn.Linear(attn_dims, 1, bias=False)
191
+
192
+ def forward(self, encoder_seq_proj, query, t):
193
+
194
+ # print(encoder_seq_proj.shape)
195
+ # Transform the query vector
196
+ query_proj = self.W(query).unsqueeze(1)
197
+
198
+ # Compute the scores
199
+ u = self.v(torch.tanh(encoder_seq_proj + query_proj))
200
+ scores = F.softmax(u, dim=1)
201
+
202
+ return scores.transpose(1, 2)
203
+
204
+
205
+ class LSA(nn.Module):
206
+ def __init__(self, attn_dim, kernel_size=31, filters=32):
207
+ super().__init__()
208
+ self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
209
+ self.L = nn.Linear(filters, attn_dim, bias=False)
210
+ self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
211
+ self.v = nn.Linear(attn_dim, 1, bias=False)
212
+ self.cumulative = None
213
+ self.attention = None
214
+
215
+ def init_attention(self, encoder_seq_proj):
216
+ device = next(self.parameters()).device # use same device as parameters
217
+ b, t, c = encoder_seq_proj.size()
218
+ self.cumulative = torch.zeros(b, t, device=device)
219
+ self.attention = torch.zeros(b, t, device=device)
220
+
221
+ def forward(self, encoder_seq_proj, query, t, chars):
222
+
223
+ if t == 0: self.init_attention(encoder_seq_proj)
224
+
225
+ processed_query = self.W(query).unsqueeze(1)
226
+
227
+ location = self.cumulative.unsqueeze(1)
228
+ processed_loc = self.L(self.conv(location).transpose(1, 2))
229
+
230
+ u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
231
+ u = u.squeeze(-1)
232
+
233
+ # Mask zero padding chars
234
+ u = u * (chars != 0).float()
235
+
236
+ # Smooth Attention
237
+ # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
238
+ scores = F.softmax(u, dim=1)
239
+ self.attention = scores
240
+ self.cumulative = self.cumulative + self.attention
241
+
242
+ return scores.unsqueeze(-1).transpose(1, 2)
243
+
244
+
245
+ class Decoder(nn.Module):
246
+ # Class variable because its value doesn't change between classes
247
+ # yet ought to be scoped by class because its a property of a Decoder
248
+ max_r = 20
249
+ def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
250
+ dropout, speaker_embedding_size):
251
+ super().__init__()
252
+ self.register_buffer("r", torch.tensor(1, dtype=torch.int))
253
+ self.n_mels = n_mels
254
+ prenet_dims = (decoder_dims * 2, decoder_dims * 2)
255
+ self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
256
+ dropout=dropout)
257
+ self.attn_net = LSA(decoder_dims)
258
+ self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
259
+ self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
260
+ self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
261
+ self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
262
+ self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
263
+ self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
264
+
265
+ def zoneout(self, prev, current, p=0.1):
266
+ device = next(self.parameters()).device # Use same device as parameters
267
+ mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
268
+ return prev * mask + current * (1 - mask)
269
+
270
+ def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
271
+ hidden_states, cell_states, context_vec, t, chars):
272
+
273
+ # Need this for reshaping mels
274
+ batch_size = encoder_seq.size(0)
275
+
276
+ # Unpack the hidden and cell states
277
+ attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
278
+ rnn1_cell, rnn2_cell = cell_states
279
+
280
+ # PreNet for the Attention RNN
281
+ prenet_out = self.prenet(prenet_in)
282
+
283
+ # Compute the Attention RNN hidden state
284
+ attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
285
+ attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
286
+
287
+ # Compute the attention scores
288
+ scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
289
+
290
+ # Dot product to create the context vector
291
+ context_vec = scores @ encoder_seq
292
+ context_vec = context_vec.squeeze(1)
293
+
294
+ # Concat Attention RNN output w. Context Vector & project
295
+ x = torch.cat([context_vec, attn_hidden], dim=1)
296
+ x = self.rnn_input(x)
297
+
298
+ # Compute first Residual RNN
299
+ rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
300
+ if self.training:
301
+ rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
302
+ else:
303
+ rnn1_hidden = rnn1_hidden_next
304
+ x = x + rnn1_hidden
305
+
306
+ # Compute second Residual RNN
307
+ rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
308
+ if self.training:
309
+ rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
310
+ else:
311
+ rnn2_hidden = rnn2_hidden_next
312
+ x = x + rnn2_hidden
313
+
314
+ # Project Mels
315
+ mels = self.mel_proj(x)
316
+ mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
317
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
318
+ cell_states = (rnn1_cell, rnn2_cell)
319
+
320
+ # Stop token prediction
321
+ s = torch.cat((x, context_vec), dim=1)
322
+ s = self.stop_proj(s)
323
+ stop_tokens = torch.sigmoid(s)
324
+
325
+ return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
326
+
327
+
328
+ class Tacotron(nn.Module):
329
+ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
330
+ fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
331
+ dropout, stop_threshold, speaker_embedding_size):
332
+ super().__init__()
333
+ self.n_mels = n_mels
334
+ self.lstm_dims = lstm_dims
335
+ self.encoder_dims = encoder_dims
336
+ self.decoder_dims = decoder_dims
337
+ self.speaker_embedding_size = speaker_embedding_size
338
+ self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
339
+ encoder_K, num_highways, dropout)
340
+ self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
341
+ self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
342
+ dropout, speaker_embedding_size)
343
+ self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
344
+ [postnet_dims, fft_bins], num_highways)
345
+ self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
346
+
347
+ self.init_model()
348
+ self.num_params()
349
+
350
+ self.register_buffer("step", torch.zeros(1, dtype=torch.long))
351
+ self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
352
+
353
+ @property
354
+ def r(self):
355
+ return self.decoder.r.item()
356
+
357
+ @r.setter
358
+ def r(self, value):
359
+ self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
360
+
361
+ def forward(self, x, m, speaker_embedding):
362
+ device = next(self.parameters()).device # use same device as parameters
363
+
364
+ self.step += 1
365
+ batch_size, _, steps = m.size()
366
+
367
+ # Initialise all hidden states and pack into tuple
368
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
369
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
370
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
371
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
372
+
373
+ # Initialise all lstm cell states and pack into tuple
374
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
375
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
376
+ cell_states = (rnn1_cell, rnn2_cell)
377
+
378
+ # <GO> Frame for start of decoder loop
379
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
380
+
381
+ # Need an initial context vector
382
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
383
+
384
+ # SV2TTS: Run the encoder with the speaker embedding
385
+ # The projection avoids unnecessary matmuls in the decoder loop
386
+ encoder_seq = self.encoder(x, speaker_embedding)
387
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
388
+
389
+ # Need a couple of lists for outputs
390
+ mel_outputs, attn_scores, stop_outputs = [], [], []
391
+
392
+ # Run the decoder loop
393
+ for t in range(0, steps, self.r):
394
+ prenet_in = m[:, :, t - 1] if t > 0 else go_frame
395
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
396
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
397
+ hidden_states, cell_states, context_vec, t, x)
398
+ mel_outputs.append(mel_frames)
399
+ attn_scores.append(scores)
400
+ stop_outputs.extend([stop_tokens] * self.r)
401
+
402
+ # Concat the mel outputs into sequence
403
+ mel_outputs = torch.cat(mel_outputs, dim=2)
404
+
405
+ # Post-Process for Linear Spectrograms
406
+ postnet_out = self.postnet(mel_outputs)
407
+ linear = self.post_proj(postnet_out)
408
+ linear = linear.transpose(1, 2)
409
+
410
+ # For easy visualisation
411
+ attn_scores = torch.cat(attn_scores, 1)
412
+ # attn_scores = attn_scores.cpu().data.numpy()
413
+ stop_outputs = torch.cat(stop_outputs, 1)
414
+
415
+ return mel_outputs, linear, attn_scores, stop_outputs
416
+
417
+ def generate(self, x, speaker_embedding=None, steps=2000):
418
+ self.eval()
419
+ device = next(self.parameters()).device # use same device as parameters
420
+
421
+ batch_size, _ = x.size()
422
+
423
+ # Need to initialise all hidden states and pack into tuple for tidyness
424
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
425
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
426
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
427
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
428
+
429
+ # Need to initialise all lstm cell states and pack into tuple for tidyness
430
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
431
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
432
+ cell_states = (rnn1_cell, rnn2_cell)
433
+
434
+ # Need a <GO> Frame for start of decoder loop
435
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
436
+
437
+ # Need an initial context vector
438
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
439
+
440
+ # SV2TTS: Run the encoder with the speaker embedding
441
+ # The projection avoids unnecessary matmuls in the decoder loop
442
+ encoder_seq = self.encoder(x, speaker_embedding)
443
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
444
+
445
+ # Need a couple of lists for outputs
446
+ mel_outputs, attn_scores, stop_outputs = [], [], []
447
+
448
+ # Run the decoder loop
449
+ for t in range(0, steps, self.r):
450
+ prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
451
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
452
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
453
+ hidden_states, cell_states, context_vec, t, x)
454
+ mel_outputs.append(mel_frames)
455
+ attn_scores.append(scores)
456
+ stop_outputs.extend([stop_tokens] * self.r)
457
+ # Stop the loop when all stop tokens in batch exceed threshold
458
+ if (stop_tokens > 0.5).all() and t > 10: break
459
+
460
+ # Concat the mel outputs into sequence
461
+ mel_outputs = torch.cat(mel_outputs, dim=2)
462
+
463
+ # Post-Process for Linear Spectrograms
464
+ postnet_out = self.postnet(mel_outputs)
465
+ linear = self.post_proj(postnet_out)
466
+
467
+
468
+ linear = linear.transpose(1, 2)
469
+
470
+ # For easy visualisation
471
+ attn_scores = torch.cat(attn_scores, 1)
472
+ stop_outputs = torch.cat(stop_outputs, 1)
473
+
474
+ self.train()
475
+
476
+ return mel_outputs, linear, attn_scores
477
+
478
+ def init_model(self):
479
+ for p in self.parameters():
480
+ if p.dim() > 1: nn.init.xavier_uniform_(p)
481
+
482
+ def get_step(self):
483
+ return self.step.data.item()
484
+
485
+ def reset_step(self):
486
+ # assignment to parameters or buffers is overloaded, updates internal dict entry
487
+ self.step = self.step.data.new_tensor(1)
488
+
489
+ def log(self, path, msg):
490
+ with open(path, "a") as f:
491
+ print(msg, file=f)
492
+
493
+ def load(self, path, optimizer=None):
494
+ # Use device of model params as location for loaded state
495
+ device = next(self.parameters()).device
496
+ checkpoint = torch.load(str(path), map_location=device)
497
+ self.load_state_dict(checkpoint["model_state"])
498
+
499
+ if "optimizer_state" in checkpoint and optimizer is not None:
500
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
501
+
502
+ def save(self, path, optimizer=None):
503
+ if optimizer is not None:
504
+ torch.save({
505
+ "model_state": self.state_dict(),
506
+ "optimizer_state": optimizer.state_dict(),
507
+ }, str(path))
508
+ else:
509
+ torch.save({
510
+ "model_state": self.state_dict(),
511
+ }, str(path))
512
+
513
+
514
+ def num_params(self, print_out=True):
515
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
516
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
517
+ if print_out:
518
+ print("Trainable Parameters: %.3fM" % parameters)
519
+ return parameters
synthesizer/preprocess.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing.pool import Pool
2
+ from synthesizer import audio
3
+ from functools import partial
4
+ from itertools import chain
5
+ from encoder import inference as encoder
6
+ from pathlib import Path
7
+ from utils import logmmse
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import librosa
11
+
12
+
13
+ def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int, skip_existing: bool, hparams,
14
+ no_alignments: bool, datasets_name: str, subfolders: str):
15
+ # Gather the input directories
16
+ dataset_root = datasets_root.joinpath(datasets_name)
17
+ input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in subfolders.split(",")]
18
+ print("\n ".join(map(str, ["Using data from:"] + input_dirs)))
19
+ assert all(input_dir.exists() for input_dir in input_dirs)
20
+
21
+ # Create the output directories for each output file type
22
+ out_dir.joinpath("mels").mkdir(exist_ok=True)
23
+ out_dir.joinpath("audio").mkdir(exist_ok=True)
24
+
25
+ # Create a metadata file
26
+ metadata_fpath = out_dir.joinpath("train.txt")
27
+ metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8")
28
+
29
+ # Preprocess the dataset
30
+ speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs))
31
+ func = partial(preprocess_speaker, out_dir=out_dir, skip_existing=skip_existing,
32
+ hparams=hparams, no_alignments=no_alignments)
33
+ job = Pool(n_processes).imap(func, speaker_dirs)
34
+ for speaker_metadata in tqdm(job, datasets_name, len(speaker_dirs), unit="speakers"):
35
+ for metadatum in speaker_metadata:
36
+ metadata_file.write("|".join(str(x) for x in metadatum) + "\n")
37
+ metadata_file.close()
38
+
39
+ # Verify the contents of the metadata file
40
+ with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
41
+ metadata = [line.split("|") for line in metadata_file]
42
+ mel_frames = sum([int(m[4]) for m in metadata])
43
+ timesteps = sum([int(m[3]) for m in metadata])
44
+ sample_rate = hparams.sample_rate
45
+ hours = (timesteps / sample_rate) / 3600
46
+ print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." %
47
+ (len(metadata), mel_frames, timesteps, hours))
48
+ print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata))
49
+ print("Max mel frames length: %d" % max(int(m[4]) for m in metadata))
50
+ print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata))
51
+
52
+
53
+ def preprocess_speaker(speaker_dir, out_dir: Path, skip_existing: bool, hparams, no_alignments: bool):
54
+ metadata = []
55
+ for book_dir in speaker_dir.glob("*"):
56
+ if no_alignments:
57
+ # Gather the utterance audios and texts
58
+ # LibriTTS uses .wav but we will include extensions for compatibility with other datasets
59
+ extensions = ["*.wav", "*.flac", "*.mp3"]
60
+ for extension in extensions:
61
+ wav_fpaths = book_dir.glob(extension)
62
+
63
+ for wav_fpath in wav_fpaths:
64
+ # Load the audio waveform
65
+ wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
66
+ if hparams.rescale:
67
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
68
+
69
+ # Get the corresponding text
70
+ # Check for .txt (for compatibility with other datasets)
71
+ text_fpath = wav_fpath.with_suffix(".txt")
72
+ if not text_fpath.exists():
73
+ # Check for .normalized.txt (LibriTTS)
74
+ text_fpath = wav_fpath.with_suffix(".normalized.txt")
75
+ assert text_fpath.exists()
76
+ with text_fpath.open("r") as text_file:
77
+ text = "".join([line for line in text_file])
78
+ text = text.replace("\"", "")
79
+ text = text.strip()
80
+
81
+ # Process the utterance
82
+ metadata.append(process_utterance(wav, text, out_dir, str(wav_fpath.with_suffix("").name),
83
+ skip_existing, hparams))
84
+ else:
85
+ # Process alignment file (LibriSpeech support)
86
+ # Gather the utterance audios and texts
87
+ try:
88
+ alignments_fpath = next(book_dir.glob("*.alignment.txt"))
89
+ with alignments_fpath.open("r") as alignments_file:
90
+ alignments = [line.rstrip().split(" ") for line in alignments_file]
91
+ except StopIteration:
92
+ # A few alignment files will be missing
93
+ continue
94
+
95
+ # Iterate over each entry in the alignments file
96
+ for wav_fname, words, end_times in alignments:
97
+ wav_fpath = book_dir.joinpath(wav_fname + ".flac")
98
+ assert wav_fpath.exists()
99
+ words = words.replace("\"", "").split(",")
100
+ end_times = list(map(float, end_times.replace("\"", "").split(",")))
101
+
102
+ # Process each sub-utterance
103
+ wavs, texts = split_on_silences(wav_fpath, words, end_times, hparams)
104
+ for i, (wav, text) in enumerate(zip(wavs, texts)):
105
+ sub_basename = "%s_%02d" % (wav_fname, i)
106
+ metadata.append(process_utterance(wav, text, out_dir, sub_basename,
107
+ skip_existing, hparams))
108
+
109
+ return [m for m in metadata if m is not None]
110
+
111
+
112
+ def split_on_silences(wav_fpath, words, end_times, hparams):
113
+ # Load the audio waveform
114
+ wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
115
+ if hparams.rescale:
116
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
117
+
118
+ words = np.array(words)
119
+ start_times = np.array([0.0] + end_times[:-1])
120
+ end_times = np.array(end_times)
121
+ assert len(words) == len(end_times) == len(start_times)
122
+ assert words[0] == "" and words[-1] == ""
123
+
124
+ # Find pauses that are too long
125
+ mask = (words == "") & (end_times - start_times >= hparams.silence_min_duration_split)
126
+ mask[0] = mask[-1] = True
127
+ breaks = np.where(mask)[0]
128
+
129
+ # Profile the noise from the silences and perform noise reduction on the waveform
130
+ silence_times = [[start_times[i], end_times[i]] for i in breaks]
131
+ silence_times = (np.array(silence_times) * hparams.sample_rate).astype(np.int)
132
+ noisy_wav = np.concatenate([wav[stime[0]:stime[1]] for stime in silence_times])
133
+ if len(noisy_wav) > hparams.sample_rate * 0.02:
134
+ profile = logmmse.profile_noise(noisy_wav, hparams.sample_rate)
135
+ wav = logmmse.denoise(wav, profile, eta=0)
136
+
137
+ # Re-attach segments that are too short
138
+ segments = list(zip(breaks[:-1], breaks[1:]))
139
+ segment_durations = [start_times[end] - end_times[start] for start, end in segments]
140
+ i = 0
141
+ while i < len(segments) and len(segments) > 1:
142
+ if segment_durations[i] < hparams.utterance_min_duration:
143
+ # See if the segment can be re-attached with the right or the left segment
144
+ left_duration = float("inf") if i == 0 else segment_durations[i - 1]
145
+ right_duration = float("inf") if i == len(segments) - 1 else segment_durations[i + 1]
146
+ joined_duration = segment_durations[i] + min(left_duration, right_duration)
147
+
148
+ # Do not re-attach if it causes the joined utterance to be too long
149
+ if joined_duration > hparams.hop_size * hparams.max_mel_frames / hparams.sample_rate:
150
+ i += 1
151
+ continue
152
+
153
+ # Re-attach the segment with the neighbour of shortest duration
154
+ j = i - 1 if left_duration <= right_duration else i
155
+ segments[j] = (segments[j][0], segments[j + 1][1])
156
+ segment_durations[j] = joined_duration
157
+ del segments[j + 1], segment_durations[j + 1]
158
+ else:
159
+ i += 1
160
+
161
+ # Split the utterance
162
+ segment_times = [[end_times[start], start_times[end]] for start, end in segments]
163
+ segment_times = (np.array(segment_times) * hparams.sample_rate).astype(np.int)
164
+ wavs = [wav[segment_time[0]:segment_time[1]] for segment_time in segment_times]
165
+ texts = [" ".join(words[start + 1:end]).replace(" ", " ") for start, end in segments]
166
+
167
+ # # DEBUG: play the audio segments (run with -n=1)
168
+ # import sounddevice as sd
169
+ # if len(wavs) > 1:
170
+ # print("This sentence was split in %d segments:" % len(wavs))
171
+ # else:
172
+ # print("There are no silences long enough for this sentence to be split:")
173
+ # for wav, text in zip(wavs, texts):
174
+ # # Pad the waveform with 1 second of silence because sounddevice tends to cut them early
175
+ # # when playing them. You shouldn't need to do that in your parsers.
176
+ # wav = np.concatenate((wav, [0] * 16000))
177
+ # print("\t%s" % text)
178
+ # sd.play(wav, 16000, blocking=True)
179
+ # print("")
180
+
181
+ return wavs, texts
182
+
183
+
184
+ def process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
185
+ skip_existing: bool, hparams):
186
+ ## FOR REFERENCE:
187
+ # For you not to lose your head if you ever wish to change things here or implement your own
188
+ # synthesizer.
189
+ # - Both the audios and the mel spectrograms are saved as numpy arrays
190
+ # - There is no processing done to the audios that will be saved to disk beyond volume
191
+ # normalization (in split_on_silences)
192
+ # - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This
193
+ # is why we re-apply it on the audio on the side of the vocoder.
194
+ # - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved
195
+ # without extra padding. This means that you won't have an exact relation between the length
196
+ # of the wav and of the mel spectrogram. See the vocoder data loader.
197
+
198
+
199
+ # Skip existing utterances if needed
200
+ mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
201
+ wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
202
+ if skip_existing and mel_fpath.exists() and wav_fpath.exists():
203
+ return None
204
+
205
+ # Trim silence
206
+ if hparams.trim_silence:
207
+ wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True)
208
+
209
+ # Skip utterances that are too short
210
+ if len(wav) < hparams.utterance_min_duration * hparams.sample_rate:
211
+ return None
212
+
213
+ # Compute the mel spectrogram
214
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
215
+ mel_frames = mel_spectrogram.shape[1]
216
+
217
+ # Skip utterances that are too long
218
+ if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length:
219
+ return None
220
+
221
+ # Write the spectrogram, embed and audio to disk
222
+ np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
223
+ np.save(wav_fpath, wav, allow_pickle=False)
224
+
225
+ # Return a tuple describing this training example
226
+ return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
227
+
228
+
229
+ def embed_utterance(fpaths, encoder_model_fpath):
230
+ if not encoder.is_loaded():
231
+ encoder.load_model(encoder_model_fpath)
232
+
233
+ # Compute the speaker embedding of the utterance
234
+ wav_fpath, embed_fpath = fpaths
235
+ wav = np.load(wav_fpath)
236
+ wav = encoder.preprocess_wav(wav)
237
+ embed = encoder.embed_utterance(wav)
238
+ np.save(embed_fpath, embed, allow_pickle=False)
239
+
240
+
241
+ def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int):
242
+ wav_dir = synthesizer_root.joinpath("audio")
243
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
244
+ assert wav_dir.exists() and metadata_fpath.exists()
245
+ embed_dir = synthesizer_root.joinpath("embeds")
246
+ embed_dir.mkdir(exist_ok=True)
247
+
248
+ # Gather the input wave filepath and the target output embed filepath
249
+ with metadata_fpath.open("r") as metadata_file:
250
+ metadata = [line.split("|") for line in metadata_file]
251
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
252
+
253
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
254
+ # Embed the utterances in separate threads
255
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
256
+ job = Pool(n_processes).imap(func, fpaths)
257
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
258
+
synthesizer/synthesize.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ from functools import partial
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+
10
+ from synthesizer.hparams import hparams_debug_string
11
+ from synthesizer.models.tacotron import Tacotron
12
+ from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
13
+ from synthesizer.utils import data_parallel_workaround
14
+ from synthesizer.utils.symbols import symbols
15
+
16
+
17
+ def run_synthesis(in_dir: Path, out_dir: Path, syn_model_fpath: Path, hparams):
18
+ # This generates ground truth-aligned mels for vocoder training
19
+ synth_dir = out_dir / "mels_gta"
20
+ synth_dir.mkdir(exist_ok=True, parents=True)
21
+ print(hparams_debug_string())
22
+
23
+ # Check for GPU
24
+ if torch.cuda.is_available():
25
+ device = torch.device("cuda")
26
+ if hparams.synthesis_batch_size % torch.cuda.device_count() != 0:
27
+ raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!")
28
+ else:
29
+ device = torch.device("cpu")
30
+ print("Synthesizer using device:", device)
31
+
32
+ # Instantiate Tacotron model
33
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
34
+ num_chars=len(symbols),
35
+ encoder_dims=hparams.tts_encoder_dims,
36
+ decoder_dims=hparams.tts_decoder_dims,
37
+ n_mels=hparams.num_mels,
38
+ fft_bins=hparams.num_mels,
39
+ postnet_dims=hparams.tts_postnet_dims,
40
+ encoder_K=hparams.tts_encoder_K,
41
+ lstm_dims=hparams.tts_lstm_dims,
42
+ postnet_K=hparams.tts_postnet_K,
43
+ num_highways=hparams.tts_num_highways,
44
+ dropout=0., # Use zero dropout for gta mels
45
+ stop_threshold=hparams.tts_stop_threshold,
46
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
47
+
48
+ # Load the weights
49
+ print("\nLoading weights at %s" % syn_model_fpath)
50
+ model.load(syn_model_fpath)
51
+ print("Tacotron weights loaded from step %d" % model.step)
52
+
53
+ # Synthesize using same reduction factor as the model is currently trained
54
+ r = np.int32(model.r)
55
+
56
+ # Set model to eval mode (disable gradient and zoneout)
57
+ model.eval()
58
+
59
+ # Initialize the dataset
60
+ metadata_fpath = in_dir.joinpath("train.txt")
61
+ mel_dir = in_dir.joinpath("mels")
62
+ embed_dir = in_dir.joinpath("embeds")
63
+
64
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
65
+ collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
66
+ data_loader = DataLoader(dataset, hparams.synthesis_batch_size, collate_fn=collate_fn, num_workers=2)
67
+
68
+ # Generate GTA mels
69
+ meta_out_fpath = out_dir / "synthesized.txt"
70
+ with meta_out_fpath.open("w") as file:
71
+ for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)):
72
+ texts, mels, embeds = texts.to(device), mels.to(device), embeds.to(device)
73
+
74
+ # Parallelize model onto GPUS using workaround due to python bug
75
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
76
+ _, mels_out, _ = data_parallel_workaround(model, texts, mels, embeds)
77
+ else:
78
+ _, mels_out, _, _ = model(texts, mels, embeds)
79
+
80
+ for j, k in enumerate(idx):
81
+ # Note: outputs mel-spectrogram files and target ones have same names, just different folders
82
+ mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1])
83
+ mel_out = mels_out[j].detach().cpu().numpy().T
84
+
85
+ # Use the length of the ground truth mel to remove padding from the generated mels
86
+ mel_out = mel_out[:int(dataset.metadata[k][4])]
87
+
88
+ # Write the spectrogram to disk
89
+ np.save(mel_filename, mel_out, allow_pickle=False)
90
+
91
+ # Write metadata into the synthesized file
92
+ file.write("|".join(dataset.metadata[k]))
synthesizer/synthesizer_dataset.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from synthesizer.utils.text import text_to_sequence
6
+
7
+
8
+ class SynthesizerDataset(Dataset):
9
+ def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams):
10
+ print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir))
11
+
12
+ with metadata_fpath.open("r") as metadata_file:
13
+ metadata = [line.split("|") for line in metadata_file]
14
+
15
+ mel_fnames = [x[1] for x in metadata if int(x[4])]
16
+ mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames]
17
+ embed_fnames = [x[2] for x in metadata if int(x[4])]
18
+ embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames]
19
+ self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths))
20
+ self.samples_texts = [x[5].strip() for x in metadata if int(x[4])]
21
+ self.metadata = metadata
22
+ self.hparams = hparams
23
+
24
+ print("Found %d samples" % len(self.samples_fpaths))
25
+
26
+ def __getitem__(self, index):
27
+ # Sometimes index may be a list of 2 (not sure why this happens)
28
+ # If that is the case, return a single item corresponding to first element in index
29
+ if index is list:
30
+ index = index[0]
31
+
32
+ mel_path, embed_path = self.samples_fpaths[index]
33
+ mel = np.load(mel_path).T.astype(np.float32)
34
+
35
+ # Load the embed
36
+ embed = np.load(embed_path)
37
+
38
+ # Get the text and clean it
39
+ text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names)
40
+
41
+ # Convert the list returned by text_to_sequence to a numpy array
42
+ text = np.asarray(text).astype(np.int32)
43
+
44
+ return text, mel.astype(np.float32), embed.astype(np.float32), index
45
+
46
+ def __len__(self):
47
+ return len(self.samples_fpaths)
48
+
49
+
50
+ def collate_synthesizer(batch, r, hparams):
51
+ # Text
52
+ x_lens = [len(x[0]) for x in batch]
53
+ max_x_len = max(x_lens)
54
+
55
+ chars = [pad1d(x[0], max_x_len) for x in batch]
56
+ chars = np.stack(chars)
57
+
58
+ # Mel spectrogram
59
+ spec_lens = [x[1].shape[-1] for x in batch]
60
+ max_spec_len = max(spec_lens) + 1
61
+ if max_spec_len % r != 0:
62
+ max_spec_len += r - max_spec_len % r
63
+
64
+ # WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence
65
+ # By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence.
66
+ if hparams.symmetric_mels:
67
+ mel_pad_value = -1 * hparams.max_abs_value
68
+ else:
69
+ mel_pad_value = 0
70
+
71
+ mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch]
72
+ mel = np.stack(mel)
73
+
74
+ # Speaker embedding (SV2TTS)
75
+ embeds = np.array([x[2] for x in batch])
76
+
77
+ # Index (for vocoder preprocessing)
78
+ indices = [x[3] for x in batch]
79
+
80
+
81
+ # Convert all to tensor
82
+ chars = torch.tensor(chars).long()
83
+ mel = torch.tensor(mel)
84
+ embeds = torch.tensor(embeds)
85
+
86
+ return chars, mel, embeds, indices
87
+
88
+ def pad1d(x, max_len, pad_value=0):
89
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
90
+
91
+ def pad2d(x, max_len, pad_value=0):
92
+ return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value)
synthesizer/train.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from functools import partial
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import optim
8
+ from torch.utils.data import DataLoader
9
+
10
+ from synthesizer import audio
11
+ from synthesizer.models.tacotron import Tacotron
12
+ from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
13
+ from synthesizer.utils import ValueWindow, data_parallel_workaround
14
+ from synthesizer.utils.plot import plot_spectrogram
15
+ from synthesizer.utils.symbols import symbols
16
+ from synthesizer.utils.text import sequence_to_text
17
+ from vocoder.display import *
18
+
19
+
20
+ def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
21
+
22
+
23
+ def time_string():
24
+ return datetime.now().strftime("%Y-%m-%d %H:%M")
25
+
26
+
27
+ def train(run_id: str, syn_dir: Path, models_dir: Path, save_every: int, backup_every: int, force_restart: bool,
28
+ hparams):
29
+ models_dir.mkdir(exist_ok=True)
30
+
31
+ model_dir = models_dir.joinpath(run_id)
32
+ plot_dir = model_dir.joinpath("plots")
33
+ wav_dir = model_dir.joinpath("wavs")
34
+ mel_output_dir = model_dir.joinpath("mel-spectrograms")
35
+ meta_folder = model_dir.joinpath("metas")
36
+ model_dir.mkdir(exist_ok=True)
37
+ plot_dir.mkdir(exist_ok=True)
38
+ wav_dir.mkdir(exist_ok=True)
39
+ mel_output_dir.mkdir(exist_ok=True)
40
+ meta_folder.mkdir(exist_ok=True)
41
+
42
+ weights_fpath = model_dir / f"synthesizer.pt"
43
+ metadata_fpath = syn_dir.joinpath("train.txt")
44
+
45
+ print("Checkpoint path: {}".format(weights_fpath))
46
+ print("Loading training data from: {}".format(metadata_fpath))
47
+ print("Using model: Tacotron")
48
+
49
+ # Bookkeeping
50
+ time_window = ValueWindow(100)
51
+ loss_window = ValueWindow(100)
52
+
53
+ # From WaveRNN/train_tacotron.py
54
+ if torch.cuda.is_available():
55
+ device = torch.device("cuda")
56
+
57
+ for session in hparams.tts_schedule:
58
+ _, _, _, batch_size = session
59
+ if batch_size % torch.cuda.device_count() != 0:
60
+ raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
61
+ else:
62
+ device = torch.device("cpu")
63
+ print("Using device:", device)
64
+
65
+ # Instantiate Tacotron Model
66
+ print("\nInitialising Tacotron Model...\n")
67
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
68
+ num_chars=len(symbols),
69
+ encoder_dims=hparams.tts_encoder_dims,
70
+ decoder_dims=hparams.tts_decoder_dims,
71
+ n_mels=hparams.num_mels,
72
+ fft_bins=hparams.num_mels,
73
+ postnet_dims=hparams.tts_postnet_dims,
74
+ encoder_K=hparams.tts_encoder_K,
75
+ lstm_dims=hparams.tts_lstm_dims,
76
+ postnet_K=hparams.tts_postnet_K,
77
+ num_highways=hparams.tts_num_highways,
78
+ dropout=hparams.tts_dropout,
79
+ stop_threshold=hparams.tts_stop_threshold,
80
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
81
+
82
+ # Initialize the optimizer
83
+ optimizer = optim.Adam(model.parameters())
84
+
85
+ # Load the weights
86
+ if force_restart or not weights_fpath.exists():
87
+ print("\nStarting the training of Tacotron from scratch\n")
88
+ model.save(weights_fpath)
89
+
90
+ # Embeddings metadata
91
+ char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
92
+ with open(char_embedding_fpath, "w", encoding="utf-8") as f:
93
+ for symbol in symbols:
94
+ if symbol == " ":
95
+ symbol = "\\s" # For visual purposes, swap space with \s
96
+
97
+ f.write("{}\n".format(symbol))
98
+
99
+ else:
100
+ print("\nLoading weights at %s" % weights_fpath)
101
+ model.load(weights_fpath, optimizer)
102
+ print("Tacotron weights loaded from step %d" % model.step)
103
+
104
+ # Initialize the dataset
105
+ metadata_fpath = syn_dir.joinpath("train.txt")
106
+ mel_dir = syn_dir.joinpath("mels")
107
+ embed_dir = syn_dir.joinpath("embeds")
108
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
109
+
110
+ for i, session in enumerate(hparams.tts_schedule):
111
+ current_step = model.get_step()
112
+
113
+ r, lr, max_step, batch_size = session
114
+
115
+ training_steps = max_step - current_step
116
+
117
+ # Do we need to change to the next session?
118
+ if current_step >= max_step:
119
+ # Are there no further sessions than the current one?
120
+ if i == len(hparams.tts_schedule) - 1:
121
+ # We have completed training. Save the model and exit
122
+ model.save(weights_fpath, optimizer)
123
+ break
124
+ else:
125
+ # There is a following session, go to it
126
+ continue
127
+
128
+ model.r = r
129
+
130
+ # Begin the training
131
+ simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
132
+ ("Batch Size", batch_size),
133
+ ("Learning Rate", lr),
134
+ ("Outputs/Step (r)", model.r)])
135
+
136
+ for p in optimizer.param_groups:
137
+ p["lr"] = lr
138
+
139
+ collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
140
+ data_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn)
141
+
142
+ total_iters = len(dataset)
143
+ steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
144
+ epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
145
+
146
+ for epoch in range(1, epochs+1):
147
+ for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
148
+ start_time = time.time()
149
+
150
+ # Generate stop tokens for training
151
+ stop = torch.ones(mels.shape[0], mels.shape[2])
152
+ for j, k in enumerate(idx):
153
+ stop[j, :int(dataset.metadata[k][4])-1] = 0
154
+
155
+ texts = texts.to(device)
156
+ mels = mels.to(device)
157
+ embeds = embeds.to(device)
158
+ stop = stop.to(device)
159
+
160
+ # Forward pass
161
+ # Parallelize model onto GPUS using workaround due to python bug
162
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
163
+ m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts, mels, embeds)
164
+ else:
165
+ m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
166
+
167
+ # Backward pass
168
+ m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
169
+ m2_loss = F.mse_loss(m2_hat, mels)
170
+ stop_loss = F.binary_cross_entropy(stop_pred, stop)
171
+
172
+ loss = m1_loss + m2_loss + stop_loss
173
+
174
+ optimizer.zero_grad()
175
+ loss.backward()
176
+
177
+ if hparams.tts_clip_grad_norm is not None:
178
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
179
+ if np.isnan(grad_norm.cpu()):
180
+ print("grad_norm was NaN!")
181
+
182
+ optimizer.step()
183
+
184
+ time_window.append(time.time() - start_time)
185
+ loss_window.append(loss.item())
186
+
187
+ step = model.get_step()
188
+ k = step // 1000
189
+
190
+ msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | " \
191
+ f"{1./time_window.average:#.2} steps/s | Step: {k}k | "
192
+ stream(msg)
193
+
194
+ # Backup or save model as appropriate
195
+ if backup_every != 0 and step % backup_every == 0 :
196
+ backup_fpath = weights_fpath.parent / f"synthesizer_{k:06d}.pt"
197
+ model.save(backup_fpath, optimizer)
198
+
199
+ if save_every != 0 and step % save_every == 0 :
200
+ # Must save latest optimizer state to ensure that resuming training
201
+ # doesn't produce artifacts
202
+ model.save(weights_fpath, optimizer)
203
+
204
+ # Evaluate model to generate samples
205
+ epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
206
+ step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
207
+ if epoch_eval or step_eval:
208
+ for sample_idx in range(hparams.tts_eval_num_samples):
209
+ # At most, generate samples equal to number in the batch
210
+ if sample_idx + 1 <= len(texts):
211
+ # Remove padding from mels using frame length in metadata
212
+ mel_length = int(dataset.metadata[idx[sample_idx]][4])
213
+ mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
214
+ target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
215
+ attention_len = mel_length // model.r
216
+
217
+ eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
218
+ mel_prediction=mel_prediction,
219
+ target_spectrogram=target_spectrogram,
220
+ input_seq=np_now(texts[sample_idx]),
221
+ step=step,
222
+ plot_dir=plot_dir,
223
+ mel_output_dir=mel_output_dir,
224
+ wav_dir=wav_dir,
225
+ sample_num=sample_idx + 1,
226
+ loss=loss,
227
+ hparams=hparams)
228
+
229
+ # Break out of loop to update training schedule
230
+ if step >= max_step:
231
+ break
232
+
233
+ # Add line break after every epoch
234
+ print("")
235
+
236
+
237
+ def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
238
+ plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
239
+ # Save some results for evaluation
240
+ attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
241
+ save_attention(attention, attention_path)
242
+
243
+ # save predicted mel spectrogram to disk (debug)
244
+ mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
245
+ np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
246
+
247
+ # save griffin lim inverted wav for debug (mel -> wav)
248
+ wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
249
+ wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
250
+ audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
251
+
252
+ # save real and predicted mel-spectrogram plot to disk (control purposes)
253
+ spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
254
+ title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
255
+ plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
256
+ target_spectrogram=target_spectrogram,
257
+ max_len=target_spectrogram.size // hparams.num_mels)
258
+ print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
synthesizer/utils/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ _output_ref = None
5
+ _replicas_ref = None
6
+
7
+ def data_parallel_workaround(model, *input):
8
+ global _output_ref
9
+ global _replicas_ref
10
+ device_ids = list(range(torch.cuda.device_count()))
11
+ output_device = device_ids[0]
12
+ replicas = torch.nn.parallel.replicate(model, device_ids)
13
+ # input.shape = (num_args, batch, ...)
14
+ inputs = torch.nn.parallel.scatter(input, device_ids)
15
+ # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
16
+ replicas = replicas[:len(inputs)]
17
+ outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
18
+ y_hat = torch.nn.parallel.gather(outputs, output_device)
19
+ _output_ref = outputs
20
+ _replicas_ref = replicas
21
+ return y_hat
22
+
23
+
24
+ class ValueWindow():
25
+ def __init__(self, window_size=100):
26
+ self._window_size = window_size
27
+ self._values = []
28
+
29
+ def append(self, x):
30
+ self._values = self._values[-(self._window_size - 1):] + [x]
31
+
32
+ @property
33
+ def sum(self):
34
+ return sum(self._values)
35
+
36
+ @property
37
+ def count(self):
38
+ return len(self._values)
39
+
40
+ @property
41
+ def average(self):
42
+ return self.sum / max(1, self.count)
43
+
44
+ def reset(self):
45
+ self._values = []
synthesizer/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.76 kB). View file
 
synthesizer/utils/__pycache__/cleaners.cpython-38.pyc ADDED
Binary file (2.87 kB). View file
 
synthesizer/utils/__pycache__/numbers.cpython-38.pyc ADDED
Binary file (2.23 kB). View file
 
synthesizer/utils/__pycache__/symbols.cpython-38.pyc ADDED
Binary file (623 Bytes). View file
 
synthesizer/utils/__pycache__/text.cpython-38.pyc ADDED
Binary file (2.77 kB). View file
 
synthesizer/utils/_cmudict.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ valid_symbols = [
4
+ "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2",
5
+ "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2",
6
+ "B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY",
7
+ "EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1",
8
+ "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0",
9
+ "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW",
10
+ "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH"
11
+ ]
12
+
13
+ _valid_symbol_set = set(valid_symbols)
14
+
15
+
16
+ class CMUDict:
17
+ """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
18
+ def __init__(self, file_or_path, keep_ambiguous=True):
19
+ if isinstance(file_or_path, str):
20
+ with open(file_or_path, encoding="latin-1") as f:
21
+ entries = _parse_cmudict(f)
22
+ else:
23
+ entries = _parse_cmudict(file_or_path)
24
+ if not keep_ambiguous:
25
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
26
+ self._entries = entries
27
+
28
+
29
+ def __len__(self):
30
+ return len(self._entries)
31
+
32
+
33
+ def lookup(self, word):
34
+ """Returns list of ARPAbet pronunciations of the given word."""
35
+ return self._entries.get(word.upper())
36
+
37
+
38
+
39
+ _alt_re = re.compile(r"\([0-9]+\)")
40
+
41
+
42
+ def _parse_cmudict(file):
43
+ cmudict = {}
44
+ for line in file:
45
+ if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
46
+ parts = line.split(" ")
47
+ word = re.sub(_alt_re, "", parts[0])
48
+ pronunciation = _get_pronunciation(parts[1])
49
+ if pronunciation:
50
+ if word in cmudict:
51
+ cmudict[word].append(pronunciation)
52
+ else:
53
+ cmudict[word] = [pronunciation]
54
+ return cmudict
55
+
56
+
57
+ def _get_pronunciation(s):
58
+ parts = s.strip().split(" ")
59
+ for part in parts:
60
+ if part not in _valid_symbol_set:
61
+ return None
62
+ return " ".join(parts)
synthesizer/utils/cleaners.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cleaners are transformations that run over the input text at both training and eval time.
3
+
4
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
5
+ hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
6
+ 1. "english_cleaners" for English text
7
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
8
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
9
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
10
+ the symbols in symbols.py to match your data).
11
+ """
12
+ import re
13
+ from unidecode import unidecode
14
+ from synthesizer.utils.numbers import normalize_numbers
15
+
16
+
17
+ # Regular expression matching whitespace:
18
+ _whitespace_re = re.compile(r"\s+")
19
+
20
+ # List of (regular expression, replacement) pairs for abbreviations:
21
+ _abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
22
+ ("mrs", "misess"),
23
+ ("mr", "mister"),
24
+ ("dr", "doctor"),
25
+ ("st", "saint"),
26
+ ("co", "company"),
27
+ ("jr", "junior"),
28
+ ("maj", "major"),
29
+ ("gen", "general"),
30
+ ("drs", "doctors"),
31
+ ("rev", "reverend"),
32
+ ("lt", "lieutenant"),
33
+ ("hon", "honorable"),
34
+ ("sgt", "sergeant"),
35
+ ("capt", "captain"),
36
+ ("esq", "esquire"),
37
+ ("ltd", "limited"),
38
+ ("col", "colonel"),
39
+ ("ft", "fort"),
40
+ ]]
41
+
42
+
43
+ def expand_abbreviations(text):
44
+ for regex, replacement in _abbreviations:
45
+ text = re.sub(regex, replacement, text)
46
+ return text
47
+
48
+
49
+ def expand_numbers(text):
50
+ return normalize_numbers(text)
51
+
52
+
53
+ def lowercase(text):
54
+ """lowercase input tokens."""
55
+ return text.lower()
56
+
57
+
58
+ def collapse_whitespace(text):
59
+ return re.sub(_whitespace_re, " ", text)
60
+
61
+
62
+ def convert_to_ascii(text):
63
+ return unidecode(text)
64
+
65
+
66
+ def basic_cleaners(text):
67
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
68
+ text = lowercase(text)
69
+ text = collapse_whitespace(text)
70
+ return text
71
+
72
+
73
+ def transliteration_cleaners(text):
74
+ """Pipeline for non-English text that transliterates to ASCII."""
75
+ text = convert_to_ascii(text)
76
+ text = lowercase(text)
77
+ text = collapse_whitespace(text)
78
+ return text
79
+
80
+
81
+ def english_cleaners(text):
82
+ """Pipeline for English text, including number and abbreviation expansion."""
83
+ text = convert_to_ascii(text)
84
+ text = lowercase(text)
85
+ text = expand_numbers(text)
86
+ text = expand_abbreviations(text)
87
+ text = collapse_whitespace(text)
88
+ return text
synthesizer/utils/numbers.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import inflect
3
+
4
+
5
+ _inflect = inflect.engine()
6
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
7
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
8
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
9
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
10
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
11
+ _number_re = re.compile(r"[0-9]+")
12
+
13
+
14
+ def _remove_commas(m):
15
+ return m.group(1).replace(",", "")
16
+
17
+
18
+ def _expand_decimal_point(m):
19
+ return m.group(1).replace(".", " point ")
20
+
21
+
22
+ def _expand_dollars(m):
23
+ match = m.group(1)
24
+ parts = match.split(".")
25
+ if len(parts) > 2:
26
+ return match + " dollars" # Unexpected format
27
+ dollars = int(parts[0]) if parts[0] else 0
28
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
29
+ if dollars and cents:
30
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
31
+ cent_unit = "cent" if cents == 1 else "cents"
32
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
33
+ elif dollars:
34
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
35
+ return "%s %s" % (dollars, dollar_unit)
36
+ elif cents:
37
+ cent_unit = "cent" if cents == 1 else "cents"
38
+ return "%s %s" % (cents, cent_unit)
39
+ else:
40
+ return "zero dollars"
41
+
42
+
43
+ def _expand_ordinal(m):
44
+ return _inflect.number_to_words(m.group(0))
45
+
46
+
47
+ def _expand_number(m):
48
+ num = int(m.group(0))
49
+ if num > 1000 and num < 3000:
50
+ if num == 2000:
51
+ return "two thousand"
52
+ elif num > 2000 and num < 2010:
53
+ return "two thousand " + _inflect.number_to_words(num % 100)
54
+ elif num % 100 == 0:
55
+ return _inflect.number_to_words(num // 100) + " hundred"
56
+ else:
57
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
58
+ else:
59
+ return _inflect.number_to_words(num, andword="")
60
+
61
+
62
+ def normalize_numbers(text):
63
+ text = re.sub(_comma_number_re, _remove_commas, text)
64
+ text = re.sub(_pounds_re, r"\1 pounds", text)
65
+ text = re.sub(_dollars_re, _expand_dollars, text)
66
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
67
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
68
+ text = re.sub(_number_re, _expand_number, text)
69
+ return text