drewThomasson commited on
Commit
f341d4d
1 Parent(s): 017ec2f

Upload 310 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. TTS/__init__.py +33 -0
  2. TTS/api.py +458 -0
  3. TTS/bin/__init__.py +0 -0
  4. TTS/bin/collect_env_info.py +49 -0
  5. TTS/bin/compute_attention_masks.py +169 -0
  6. TTS/bin/compute_embeddings.py +201 -0
  7. TTS/bin/compute_statistics.py +100 -0
  8. TTS/bin/eval_encoder.py +92 -0
  9. TTS/bin/extract_tts_spectrograms.py +290 -0
  10. TTS/bin/find_unique_chars.py +40 -0
  11. TTS/bin/find_unique_phonemes.py +79 -0
  12. TTS/bin/remove_silence_using_vad.py +128 -0
  13. TTS/bin/resample.py +90 -0
  14. TTS/bin/synthesize.py +486 -0
  15. TTS/bin/train_encoder.py +340 -0
  16. TTS/bin/train_tts.py +75 -0
  17. TTS/bin/train_vocoder.py +81 -0
  18. TTS/bin/tune_wavegrad.py +107 -0
  19. TTS/config/__init__.py +138 -0
  20. TTS/config/shared_configs.py +268 -0
  21. TTS/demos/xtts_ft_demo/requirements.txt +2 -0
  22. TTS/demos/xtts_ft_demo/utils/formatter.py +161 -0
  23. TTS/demos/xtts_ft_demo/utils/gpt_train.py +171 -0
  24. TTS/demos/xtts_ft_demo/xtts_demo.py +433 -0
  25. TTS/encoder/README.md +18 -0
  26. TTS/encoder/__init__.py +0 -0
  27. TTS/encoder/configs/base_encoder_config.py +61 -0
  28. TTS/encoder/configs/emotion_encoder_config.py +12 -0
  29. TTS/encoder/configs/speaker_encoder_config.py +11 -0
  30. TTS/encoder/dataset.py +146 -0
  31. TTS/encoder/losses.py +230 -0
  32. TTS/encoder/models/base_encoder.py +165 -0
  33. TTS/encoder/models/lstm.py +99 -0
  34. TTS/encoder/models/resnet.py +198 -0
  35. TTS/encoder/requirements.txt +2 -0
  36. TTS/encoder/utils/__init__.py +0 -0
  37. TTS/encoder/utils/generic_utils.py +141 -0
  38. TTS/encoder/utils/prepare_voxceleb.py +226 -0
  39. TTS/encoder/utils/training.py +99 -0
  40. TTS/encoder/utils/visual.py +53 -0
  41. TTS/model.py +66 -0
  42. TTS/server/README.md +21 -0
  43. TTS/server/__init__.py +0 -0
  44. TTS/server/conf.json +12 -0
  45. TTS/server/server.py +262 -0
  46. TTS/server/static/coqui-log-green-TTS.png +0 -0
  47. TTS/server/templates/details.html +131 -0
  48. TTS/server/templates/index.html +154 -0
  49. TTS/tts/__init__.py +0 -0
  50. TTS/tts/configs/__init__.py +17 -0
TTS/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+
3
+ from TTS.utils.generic_utils import is_pytorch_at_least_2_4
4
+
5
+ __version__ = importlib.metadata.version("coqui-tts")
6
+
7
+
8
+ if is_pytorch_at_least_2_4():
9
+ import _codecs
10
+ from collections import defaultdict
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ from TTS.config.shared_configs import BaseDatasetConfig
16
+ from TTS.tts.configs.xtts_config import XttsConfig
17
+ from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
18
+ from TTS.utils.radam import RAdam
19
+
20
+ torch.serialization.add_safe_globals([dict, defaultdict, RAdam])
21
+
22
+ # Bark
23
+ torch.serialization.add_safe_globals(
24
+ [
25
+ np.core.multiarray.scalar,
26
+ np.dtype,
27
+ np.dtypes.Float64DType,
28
+ _codecs.encode, # TODO: safe by default from Pytorch 2.5
29
+ ]
30
+ )
31
+
32
+ # XTTS
33
+ torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs])
TTS/api.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import tempfile
3
+ import warnings
4
+ from pathlib import Path
5
+
6
+ from torch import nn
7
+
8
+ from TTS.config import load_config
9
+ from TTS.utils.audio.numpy_transforms import save_wav
10
+ from TTS.utils.manage import ModelManager
11
+ from TTS.utils.synthesizer import Synthesizer
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class TTS(nn.Module):
17
+ """TODO: Add voice conversion and Capacitron support."""
18
+
19
+ def __init__(
20
+ self,
21
+ model_name: str = "",
22
+ model_path: str = None,
23
+ config_path: str = None,
24
+ vocoder_path: str = None,
25
+ vocoder_config_path: str = None,
26
+ progress_bar: bool = True,
27
+ gpu=False,
28
+ ):
29
+ """🐸TTS python interface that allows to load and use the released models.
30
+
31
+ Example with a multi-speaker model:
32
+ >>> from TTS.api import TTS
33
+ >>> tts = TTS(TTS.list_models()[0])
34
+ >>> wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0])
35
+ >>> tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")
36
+
37
+ Example with a single-speaker model:
38
+ >>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False)
39
+ >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav")
40
+
41
+ Example loading a model from a path:
42
+ >>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False)
43
+ >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav")
44
+
45
+ Example voice cloning with YourTTS in English, French and Portuguese:
46
+ >>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
47
+ >>> tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="thisisit.wav")
48
+ >>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav")
49
+ >>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav")
50
+
51
+ Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html):
52
+ >>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True)
53
+ >>> tts.tts_to_file("This is a test.", file_path="output.wav")
54
+
55
+ Args:
56
+ model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
57
+ model_path (str, optional): Path to the model checkpoint. Defaults to None.
58
+ config_path (str, optional): Path to the model config. Defaults to None.
59
+ vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
60
+ vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None.
61
+ progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
62
+ gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
63
+ """
64
+ super().__init__()
65
+ self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar)
66
+ self.config = load_config(config_path) if config_path else None
67
+ self.synthesizer = None
68
+ self.voice_converter = None
69
+ self.model_name = ""
70
+ if gpu:
71
+ warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
72
+
73
+ if model_name is not None and len(model_name) > 0:
74
+ if "tts_models" in model_name:
75
+ self.load_tts_model_by_name(model_name, gpu)
76
+ elif "voice_conversion_models" in model_name:
77
+ self.load_vc_model_by_name(model_name, gpu)
78
+ else:
79
+ self.load_model_by_name(model_name, gpu)
80
+
81
+ if model_path:
82
+ self.load_tts_model_by_path(
83
+ model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu
84
+ )
85
+
86
+ @property
87
+ def models(self):
88
+ return self.manager.list_tts_models()
89
+
90
+ @property
91
+ def is_multi_speaker(self):
92
+ if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager:
93
+ return self.synthesizer.tts_model.speaker_manager.num_speakers > 1
94
+ return False
95
+
96
+ @property
97
+ def is_multi_lingual(self):
98
+ # Not sure what sets this to None, but applied a fix to prevent crashing.
99
+ if (
100
+ isinstance(self.model_name, str)
101
+ and "xtts" in self.model_name
102
+ or self.config
103
+ and ("xtts" in self.config.model or "languages" in self.config and len(self.config.languages) > 1)
104
+ ):
105
+ return True
106
+ if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
107
+ return self.synthesizer.tts_model.language_manager.num_languages > 1
108
+ return False
109
+
110
+ @property
111
+ def speakers(self):
112
+ if not self.is_multi_speaker:
113
+ return None
114
+ return self.synthesizer.tts_model.speaker_manager.speaker_names
115
+
116
+ @property
117
+ def languages(self):
118
+ if not self.is_multi_lingual:
119
+ return None
120
+ return self.synthesizer.tts_model.language_manager.language_names
121
+
122
+ @staticmethod
123
+ def get_models_file_path():
124
+ return Path(__file__).parent / ".models.json"
125
+
126
+ @staticmethod
127
+ def list_models():
128
+ return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
129
+
130
+ def download_model_by_name(self, model_name: str):
131
+ model_path, config_path, model_item = self.manager.download_model(model_name)
132
+ if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
133
+ # return model directory if there are multiple files
134
+ # we assume that the model knows how to load itself
135
+ return None, None, None, None, model_path
136
+ if model_item.get("default_vocoder") is None:
137
+ return model_path, config_path, None, None, None
138
+ vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
139
+ return model_path, config_path, vocoder_path, vocoder_config_path, None
140
+
141
+ def load_model_by_name(self, model_name: str, gpu: bool = False):
142
+ """Load one of the 🐸TTS models by name.
143
+
144
+ Args:
145
+ model_name (str): Model name to load. You can list models by ```tts.models```.
146
+ gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
147
+ """
148
+ self.load_tts_model_by_name(model_name, gpu)
149
+
150
+ def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
151
+ """Load one of the voice conversion models by name.
152
+
153
+ Args:
154
+ model_name (str): Model name to load. You can list models by ```tts.models```.
155
+ gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
156
+ """
157
+ self.model_name = model_name
158
+ model_path, config_path, _, _, _ = self.download_model_by_name(model_name)
159
+ self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu)
160
+
161
+ def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
162
+ """Load one of 🐸TTS models by name.
163
+
164
+ Args:
165
+ model_name (str): Model name to load. You can list models by ```tts.models```.
166
+ gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
167
+
168
+ TODO: Add tests
169
+ """
170
+ self.synthesizer = None
171
+ self.model_name = model_name
172
+
173
+ model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
174
+
175
+ # init synthesizer
176
+ # None values are fetch from the model
177
+ self.synthesizer = Synthesizer(
178
+ tts_checkpoint=model_path,
179
+ tts_config_path=config_path,
180
+ tts_speakers_file=None,
181
+ tts_languages_file=None,
182
+ vocoder_checkpoint=vocoder_path,
183
+ vocoder_config=vocoder_config_path,
184
+ encoder_checkpoint=None,
185
+ encoder_config=None,
186
+ model_dir=model_dir,
187
+ use_cuda=gpu,
188
+ )
189
+
190
+ def load_tts_model_by_path(
191
+ self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False
192
+ ):
193
+ """Load a model from a path.
194
+
195
+ Args:
196
+ model_path (str): Path to the model checkpoint.
197
+ config_path (str): Path to the model config.
198
+ vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
199
+ vocoder_config (str, optional): Path to the vocoder config. Defaults to None.
200
+ gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
201
+ """
202
+
203
+ self.synthesizer = Synthesizer(
204
+ tts_checkpoint=model_path,
205
+ tts_config_path=config_path,
206
+ tts_speakers_file=None,
207
+ tts_languages_file=None,
208
+ vocoder_checkpoint=vocoder_path,
209
+ vocoder_config=vocoder_config,
210
+ encoder_checkpoint=None,
211
+ encoder_config=None,
212
+ use_cuda=gpu,
213
+ )
214
+
215
+ def _check_arguments(
216
+ self,
217
+ speaker: str = None,
218
+ language: str = None,
219
+ speaker_wav: str = None,
220
+ emotion: str = None,
221
+ speed: float = None,
222
+ **kwargs,
223
+ ) -> None:
224
+ """Check if the arguments are valid for the model."""
225
+ # check for the coqui tts models
226
+ if self.is_multi_speaker and (speaker is None and speaker_wav is None):
227
+ raise ValueError("Model is multi-speaker but no `speaker` is provided.")
228
+ if self.is_multi_lingual and language is None:
229
+ raise ValueError("Model is multi-lingual but no `language` is provided.")
230
+ if not self.is_multi_speaker and speaker is not None and "voice_dir" not in kwargs:
231
+ raise ValueError("Model is not multi-speaker but `speaker` is provided.")
232
+ if not self.is_multi_lingual and language is not None:
233
+ raise ValueError("Model is not multi-lingual but `language` is provided.")
234
+ if emotion is not None and speed is not None:
235
+ raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.")
236
+
237
+ def tts(
238
+ self,
239
+ text: str,
240
+ speaker: str = None,
241
+ language: str = None,
242
+ speaker_wav: str = None,
243
+ emotion: str = None,
244
+ speed: float = None,
245
+ split_sentences: bool = True,
246
+ **kwargs,
247
+ ):
248
+ """Convert text to speech.
249
+
250
+ Args:
251
+ text (str):
252
+ Input text to synthesize.
253
+ speaker (str, optional):
254
+ Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
255
+ `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
256
+ language (str): Language of the text. If None, the default language of the speaker is used. Language is only
257
+ supported by `XTTS` model.
258
+ speaker_wav (str, optional):
259
+ Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
260
+ Defaults to None.
261
+ emotion (str, optional):
262
+ Emotion to use for 🐸Coqui Studio models. If None, Studio models use "Neutral". Defaults to None.
263
+ speed (float, optional):
264
+ Speed factor to use for 🐸Coqui Studio models, between 0 and 2.0. If None, Studio models use 1.0.
265
+ Defaults to None.
266
+ split_sentences (bool, optional):
267
+ Split text into sentences, synthesize them separately and concatenate the file audio.
268
+ Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
269
+ applicable to the 🐸TTS models. Defaults to True.
270
+ kwargs (dict, optional):
271
+ Additional arguments for the model.
272
+ """
273
+ self._check_arguments(
274
+ speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, speed=speed, **kwargs
275
+ )
276
+ wav = self.synthesizer.tts(
277
+ text=text,
278
+ speaker_name=speaker,
279
+ language_name=language,
280
+ speaker_wav=speaker_wav,
281
+ reference_wav=None,
282
+ style_wav=None,
283
+ style_text=None,
284
+ reference_speaker_name=None,
285
+ split_sentences=split_sentences,
286
+ **kwargs,
287
+ )
288
+ return wav
289
+
290
+ def tts_to_file(
291
+ self,
292
+ text: str,
293
+ speaker: str = None,
294
+ language: str = None,
295
+ speaker_wav: str = None,
296
+ emotion: str = None,
297
+ speed: float = 1.0,
298
+ pipe_out=None,
299
+ file_path: str = "output.wav",
300
+ split_sentences: bool = True,
301
+ **kwargs,
302
+ ):
303
+ """Convert text to speech.
304
+
305
+ Args:
306
+ text (str):
307
+ Input text to synthesize.
308
+ speaker (str, optional):
309
+ Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
310
+ `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
311
+ language (str, optional):
312
+ Language code for multi-lingual models. You can check whether loaded model is multi-lingual
313
+ `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
314
+ speaker_wav (str, optional):
315
+ Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
316
+ Defaults to None.
317
+ emotion (str, optional):
318
+ Emotion to use for 🐸Coqui Studio models. Defaults to "Neutral".
319
+ speed (float, optional):
320
+ Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None.
321
+ pipe_out (BytesIO, optional):
322
+ Flag to stdout the generated TTS wav file for shell pipe.
323
+ file_path (str, optional):
324
+ Output file path. Defaults to "output.wav".
325
+ split_sentences (bool, optional):
326
+ Split text into sentences, synthesize them separately and concatenate the file audio.
327
+ Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
328
+ applicable to the 🐸TTS models. Defaults to True.
329
+ kwargs (dict, optional):
330
+ Additional arguments for the model.
331
+ """
332
+ self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
333
+
334
+ wav = self.tts(
335
+ text=text,
336
+ speaker=speaker,
337
+ language=language,
338
+ speaker_wav=speaker_wav,
339
+ split_sentences=split_sentences,
340
+ **kwargs,
341
+ )
342
+ self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
343
+ return file_path
344
+
345
+ def voice_conversion(
346
+ self,
347
+ source_wav: str,
348
+ target_wav: str,
349
+ ):
350
+ """Voice conversion with FreeVC. Convert source wav to target speaker.
351
+
352
+ Args:``
353
+ source_wav (str):
354
+ Path to the source wav file.
355
+ target_wav (str):`
356
+ Path to the target wav file.
357
+ """
358
+ wav = self.voice_converter.voice_conversion(source_wav=source_wav, target_wav=target_wav)
359
+ return wav
360
+
361
+ def voice_conversion_to_file(
362
+ self,
363
+ source_wav: str,
364
+ target_wav: str,
365
+ file_path: str = "output.wav",
366
+ ):
367
+ """Voice conversion with FreeVC. Convert source wav to target speaker.
368
+
369
+ Args:
370
+ source_wav (str):
371
+ Path to the source wav file.
372
+ target_wav (str):
373
+ Path to the target wav file.
374
+ file_path (str, optional):
375
+ Output file path. Defaults to "output.wav".
376
+ """
377
+ wav = self.voice_conversion(source_wav=source_wav, target_wav=target_wav)
378
+ save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
379
+ return file_path
380
+
381
+ def tts_with_vc(
382
+ self,
383
+ text: str,
384
+ language: str = None,
385
+ speaker_wav: str = None,
386
+ speaker: str = None,
387
+ split_sentences: bool = True,
388
+ ):
389
+ """Convert text to speech with voice conversion.
390
+
391
+ It combines tts with voice conversion to fake voice cloning.
392
+
393
+ - Convert text to speech with tts.
394
+ - Convert the output wav to target speaker with voice conversion.
395
+
396
+ Args:
397
+ text (str):
398
+ Input text to synthesize.
399
+ language (str, optional):
400
+ Language code for multi-lingual models. You can check whether loaded model is multi-lingual
401
+ `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
402
+ speaker_wav (str, optional):
403
+ Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
404
+ Defaults to None.
405
+ speaker (str, optional):
406
+ Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
407
+ `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
408
+ split_sentences (bool, optional):
409
+ Split text into sentences, synthesize them separately and concatenate the file audio.
410
+ Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
411
+ applicable to the 🐸TTS models. Defaults to True.
412
+ """
413
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
414
+ # Lazy code... save it to a temp file to resample it while reading it for VC
415
+ self.tts_to_file(
416
+ text=text, speaker=speaker, language=language, file_path=fp.name, split_sentences=split_sentences
417
+ )
418
+ if self.voice_converter is None:
419
+ self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
420
+ wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)
421
+ return wav
422
+
423
+ def tts_with_vc_to_file(
424
+ self,
425
+ text: str,
426
+ language: str = None,
427
+ speaker_wav: str = None,
428
+ file_path: str = "output.wav",
429
+ speaker: str = None,
430
+ split_sentences: bool = True,
431
+ ):
432
+ """Convert text to speech with voice conversion and save to file.
433
+
434
+ Check `tts_with_vc` for more details.
435
+
436
+ Args:
437
+ text (str):
438
+ Input text to synthesize.
439
+ language (str, optional):
440
+ Language code for multi-lingual models. You can check whether loaded model is multi-lingual
441
+ `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
442
+ speaker_wav (str, optional):
443
+ Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
444
+ Defaults to None.
445
+ file_path (str, optional):
446
+ Output file path. Defaults to "output.wav".
447
+ speaker (str, optional):
448
+ Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
449
+ `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
450
+ split_sentences (bool, optional):
451
+ Split text into sentences, synthesize them separately and concatenate the file audio.
452
+ Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
453
+ applicable to the 🐸TTS models. Defaults to True.
454
+ """
455
+ wav = self.tts_with_vc(
456
+ text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences
457
+ )
458
+ save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
TTS/bin/__init__.py ADDED
File without changes
TTS/bin/collect_env_info.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Get detailed info about the working environment."""
2
+
3
+ import json
4
+ import os
5
+ import platform
6
+ import sys
7
+
8
+ import numpy
9
+ import torch
10
+
11
+ import TTS
12
+
13
+ sys.path += [os.path.abspath(".."), os.path.abspath(".")]
14
+
15
+
16
+ def system_info():
17
+ return {
18
+ "OS": platform.system(),
19
+ "architecture": platform.architecture(),
20
+ "version": platform.version(),
21
+ "processor": platform.processor(),
22
+ "python": platform.python_version(),
23
+ }
24
+
25
+
26
+ def cuda_info():
27
+ return {
28
+ "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())],
29
+ "available": torch.cuda.is_available(),
30
+ "version": torch.version.cuda,
31
+ }
32
+
33
+
34
+ def package_info():
35
+ return {
36
+ "numpy": numpy.__version__,
37
+ "PyTorch_version": torch.__version__,
38
+ "PyTorch_debug": torch.version.debug,
39
+ "TTS": TTS.__version__,
40
+ }
41
+
42
+
43
+ def main():
44
+ details = {"System": system_info(), "CUDA": cuda_info(), "Packages": package_info()}
45
+ print(json.dumps(details, indent=4, sort_keys=True))
46
+
47
+
48
+ if __name__ == "__main__":
49
+ main()
TTS/bin/compute_attention_masks.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import importlib
3
+ import logging
4
+ import os
5
+ from argparse import RawTextHelpFormatter
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ from tqdm import tqdm
11
+ from trainer.io import load_checkpoint
12
+
13
+ from TTS.config import load_config
14
+ from TTS.tts.datasets.TTSDataset import TTSDataset
15
+ from TTS.tts.models import setup_model
16
+ from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols
17
+ from TTS.utils.audio import AudioProcessor
18
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
19
+
20
+ if __name__ == "__main__":
21
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
22
+
23
+ # pylint: disable=bad-option-value
24
+ parser = argparse.ArgumentParser(
25
+ description="""Extract attention masks from trained Tacotron/Tacotron2 models.
26
+ These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n"""
27
+ """Each attention mask is written to the same path as the input wav file with ".npy" file extension.
28
+ (e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n"""
29
+ """
30
+ Example run:
31
+ CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
32
+ --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth
33
+ --config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
34
+ --dataset_metafile metadata.csv
35
+ --data_path /root/LJSpeech-1.1/
36
+ --batch_size 32
37
+ --dataset ljspeech
38
+ --use_cuda
39
+ """,
40
+ formatter_class=RawTextHelpFormatter,
41
+ )
42
+ parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ")
43
+ parser.add_argument(
44
+ "--config_path",
45
+ type=str,
46
+ required=True,
47
+ help="Path to Tacotron/Tacotron2 config file.",
48
+ )
49
+ parser.add_argument(
50
+ "--dataset",
51
+ type=str,
52
+ default="",
53
+ required=True,
54
+ help="Target dataset processor name from TTS.tts.dataset.preprocess.",
55
+ )
56
+
57
+ parser.add_argument(
58
+ "--dataset_metafile",
59
+ type=str,
60
+ default="",
61
+ required=True,
62
+ help="Dataset metafile inclusing file paths with transcripts.",
63
+ )
64
+ parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.")
65
+ parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, default=False, help="enable/disable cuda.")
66
+
67
+ parser.add_argument(
68
+ "--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA."
69
+ )
70
+ args = parser.parse_args()
71
+
72
+ C = load_config(args.config_path)
73
+ ap = AudioProcessor(**C.audio)
74
+
75
+ # if the vocabulary was passed, replace the default
76
+ if "characters" in C.keys():
77
+ symbols, phonemes = make_symbols(**C.characters) # noqa: F811
78
+
79
+ # load the model
80
+ num_chars = len(phonemes) if C.use_phonemes else len(symbols)
81
+ # TODO: handle multi-speaker
82
+ model = setup_model(C)
83
+ model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True)
84
+
85
+ # data loader
86
+ preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
87
+ preprocessor = getattr(preprocessor, args.dataset)
88
+ meta_data = preprocessor(args.data_path, args.dataset_metafile)
89
+ dataset = TTSDataset(
90
+ model.decoder.r,
91
+ C.text_cleaner,
92
+ compute_linear_spec=False,
93
+ ap=ap,
94
+ meta_data=meta_data,
95
+ characters=C.characters if "characters" in C.keys() else None,
96
+ add_blank=C["add_blank"] if "add_blank" in C.keys() else False,
97
+ use_phonemes=C.use_phonemes,
98
+ phoneme_cache_path=C.phoneme_cache_path,
99
+ phoneme_language=C.phoneme_language,
100
+ enable_eos_bos=C.enable_eos_bos_chars,
101
+ )
102
+
103
+ dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False))
104
+ loader = DataLoader(
105
+ dataset,
106
+ batch_size=args.batch_size,
107
+ num_workers=4,
108
+ collate_fn=dataset.collate_fn,
109
+ shuffle=False,
110
+ drop_last=False,
111
+ )
112
+
113
+ # compute attentions
114
+ file_paths = []
115
+ with torch.no_grad():
116
+ for data in tqdm(loader):
117
+ # setup input data
118
+ text_input = data[0]
119
+ text_lengths = data[1]
120
+ linear_input = data[3]
121
+ mel_input = data[4]
122
+ mel_lengths = data[5]
123
+ stop_targets = data[6]
124
+ item_idxs = data[7]
125
+
126
+ # dispatch data to GPU
127
+ if args.use_cuda:
128
+ text_input = text_input.cuda()
129
+ text_lengths = text_lengths.cuda()
130
+ mel_input = mel_input.cuda()
131
+ mel_lengths = mel_lengths.cuda()
132
+
133
+ model_outputs = model.forward(text_input, text_lengths, mel_input)
134
+
135
+ alignments = model_outputs["alignments"].detach()
136
+ for idx, alignment in enumerate(alignments):
137
+ item_idx = item_idxs[idx]
138
+ # interpolate if r > 1
139
+ alignment = (
140
+ torch.nn.functional.interpolate(
141
+ alignment.transpose(0, 1).unsqueeze(0),
142
+ size=None,
143
+ scale_factor=model.decoder.r,
144
+ mode="nearest",
145
+ align_corners=None,
146
+ recompute_scale_factor=None,
147
+ )
148
+ .squeeze(0)
149
+ .transpose(0, 1)
150
+ )
151
+ # remove paddings
152
+ alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
153
+ # set file paths
154
+ wav_file_name = os.path.basename(item_idx)
155
+ align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy"
156
+ file_path = item_idx.replace(wav_file_name, align_file_name)
157
+ # save output
158
+ wav_file_abs_path = os.path.abspath(item_idx)
159
+ file_abs_path = os.path.abspath(file_path)
160
+ file_paths.append([wav_file_abs_path, file_abs_path])
161
+ np.save(file_path, alignment)
162
+
163
+ # ourput metafile
164
+ metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
165
+
166
+ with open(metafile, "w", encoding="utf-8") as f:
167
+ for p in file_paths:
168
+ f.write(f"{p[0]}|{p[1]}\n")
169
+ print(f" >> Metafile created: {metafile}")
TTS/bin/compute_embeddings.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ from argparse import RawTextHelpFormatter
5
+
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from TTS.config import load_config
10
+ from TTS.config.shared_configs import BaseDatasetConfig
11
+ from TTS.tts.datasets import load_tts_samples
12
+ from TTS.tts.utils.managers import save_file
13
+ from TTS.tts.utils.speakers import SpeakerManager
14
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
15
+
16
+
17
+ def compute_embeddings(
18
+ model_path,
19
+ config_path,
20
+ output_path,
21
+ old_speakers_file=None,
22
+ old_append=False,
23
+ config_dataset_path=None,
24
+ formatter_name=None,
25
+ dataset_name=None,
26
+ dataset_path=None,
27
+ meta_file_train=None,
28
+ meta_file_val=None,
29
+ disable_cuda=False,
30
+ no_eval=False,
31
+ ):
32
+ use_cuda = torch.cuda.is_available() and not disable_cuda
33
+
34
+ if config_dataset_path is not None:
35
+ c_dataset = load_config(config_dataset_path)
36
+ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=not no_eval)
37
+ else:
38
+ c_dataset = BaseDatasetConfig()
39
+ c_dataset.formatter = formatter_name
40
+ c_dataset.dataset_name = dataset_name
41
+ c_dataset.path = dataset_path
42
+ if meta_file_train is not None:
43
+ c_dataset.meta_file_train = meta_file_train
44
+ if meta_file_val is not None:
45
+ c_dataset.meta_file_val = meta_file_val
46
+ meta_data_train, meta_data_eval = load_tts_samples(c_dataset, eval_split=not no_eval)
47
+
48
+ if meta_data_eval is None:
49
+ samples = meta_data_train
50
+ else:
51
+ samples = meta_data_train + meta_data_eval
52
+
53
+ encoder_manager = SpeakerManager(
54
+ encoder_model_path=model_path,
55
+ encoder_config_path=config_path,
56
+ d_vectors_file_path=old_speakers_file,
57
+ use_cuda=use_cuda,
58
+ )
59
+
60
+ class_name_key = encoder_manager.encoder_config.class_name_key
61
+
62
+ # compute speaker embeddings
63
+ if old_speakers_file is not None and old_append:
64
+ speaker_mapping = encoder_manager.embeddings
65
+ else:
66
+ speaker_mapping = {}
67
+
68
+ for fields in tqdm(samples):
69
+ class_name = fields[class_name_key]
70
+ audio_file = fields["audio_file"]
71
+ embedding_key = fields["audio_unique_name"]
72
+
73
+ # Only update the speaker name when the embedding is already in the old file.
74
+ if embedding_key in speaker_mapping:
75
+ speaker_mapping[embedding_key]["name"] = class_name
76
+ continue
77
+
78
+ if old_speakers_file is not None and embedding_key in encoder_manager.clip_ids:
79
+ # get the embedding from the old file
80
+ embedd = encoder_manager.get_embedding_by_clip(embedding_key)
81
+ else:
82
+ # extract the embedding
83
+ embedd = encoder_manager.compute_embedding_from_clip(audio_file)
84
+
85
+ # create speaker_mapping if target dataset is defined
86
+ speaker_mapping[embedding_key] = {}
87
+ speaker_mapping[embedding_key]["name"] = class_name
88
+ speaker_mapping[embedding_key]["embedding"] = embedd
89
+
90
+ if speaker_mapping:
91
+ # save speaker_mapping if target dataset is defined
92
+ if os.path.isdir(output_path):
93
+ mapping_file_path = os.path.join(output_path, "speakers.pth")
94
+ else:
95
+ mapping_file_path = output_path
96
+
97
+ if os.path.dirname(mapping_file_path) != "":
98
+ os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)
99
+
100
+ save_file(speaker_mapping, mapping_file_path)
101
+ print("Speaker embeddings saved at:", mapping_file_path)
102
+
103
+
104
+ if __name__ == "__main__":
105
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
106
+
107
+ parser = argparse.ArgumentParser(
108
+ description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n"""
109
+ """
110
+ Example runs:
111
+ python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --config_dataset_path dataset_config.json
112
+
113
+ python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --formatter_name coqui --dataset_path /path/to/vctk/dataset --dataset_name my_vctk --meta_file_train /path/to/vctk/metafile_train.csv --meta_file_val /path/to/vctk/metafile_eval.csv
114
+ """,
115
+ formatter_class=RawTextHelpFormatter,
116
+ )
117
+ parser.add_argument(
118
+ "--model_path",
119
+ type=str,
120
+ help="Path to model checkpoint file. It defaults to the released speaker encoder.",
121
+ default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar",
122
+ )
123
+ parser.add_argument(
124
+ "--config_path",
125
+ type=str,
126
+ help="Path to model config file. It defaults to the released speaker encoder config.",
127
+ default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json",
128
+ )
129
+ parser.add_argument(
130
+ "--config_dataset_path",
131
+ type=str,
132
+ help="Path to dataset config file. You either need to provide this or `formatter_name`, `dataset_name` and `dataset_path` arguments.",
133
+ default=None,
134
+ )
135
+ parser.add_argument(
136
+ "--output_path",
137
+ type=str,
138
+ help="Path for output `pth` or `json` file.",
139
+ default="speakers.pth",
140
+ )
141
+ parser.add_argument(
142
+ "--old_file",
143
+ type=str,
144
+ help="The old existing embedding file, from which the embeddings will be directly loaded for already computed audio clips.",
145
+ default=None,
146
+ )
147
+ parser.add_argument(
148
+ "--old_append",
149
+ help="Append new audio clip embeddings to the old embedding file, generate a new non-duplicated merged embedding file. Default False",
150
+ default=False,
151
+ action="store_true",
152
+ )
153
+ parser.add_argument("--disable_cuda", action="store_true", help="Flag to disable cuda.", default=False)
154
+ parser.add_argument("--no_eval", help="Do not compute eval?. Default False", default=False, action="store_true")
155
+ parser.add_argument(
156
+ "--formatter_name",
157
+ type=str,
158
+ help="Name of the formatter to use. You either need to provide this or `config_dataset_path`",
159
+ default=None,
160
+ )
161
+ parser.add_argument(
162
+ "--dataset_name",
163
+ type=str,
164
+ help="Name of the dataset to use. You either need to provide this or `config_dataset_path`",
165
+ default=None,
166
+ )
167
+ parser.add_argument(
168
+ "--dataset_path",
169
+ type=str,
170
+ help="Path to the dataset. You either need to provide this or `config_dataset_path`",
171
+ default=None,
172
+ )
173
+ parser.add_argument(
174
+ "--meta_file_train",
175
+ type=str,
176
+ help="Path to the train meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`",
177
+ default=None,
178
+ )
179
+ parser.add_argument(
180
+ "--meta_file_val",
181
+ type=str,
182
+ help="Path to the evaluation meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`",
183
+ default=None,
184
+ )
185
+ args = parser.parse_args()
186
+
187
+ compute_embeddings(
188
+ args.model_path,
189
+ args.config_path,
190
+ args.output_path,
191
+ old_speakers_file=args.old_file,
192
+ old_append=args.old_append,
193
+ config_dataset_path=args.config_dataset_path,
194
+ formatter_name=args.formatter_name,
195
+ dataset_name=args.dataset_name,
196
+ dataset_path=args.dataset_path,
197
+ meta_file_train=args.meta_file_train,
198
+ meta_file_val=args.meta_file_val,
199
+ disable_cuda=args.disable_cuda,
200
+ no_eval=args.no_eval,
201
+ )
TTS/bin/compute_statistics.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import argparse
5
+ import glob
6
+ import logging
7
+ import os
8
+
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ # from TTS.utils.io import load_config
13
+ from TTS.config import load_config
14
+ from TTS.tts.datasets import load_tts_samples
15
+ from TTS.utils.audio import AudioProcessor
16
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
17
+
18
+
19
+ def main():
20
+ """Run preprocessing process."""
21
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
22
+
23
+ parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
24
+ parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.")
25
+ parser.add_argument("out_path", type=str, help="save path (directory and filename).")
26
+ parser.add_argument(
27
+ "--data_path",
28
+ type=str,
29
+ required=False,
30
+ help="folder including the target set of wavs overriding dataset config.",
31
+ )
32
+ args, overrides = parser.parse_known_args()
33
+
34
+ CONFIG = load_config(args.config_path)
35
+ CONFIG.parse_known_args(overrides, relaxed_parser=True)
36
+
37
+ # load config
38
+ CONFIG.audio.signal_norm = False # do not apply earlier normalization
39
+ CONFIG.audio.stats_path = None # discard pre-defined stats
40
+
41
+ # load audio processor
42
+ ap = AudioProcessor(**CONFIG.audio.to_dict())
43
+
44
+ # load the meta data of target dataset
45
+ if args.data_path:
46
+ dataset_items = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True)
47
+ else:
48
+ dataset_items = load_tts_samples(CONFIG.datasets)[0] # take only train data
49
+ print(f" > There are {len(dataset_items)} files.")
50
+
51
+ mel_sum = 0
52
+ mel_square_sum = 0
53
+ linear_sum = 0
54
+ linear_square_sum = 0
55
+ N = 0
56
+ for item in tqdm(dataset_items):
57
+ # compute features
58
+ wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"])
59
+ linear = ap.spectrogram(wav)
60
+ mel = ap.melspectrogram(wav)
61
+
62
+ # compute stats
63
+ N += mel.shape[1]
64
+ mel_sum += mel.sum(1)
65
+ linear_sum += linear.sum(1)
66
+ mel_square_sum += (mel**2).sum(axis=1)
67
+ linear_square_sum += (linear**2).sum(axis=1)
68
+
69
+ mel_mean = mel_sum / N
70
+ mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2)
71
+ linear_mean = linear_sum / N
72
+ linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2)
73
+
74
+ output_file_path = args.out_path
75
+ stats = {}
76
+ stats["mel_mean"] = mel_mean
77
+ stats["mel_std"] = mel_scale
78
+ stats["linear_mean"] = linear_mean
79
+ stats["linear_std"] = linear_scale
80
+
81
+ print(f" > Avg mel spec mean: {mel_mean.mean()}")
82
+ print(f" > Avg mel spec scale: {mel_scale.mean()}")
83
+ print(f" > Avg linear spec mean: {linear_mean.mean()}")
84
+ print(f" > Avg linear spec scale: {linear_scale.mean()}")
85
+
86
+ # set default config values for mean-var scaling
87
+ CONFIG.audio.stats_path = output_file_path
88
+ CONFIG.audio.signal_norm = True
89
+ # remove redundant values
90
+ del CONFIG.audio.max_norm
91
+ del CONFIG.audio.min_level_db
92
+ del CONFIG.audio.symmetric_norm
93
+ del CONFIG.audio.clip_norm
94
+ stats["audio_config"] = CONFIG.audio.to_dict()
95
+ np.save(output_file_path, stats, allow_pickle=True)
96
+ print(f" > stats saved to {output_file_path}")
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
TTS/bin/eval_encoder.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from argparse import RawTextHelpFormatter
4
+
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from TTS.config import load_config
9
+ from TTS.tts.datasets import load_tts_samples
10
+ from TTS.tts.utils.speakers import SpeakerManager
11
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
12
+
13
+
14
+ def compute_encoder_accuracy(dataset_items, encoder_manager):
15
+ class_name_key = encoder_manager.encoder_config.class_name_key
16
+ map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)
17
+
18
+ class_acc_dict = {}
19
+
20
+ # compute embeddings for all wav_files
21
+ for item in tqdm(dataset_items):
22
+ class_name = item[class_name_key]
23
+ wav_file = item["audio_file"]
24
+
25
+ # extract the embedding
26
+ embedd = encoder_manager.compute_embedding_from_clip(wav_file)
27
+ if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None:
28
+ embedding = torch.FloatTensor(embedd).unsqueeze(0)
29
+ if encoder_manager.use_cuda:
30
+ embedding = embedding.cuda()
31
+
32
+ class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item()
33
+ predicted_label = map_classid_to_classname[str(class_id)]
34
+ else:
35
+ predicted_label = None
36
+
37
+ if class_name is not None and predicted_label is not None:
38
+ is_equal = int(class_name == predicted_label)
39
+ if class_name not in class_acc_dict:
40
+ class_acc_dict[class_name] = [is_equal]
41
+ else:
42
+ class_acc_dict[class_name].append(is_equal)
43
+ else:
44
+ raise RuntimeError("Error: class_name or/and predicted_label are None")
45
+
46
+ acc_avg = 0
47
+ for key, values in class_acc_dict.items():
48
+ acc = sum(values) / len(values)
49
+ print("Class", key, "Accuracy:", acc)
50
+ acc_avg += acc
51
+
52
+ print("Average Accuracy:", acc_avg / len(class_acc_dict))
53
+
54
+
55
+ if __name__ == "__main__":
56
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
57
+
58
+ parser = argparse.ArgumentParser(
59
+ description="""Compute the accuracy of the encoder.\n\n"""
60
+ """
61
+ Example runs:
62
+ python TTS/bin/eval_encoder.py emotion_encoder_model.pth emotion_encoder_config.json dataset_config.json
63
+ """,
64
+ formatter_class=RawTextHelpFormatter,
65
+ )
66
+ parser.add_argument("model_path", type=str, help="Path to model checkpoint file.")
67
+ parser.add_argument(
68
+ "config_path",
69
+ type=str,
70
+ help="Path to model config file.",
71
+ )
72
+
73
+ parser.add_argument(
74
+ "config_dataset_path",
75
+ type=str,
76
+ help="Path to dataset config file.",
77
+ )
78
+ parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, help="flag to set cuda.", default=True)
79
+ parser.add_argument("--eval", action=argparse.BooleanOptionalAction, help="compute eval.", default=True)
80
+
81
+ args = parser.parse_args()
82
+
83
+ c_dataset = load_config(args.config_dataset_path)
84
+
85
+ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval)
86
+ items = meta_data_train + meta_data_eval
87
+
88
+ enc_manager = SpeakerManager(
89
+ encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda
90
+ )
91
+
92
+ compute_encoder_accuracy(items, enc_manager)
TTS/bin/extract_tts_spectrograms.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Extract Mel spectrograms with teacher forcing."""
3
+
4
+ import argparse
5
+ import logging
6
+ import os
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import DataLoader
11
+ from tqdm import tqdm
12
+ from trainer.generic_utils import count_parameters
13
+
14
+ from TTS.config import load_config
15
+ from TTS.tts.datasets import TTSDataset, load_tts_samples
16
+ from TTS.tts.models import setup_model
17
+ from TTS.tts.utils.speakers import SpeakerManager
18
+ from TTS.tts.utils.text.tokenizer import TTSTokenizer
19
+ from TTS.utils.audio import AudioProcessor
20
+ from TTS.utils.audio.numpy_transforms import quantize
21
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
22
+
23
+ use_cuda = torch.cuda.is_available()
24
+
25
+
26
+ def setup_loader(ap, r):
27
+ tokenizer, _ = TTSTokenizer.init_from_config(c)
28
+ dataset = TTSDataset(
29
+ outputs_per_step=r,
30
+ compute_linear_spec=False,
31
+ samples=meta_data,
32
+ tokenizer=tokenizer,
33
+ ap=ap,
34
+ batch_group_size=0,
35
+ min_text_len=c.min_text_len,
36
+ max_text_len=c.max_text_len,
37
+ min_audio_len=c.min_audio_len,
38
+ max_audio_len=c.max_audio_len,
39
+ phoneme_cache_path=c.phoneme_cache_path,
40
+ precompute_num_workers=0,
41
+ use_noise_augment=False,
42
+ speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None,
43
+ d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None,
44
+ )
45
+
46
+ if c.use_phonemes and c.compute_input_seq_cache:
47
+ # precompute phonemes to have a better estimate of sequence lengths.
48
+ dataset.compute_input_seq(c.num_loader_workers)
49
+ dataset.preprocess_samples()
50
+
51
+ loader = DataLoader(
52
+ dataset,
53
+ batch_size=c.batch_size,
54
+ shuffle=False,
55
+ collate_fn=dataset.collate_fn,
56
+ drop_last=False,
57
+ sampler=None,
58
+ num_workers=c.num_loader_workers,
59
+ pin_memory=False,
60
+ )
61
+ return loader
62
+
63
+
64
+ def set_filename(wav_path, out_path):
65
+ wav_file = os.path.basename(wav_path)
66
+ file_name = wav_file.split(".")[0]
67
+ os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
68
+ os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
69
+ os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True)
70
+ os.makedirs(os.path.join(out_path, "wav"), exist_ok=True)
71
+ wavq_path = os.path.join(out_path, "quant", file_name)
72
+ mel_path = os.path.join(out_path, "mel", file_name)
73
+ wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav")
74
+ wav_path = os.path.join(out_path, "wav", file_name + ".wav")
75
+ return file_name, wavq_path, mel_path, wav_gl_path, wav_path
76
+
77
+
78
+ def format_data(data):
79
+ # setup input data
80
+ text_input = data["token_id"]
81
+ text_lengths = data["token_id_lengths"]
82
+ mel_input = data["mel"]
83
+ mel_lengths = data["mel_lengths"]
84
+ item_idx = data["item_idxs"]
85
+ d_vectors = data["d_vectors"]
86
+ speaker_ids = data["speaker_ids"]
87
+ attn_mask = data["attns"]
88
+ avg_text_length = torch.mean(text_lengths.float())
89
+ avg_spec_length = torch.mean(mel_lengths.float())
90
+
91
+ # dispatch data to GPU
92
+ if use_cuda:
93
+ text_input = text_input.cuda(non_blocking=True)
94
+ text_lengths = text_lengths.cuda(non_blocking=True)
95
+ mel_input = mel_input.cuda(non_blocking=True)
96
+ mel_lengths = mel_lengths.cuda(non_blocking=True)
97
+ if speaker_ids is not None:
98
+ speaker_ids = speaker_ids.cuda(non_blocking=True)
99
+ if d_vectors is not None:
100
+ d_vectors = d_vectors.cuda(non_blocking=True)
101
+ if attn_mask is not None:
102
+ attn_mask = attn_mask.cuda(non_blocking=True)
103
+ return (
104
+ text_input,
105
+ text_lengths,
106
+ mel_input,
107
+ mel_lengths,
108
+ speaker_ids,
109
+ d_vectors,
110
+ avg_text_length,
111
+ avg_spec_length,
112
+ attn_mask,
113
+ item_idx,
114
+ )
115
+
116
+
117
+ @torch.no_grad()
118
+ def inference(
119
+ model_name,
120
+ model,
121
+ ap,
122
+ text_input,
123
+ text_lengths,
124
+ mel_input,
125
+ mel_lengths,
126
+ speaker_ids=None,
127
+ d_vectors=None,
128
+ ):
129
+ if model_name == "glow_tts":
130
+ speaker_c = None
131
+ if speaker_ids is not None:
132
+ speaker_c = speaker_ids
133
+ elif d_vectors is not None:
134
+ speaker_c = d_vectors
135
+ outputs = model.inference_with_MAS(
136
+ text_input,
137
+ text_lengths,
138
+ mel_input,
139
+ mel_lengths,
140
+ aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids},
141
+ )
142
+ model_output = outputs["model_outputs"]
143
+ model_output = model_output.detach().cpu().numpy()
144
+
145
+ elif "tacotron" in model_name:
146
+ aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
147
+ outputs = model(text_input, text_lengths, mel_input, mel_lengths, aux_input)
148
+ postnet_outputs = outputs["model_outputs"]
149
+ # normalize tacotron output
150
+ if model_name == "tacotron":
151
+ mel_specs = []
152
+ postnet_outputs = postnet_outputs.data.cpu().numpy()
153
+ for b in range(postnet_outputs.shape[0]):
154
+ postnet_output = postnet_outputs[b]
155
+ mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T))
156
+ model_output = torch.stack(mel_specs).cpu().numpy()
157
+
158
+ elif model_name == "tacotron2":
159
+ model_output = postnet_outputs.detach().cpu().numpy()
160
+ return model_output
161
+
162
+
163
+ def extract_spectrograms(
164
+ data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt"
165
+ ):
166
+ model.eval()
167
+ export_metadata = []
168
+ for _, data in tqdm(enumerate(data_loader), total=len(data_loader)):
169
+ # format data
170
+ (
171
+ text_input,
172
+ text_lengths,
173
+ mel_input,
174
+ mel_lengths,
175
+ speaker_ids,
176
+ d_vectors,
177
+ _,
178
+ _,
179
+ _,
180
+ item_idx,
181
+ ) = format_data(data)
182
+
183
+ model_output = inference(
184
+ c.model.lower(),
185
+ model,
186
+ ap,
187
+ text_input,
188
+ text_lengths,
189
+ mel_input,
190
+ mel_lengths,
191
+ speaker_ids,
192
+ d_vectors,
193
+ )
194
+
195
+ for idx in range(text_input.shape[0]):
196
+ wav_file_path = item_idx[idx]
197
+ wav = ap.load_wav(wav_file_path)
198
+ _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
199
+
200
+ # quantize and save wav
201
+ if quantize_bits > 0:
202
+ wavq = quantize(wav, quantize_bits)
203
+ np.save(wavq_path, wavq)
204
+
205
+ # save TTS mel
206
+ mel = model_output[idx]
207
+ mel_length = mel_lengths[idx]
208
+ mel = mel[:mel_length, :].T
209
+ np.save(mel_path, mel)
210
+
211
+ export_metadata.append([wav_file_path, mel_path])
212
+ if save_audio:
213
+ ap.save_wav(wav, wav_path)
214
+
215
+ if debug:
216
+ print("Audio for debug saved at:", wav_gl_path)
217
+ wav = ap.inv_melspectrogram(mel)
218
+ ap.save_wav(wav, wav_gl_path)
219
+
220
+ with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f:
221
+ for data in export_metadata:
222
+ f.write(f"{data[0]}|{data[1]+'.npy'}\n")
223
+
224
+
225
+ def main(args): # pylint: disable=redefined-outer-name
226
+ # pylint: disable=global-variable-undefined
227
+ global meta_data, speaker_manager
228
+
229
+ # Audio processor
230
+ ap = AudioProcessor(**c.audio)
231
+
232
+ # load data instances
233
+ meta_data_train, meta_data_eval = load_tts_samples(
234
+ c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
235
+ )
236
+
237
+ # use eval and training partitions
238
+ meta_data = meta_data_train + meta_data_eval
239
+
240
+ # init speaker manager
241
+ if c.use_speaker_embedding:
242
+ speaker_manager = SpeakerManager(data_items=meta_data)
243
+ elif c.use_d_vector_file:
244
+ speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file)
245
+ else:
246
+ speaker_manager = None
247
+
248
+ # setup model
249
+ model = setup_model(c)
250
+
251
+ # restore model
252
+ model.load_checkpoint(c, args.checkpoint_path, eval=True)
253
+
254
+ if use_cuda:
255
+ model.cuda()
256
+
257
+ num_params = count_parameters(model)
258
+ print("\n > Model has {} parameters".format(num_params), flush=True)
259
+ # set r
260
+ r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
261
+ own_loader = setup_loader(ap, r)
262
+
263
+ extract_spectrograms(
264
+ own_loader,
265
+ model,
266
+ ap,
267
+ args.output_path,
268
+ quantize_bits=args.quantize_bits,
269
+ save_audio=args.save_audio,
270
+ debug=args.debug,
271
+ metada_name="metada.txt",
272
+ )
273
+
274
+
275
+ if __name__ == "__main__":
276
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
277
+
278
+ parser = argparse.ArgumentParser()
279
+ parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True)
280
+ parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True)
281
+ parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True)
282
+ parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug")
283
+ parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files")
284
+ parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero")
285
+ parser.add_argument("--eval", action=argparse.BooleanOptionalAction, help="compute eval.", default=True)
286
+ args = parser.parse_args()
287
+
288
+ c = load_config(args.config_path)
289
+ c.audio.trim_silence = False
290
+ main(args)
TTS/bin/find_unique_chars.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Find all the unique characters in a dataset"""
2
+
3
+ import argparse
4
+ import logging
5
+ from argparse import RawTextHelpFormatter
6
+
7
+ from TTS.config import load_config
8
+ from TTS.tts.datasets import find_unique_chars, load_tts_samples
9
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
10
+
11
+
12
+ def main():
13
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
14
+
15
+ # pylint: disable=bad-option-value
16
+ parser = argparse.ArgumentParser(
17
+ description="""Find all the unique characters or phonemes in a dataset.\n\n"""
18
+ """
19
+ Example runs:
20
+
21
+ python TTS/bin/find_unique_chars.py --config_path config.json
22
+ """,
23
+ formatter_class=RawTextHelpFormatter,
24
+ )
25
+ parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True)
26
+ args = parser.parse_args()
27
+
28
+ c = load_config(args.config_path)
29
+
30
+ # load all datasets
31
+ train_items, eval_items = load_tts_samples(
32
+ c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
33
+ )
34
+
35
+ items = train_items + eval_items
36
+ find_unique_chars(items)
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
TTS/bin/find_unique_phonemes.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Find all the unique characters in a dataset"""
2
+
3
+ import argparse
4
+ import logging
5
+ import multiprocessing
6
+ from argparse import RawTextHelpFormatter
7
+
8
+ from tqdm.contrib.concurrent import process_map
9
+
10
+ from TTS.config import load_config
11
+ from TTS.tts.datasets import load_tts_samples
12
+ from TTS.tts.utils.text.phonemizers import Gruut
13
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
14
+
15
+
16
+ def compute_phonemes(item):
17
+ text = item["text"]
18
+ ph = phonemizer.phonemize(text).replace("|", "")
19
+ return set(ph)
20
+
21
+
22
+ def main():
23
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
24
+
25
+ # pylint: disable=W0601
26
+ global c, phonemizer
27
+ # pylint: disable=bad-option-value
28
+ parser = argparse.ArgumentParser(
29
+ description="""Find all the unique characters or phonemes in a dataset.\n\n"""
30
+ """
31
+ Example runs:
32
+
33
+ python TTS/bin/find_unique_phonemes.py --config_path config.json
34
+ """,
35
+ formatter_class=RawTextHelpFormatter,
36
+ )
37
+ parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True)
38
+ args = parser.parse_args()
39
+
40
+ c = load_config(args.config_path)
41
+
42
+ # load all datasets
43
+ train_items, eval_items = load_tts_samples(
44
+ c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size
45
+ )
46
+ items = train_items + eval_items
47
+ print("Num items:", len(items))
48
+
49
+ language_list = [item["language"] for item in items]
50
+ is_lang_def = all(language_list)
51
+
52
+ if not c.phoneme_language or not is_lang_def:
53
+ raise ValueError("Phoneme language must be defined in config.")
54
+
55
+ if not language_list.count(language_list[0]) == len(language_list):
56
+ raise ValueError(
57
+ "Currently, just one phoneme language per config file is supported !! Please split the dataset config into different configs and run it individually for each language !!"
58
+ )
59
+
60
+ phonemizer = Gruut(language=language_list[0], keep_puncs=True)
61
+
62
+ phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
63
+ phones = []
64
+ for ph in phonemes:
65
+ phones.extend(ph)
66
+
67
+ phones = set(phones)
68
+ lower_phones = filter(lambda c: c.islower(), phones)
69
+ phones_force_lower = [c.lower() for c in phones]
70
+ phones_force_lower = set(phones_force_lower)
71
+
72
+ print(f" > Number of unique phonemes: {len(phones)}")
73
+ print(f" > Unique phonemes: {''.join(sorted(phones))}")
74
+ print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}")
75
+ print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()
TTS/bin/remove_silence_using_vad.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import logging
4
+ import multiprocessing
5
+ import os
6
+ import pathlib
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
12
+ from TTS.utils.vad import get_vad_model_and_utils, remove_silence
13
+
14
+ torch.set_num_threads(1)
15
+
16
+
17
+ def adjust_path_and_remove_silence(audio_path):
18
+ output_path = audio_path.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
19
+ # ignore if the file exists
20
+ if os.path.exists(output_path) and not args.force:
21
+ return output_path, False
22
+
23
+ # create all directory structure
24
+ pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
25
+ # remove the silence and save the audio
26
+ output_path, is_speech = remove_silence(
27
+ model_and_utils,
28
+ audio_path,
29
+ output_path,
30
+ trim_just_beginning_and_end=args.trim_just_beginning_and_end,
31
+ use_cuda=args.use_cuda,
32
+ )
33
+ return output_path, is_speech
34
+
35
+
36
+ def preprocess_audios():
37
+ files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True))
38
+ print("> Number of files: ", len(files))
39
+ if not args.force:
40
+ print("> Ignoring files that already exist in the output idrectory.")
41
+
42
+ if args.trim_just_beginning_and_end:
43
+ print("> Trimming just the beginning and the end with nonspeech parts.")
44
+ else:
45
+ print("> Trimming all nonspeech parts.")
46
+
47
+ filtered_files = []
48
+ if files:
49
+ # create threads
50
+ # num_threads = multiprocessing.cpu_count()
51
+ # process_map(adjust_path_and_remove_silence, files, max_workers=num_threads, chunksize=15)
52
+
53
+ if args.num_processes > 1:
54
+ with multiprocessing.Pool(processes=args.num_processes) as pool:
55
+ results = list(
56
+ tqdm(
57
+ pool.imap_unordered(adjust_path_and_remove_silence, files),
58
+ total=len(files),
59
+ desc="Processing audio files",
60
+ )
61
+ )
62
+ for output_path, is_speech in results:
63
+ if not is_speech:
64
+ filtered_files.append(output_path)
65
+ else:
66
+ for f in tqdm(files):
67
+ output_path, is_speech = adjust_path_and_remove_silence(f)
68
+ if not is_speech:
69
+ filtered_files.append(output_path)
70
+
71
+ # write files that do not have speech
72
+ with open(os.path.join(args.output_dir, "filtered_files.txt"), "w", encoding="utf-8") as f:
73
+ for file in filtered_files:
74
+ f.write(str(file) + "\n")
75
+ else:
76
+ print("> No files Found !")
77
+
78
+
79
+ if __name__ == "__main__":
80
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
81
+
82
+ parser = argparse.ArgumentParser(
83
+ description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end"
84
+ )
85
+ parser.add_argument("-i", "--input_dir", type=str, help="Dataset root dir", required=True)
86
+ parser.add_argument("-o", "--output_dir", type=str, help="Output Dataset dir", default="")
87
+ parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files")
88
+ parser.add_argument(
89
+ "-g",
90
+ "--glob",
91
+ type=str,
92
+ default="**/*.wav",
93
+ help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav",
94
+ )
95
+ parser.add_argument(
96
+ "-t",
97
+ "--trim_just_beginning_and_end",
98
+ action=argparse.BooleanOptionalAction,
99
+ default=True,
100
+ help="If True this script will trim just the beginning and end nonspeech parts. If False all nonspeech parts will be trimmed.",
101
+ )
102
+ parser.add_argument(
103
+ "-c",
104
+ "--use_cuda",
105
+ action=argparse.BooleanOptionalAction,
106
+ default=False,
107
+ help="If True use cuda",
108
+ )
109
+ parser.add_argument(
110
+ "--use_onnx",
111
+ action=argparse.BooleanOptionalAction,
112
+ default=False,
113
+ help="If True use onnx",
114
+ )
115
+ parser.add_argument(
116
+ "--num_processes",
117
+ type=int,
118
+ default=1,
119
+ help="Number of processes to use",
120
+ )
121
+ args = parser.parse_args()
122
+
123
+ if args.output_dir == "":
124
+ args.output_dir = args.input_dir
125
+
126
+ # load the model and utils
127
+ model_and_utils = get_vad_model_and_utils(use_cuda=args.use_cuda, use_onnx=args.use_onnx)
128
+ preprocess_audios()
TTS/bin/resample.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ from argparse import RawTextHelpFormatter
5
+ from multiprocessing import Pool
6
+ from shutil import copytree
7
+
8
+ import librosa
9
+ import soundfile as sf
10
+ from tqdm import tqdm
11
+
12
+
13
+ def resample_file(func_args):
14
+ filename, output_sr = func_args
15
+ y, sr = librosa.load(filename, sr=output_sr)
16
+ sf.write(filename, y, sr)
17
+
18
+
19
+ def resample_files(input_dir, output_sr, output_dir=None, file_ext="wav", n_jobs=10):
20
+ if output_dir:
21
+ print("Recursively copying the input folder...")
22
+ copytree(input_dir, output_dir)
23
+ input_dir = output_dir
24
+
25
+ print("Resampling the audio files...")
26
+ audio_files = glob.glob(os.path.join(input_dir, f"**/*.{file_ext}"), recursive=True)
27
+ print(f"Found {len(audio_files)} files...")
28
+ audio_files = list(zip(audio_files, len(audio_files) * [output_sr]))
29
+ with Pool(processes=n_jobs) as p:
30
+ with tqdm(total=len(audio_files)) as pbar:
31
+ for _, _ in enumerate(p.imap_unordered(resample_file, audio_files)):
32
+ pbar.update()
33
+
34
+ print("Done !")
35
+
36
+
37
+ if __name__ == "__main__":
38
+ parser = argparse.ArgumentParser(
39
+ description="""Resample a folder recusively with librosa
40
+ Can be used in place or create a copy of the folder as an output.\n\n
41
+ Example run:
42
+ python TTS/bin/resample.py
43
+ --input_dir /root/LJSpeech-1.1/
44
+ --output_sr 22050
45
+ --output_dir /root/resampled_LJSpeech-1.1/
46
+ --file_ext wav
47
+ --n_jobs 24
48
+ """,
49
+ formatter_class=RawTextHelpFormatter,
50
+ )
51
+
52
+ parser.add_argument(
53
+ "--input_dir",
54
+ type=str,
55
+ default=None,
56
+ required=True,
57
+ help="Path of the folder containing the audio files to resample",
58
+ )
59
+
60
+ parser.add_argument(
61
+ "--output_sr",
62
+ type=int,
63
+ default=22050,
64
+ required=False,
65
+ help="Samlple rate to which the audio files should be resampled",
66
+ )
67
+
68
+ parser.add_argument(
69
+ "--output_dir",
70
+ type=str,
71
+ default=None,
72
+ required=False,
73
+ help="Path of the destination folder. If not defined, the operation is done in place",
74
+ )
75
+
76
+ parser.add_argument(
77
+ "--file_ext",
78
+ type=str,
79
+ default="wav",
80
+ required=False,
81
+ help="Extension of the audio files to resample",
82
+ )
83
+
84
+ parser.add_argument(
85
+ "--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores"
86
+ )
87
+
88
+ args = parser.parse_args()
89
+
90
+ resample_files(args.input_dir, args.output_sr, args.output_dir, args.file_ext, args.n_jobs)
TTS/bin/synthesize.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """Command line interface."""
4
+
5
+ import argparse
6
+ import contextlib
7
+ import logging
8
+ import sys
9
+ from argparse import RawTextHelpFormatter
10
+
11
+ # pylint: disable=redefined-outer-name, unused-argument
12
+ from pathlib import Path
13
+
14
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ description = """
19
+ Synthesize speech on command line.
20
+
21
+ You can either use your trained model or choose a model from the provided list.
22
+
23
+ If you don't specify any models, then it uses LJSpeech based English model.
24
+
25
+ #### Single Speaker Models
26
+
27
+ - List provided models:
28
+
29
+ ```
30
+ $ tts --list_models
31
+ ```
32
+
33
+ - Get model info (for both tts_models and vocoder_models):
34
+
35
+ - Query by type/name:
36
+ The model_info_by_name uses the name as it from the --list_models.
37
+ ```
38
+ $ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
39
+ ```
40
+ For example:
41
+ ```
42
+ $ tts --model_info_by_name tts_models/tr/common-voice/glow-tts
43
+ $ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2
44
+ ```
45
+ - Query by type/idx:
46
+ The model_query_idx uses the corresponding idx from --list_models.
47
+
48
+ ```
49
+ $ tts --model_info_by_idx "<model_type>/<model_query_idx>"
50
+ ```
51
+
52
+ For example:
53
+
54
+ ```
55
+ $ tts --model_info_by_idx tts_models/3
56
+ ```
57
+
58
+ - Query info for model info by full name:
59
+ ```
60
+ $ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
61
+ ```
62
+
63
+ - Run TTS with default models:
64
+
65
+ ```
66
+ $ tts --text "Text for TTS" --out_path output/path/speech.wav
67
+ ```
68
+
69
+ - Run TTS and pipe out the generated TTS wav file data:
70
+
71
+ ```
72
+ $ tts --text "Text for TTS" --pipe_out --out_path output/path/speech.wav | aplay
73
+ ```
74
+
75
+ - Run a TTS model with its default vocoder model:
76
+
77
+ ```
78
+ $ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
79
+ ```
80
+
81
+ For example:
82
+
83
+ ```
84
+ $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav
85
+ ```
86
+
87
+ - Run with specific TTS and vocoder models from the list:
88
+
89
+ ```
90
+ $ tts --text "Text for TTS" --model_name "<model_type>/<language>/<dataset>/<model_name>" --vocoder_name "<model_type>/<language>/<dataset>/<model_name>" --out_path output/path/speech.wav
91
+ ```
92
+
93
+ For example:
94
+
95
+ ```
96
+ $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav
97
+ ```
98
+
99
+ - Run your own TTS model (Using Griffin-Lim Vocoder):
100
+
101
+ ```
102
+ $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
103
+ ```
104
+
105
+ - Run your own TTS and Vocoder models:
106
+
107
+ ```
108
+ $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav
109
+ --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json
110
+ ```
111
+
112
+ #### Multi-speaker Models
113
+
114
+ - List the available speakers and choose a <speaker_id> among them:
115
+
116
+ ```
117
+ $ tts --model_name "<language>/<dataset>/<model_name>" --list_speaker_idxs
118
+ ```
119
+
120
+ - Run the multi-speaker TTS model with the target speaker ID:
121
+
122
+ ```
123
+ $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx <speaker_id>
124
+ ```
125
+
126
+ - Run your own multi-speaker TTS model:
127
+
128
+ ```
129
+ $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
130
+ ```
131
+
132
+ ### Voice Conversion Models
133
+
134
+ ```
135
+ $ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
136
+ ```
137
+ """
138
+
139
+
140
+ def parse_args() -> argparse.Namespace:
141
+ """Parse arguments."""
142
+ parser = argparse.ArgumentParser(
143
+ description=description.replace(" ```\n", ""),
144
+ formatter_class=RawTextHelpFormatter,
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--list_models",
149
+ action="store_true",
150
+ help="list available pre-trained TTS and vocoder models.",
151
+ )
152
+
153
+ parser.add_argument(
154
+ "--model_info_by_idx",
155
+ type=str,
156
+ default=None,
157
+ help="model info using query format: <model_type>/<model_query_idx>",
158
+ )
159
+
160
+ parser.add_argument(
161
+ "--model_info_by_name",
162
+ type=str,
163
+ default=None,
164
+ help="model info using query format: <model_type>/<language>/<dataset>/<model_name>",
165
+ )
166
+
167
+ parser.add_argument("--text", type=str, default=None, help="Text to generate speech.")
168
+
169
+ # Args for running pre-trained TTS models.
170
+ parser.add_argument(
171
+ "--model_name",
172
+ type=str,
173
+ default="tts_models/en/ljspeech/tacotron2-DDC",
174
+ help="Name of one of the pre-trained TTS models in format <language>/<dataset>/<model_name>",
175
+ )
176
+ parser.add_argument(
177
+ "--vocoder_name",
178
+ type=str,
179
+ default=None,
180
+ help="Name of one of the pre-trained vocoder models in format <language>/<dataset>/<model_name>",
181
+ )
182
+
183
+ # Args for running custom models
184
+ parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.")
185
+ parser.add_argument(
186
+ "--model_path",
187
+ type=str,
188
+ default=None,
189
+ help="Path to model file.",
190
+ )
191
+ parser.add_argument(
192
+ "--out_path",
193
+ type=str,
194
+ default="tts_output.wav",
195
+ help="Output wav file path.",
196
+ )
197
+ parser.add_argument("--use_cuda", action="store_true", help="Run model on CUDA.")
198
+ parser.add_argument("--device", type=str, help="Device to run model on.", default="cpu")
199
+ parser.add_argument(
200
+ "--vocoder_path",
201
+ type=str,
202
+ help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).",
203
+ default=None,
204
+ )
205
+ parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None)
206
+ parser.add_argument(
207
+ "--encoder_path",
208
+ type=str,
209
+ help="Path to speaker encoder model file.",
210
+ default=None,
211
+ )
212
+ parser.add_argument("--encoder_config_path", type=str, help="Path to speaker encoder config file.", default=None)
213
+ parser.add_argument(
214
+ "--pipe_out",
215
+ help="stdout the generated TTS wav file for shell pipe.",
216
+ action="store_true",
217
+ )
218
+
219
+ # args for multi-speaker synthesis
220
+ parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
221
+ parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
222
+ parser.add_argument(
223
+ "--speaker_idx",
224
+ type=str,
225
+ help="Target speaker ID for a multi-speaker TTS model.",
226
+ default=None,
227
+ )
228
+ parser.add_argument(
229
+ "--language_idx",
230
+ type=str,
231
+ help="Target language ID for a multi-lingual TTS model.",
232
+ default=None,
233
+ )
234
+ parser.add_argument(
235
+ "--speaker_wav",
236
+ nargs="+",
237
+ help="wav file(s) to condition a multi-speaker TTS model with a Speaker Encoder. You can give multiple file paths. The d_vectors is computed as their average.",
238
+ default=None,
239
+ )
240
+ parser.add_argument("--gst_style", help="Wav path file for GST style reference.", default=None)
241
+ parser.add_argument(
242
+ "--capacitron_style_wav", type=str, help="Wav path file for Capacitron prosody reference.", default=None
243
+ )
244
+ parser.add_argument("--capacitron_style_text", type=str, help="Transcription of the reference.", default=None)
245
+ parser.add_argument(
246
+ "--list_speaker_idxs",
247
+ help="List available speaker ids for the defined multi-speaker model.",
248
+ action="store_true",
249
+ )
250
+ parser.add_argument(
251
+ "--list_language_idxs",
252
+ help="List available language ids for the defined multi-lingual model.",
253
+ action="store_true",
254
+ )
255
+ # aux args
256
+ parser.add_argument(
257
+ "--save_spectogram",
258
+ action="store_true",
259
+ help="Save raw spectogram for further (vocoder) processing in out_path.",
260
+ )
261
+ parser.add_argument(
262
+ "--reference_wav",
263
+ type=str,
264
+ help="Reference wav file to convert in the voice of the speaker_idx or speaker_wav",
265
+ default=None,
266
+ )
267
+ parser.add_argument(
268
+ "--reference_speaker_idx",
269
+ type=str,
270
+ help="speaker ID of the reference_wav speaker (If not provided the embedding will be computed using the Speaker Encoder).",
271
+ default=None,
272
+ )
273
+ parser.add_argument(
274
+ "--progress_bar",
275
+ action=argparse.BooleanOptionalAction,
276
+ help="Show a progress bar for the model download.",
277
+ default=True,
278
+ )
279
+
280
+ # voice conversion args
281
+ parser.add_argument(
282
+ "--source_wav",
283
+ type=str,
284
+ default=None,
285
+ help="Original audio file to convert in the voice of the target_wav",
286
+ )
287
+ parser.add_argument(
288
+ "--target_wav",
289
+ type=str,
290
+ default=None,
291
+ help="Target audio file to convert in the voice of the source_wav",
292
+ )
293
+
294
+ parser.add_argument(
295
+ "--voice_dir",
296
+ type=str,
297
+ default=None,
298
+ help="Voice dir for tortoise model",
299
+ )
300
+
301
+ args = parser.parse_args()
302
+
303
+ # print the description if either text or list_models is not set
304
+ check_args = [
305
+ args.text,
306
+ args.list_models,
307
+ args.list_speaker_idxs,
308
+ args.list_language_idxs,
309
+ args.reference_wav,
310
+ args.model_info_by_idx,
311
+ args.model_info_by_name,
312
+ args.source_wav,
313
+ args.target_wav,
314
+ ]
315
+ if not any(check_args):
316
+ parser.parse_args(["-h"])
317
+ return args
318
+
319
+
320
+ def main():
321
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
322
+ args = parse_args()
323
+
324
+ pipe_out = sys.stdout if args.pipe_out else None
325
+
326
+ with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout):
327
+ # Late-import to make things load faster
328
+ from TTS.utils.manage import ModelManager
329
+ from TTS.utils.synthesizer import Synthesizer
330
+
331
+ # load model manager
332
+ path = Path(__file__).parent / "../.models.json"
333
+ manager = ModelManager(path, progress_bar=args.progress_bar)
334
+
335
+ tts_path = None
336
+ tts_config_path = None
337
+ speakers_file_path = None
338
+ language_ids_file_path = None
339
+ vocoder_path = None
340
+ vocoder_config_path = None
341
+ encoder_path = None
342
+ encoder_config_path = None
343
+ vc_path = None
344
+ vc_config_path = None
345
+ model_dir = None
346
+
347
+ # CASE1 #list : list pre-trained TTS models
348
+ if args.list_models:
349
+ manager.list_models()
350
+ sys.exit()
351
+
352
+ # CASE2 #info : model info for pre-trained TTS models
353
+ if args.model_info_by_idx:
354
+ model_query = args.model_info_by_idx
355
+ manager.model_info_by_idx(model_query)
356
+ sys.exit()
357
+
358
+ if args.model_info_by_name:
359
+ model_query_full_name = args.model_info_by_name
360
+ manager.model_info_by_full_name(model_query_full_name)
361
+ sys.exit()
362
+
363
+ # CASE3: load pre-trained model paths
364
+ if args.model_name is not None and not args.model_path:
365
+ model_path, config_path, model_item = manager.download_model(args.model_name)
366
+ # tts model
367
+ if model_item["model_type"] == "tts_models":
368
+ tts_path = model_path
369
+ tts_config_path = config_path
370
+ if args.vocoder_name is None and "default_vocoder" in model_item:
371
+ args.vocoder_name = model_item["default_vocoder"]
372
+
373
+ # voice conversion model
374
+ if model_item["model_type"] == "voice_conversion_models":
375
+ vc_path = model_path
376
+ vc_config_path = config_path
377
+
378
+ # tts model with multiple files to be loaded from the directory path
379
+ if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list):
380
+ model_dir = model_path
381
+ tts_path = None
382
+ tts_config_path = None
383
+ args.vocoder_name = None
384
+
385
+ # load vocoder
386
+ if args.vocoder_name is not None and not args.vocoder_path:
387
+ vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
388
+
389
+ # CASE4: set custom model paths
390
+ if args.model_path is not None:
391
+ tts_path = args.model_path
392
+ tts_config_path = args.config_path
393
+ speakers_file_path = args.speakers_file_path
394
+ language_ids_file_path = args.language_ids_file_path
395
+
396
+ if args.vocoder_path is not None:
397
+ vocoder_path = args.vocoder_path
398
+ vocoder_config_path = args.vocoder_config_path
399
+
400
+ if args.encoder_path is not None:
401
+ encoder_path = args.encoder_path
402
+ encoder_config_path = args.encoder_config_path
403
+
404
+ device = args.device
405
+ if args.use_cuda:
406
+ device = "cuda"
407
+
408
+ # load models
409
+ synthesizer = Synthesizer(
410
+ tts_path,
411
+ tts_config_path,
412
+ speakers_file_path,
413
+ language_ids_file_path,
414
+ vocoder_path,
415
+ vocoder_config_path,
416
+ encoder_path,
417
+ encoder_config_path,
418
+ vc_path,
419
+ vc_config_path,
420
+ model_dir,
421
+ args.voice_dir,
422
+ ).to(device)
423
+
424
+ # query speaker ids of a multi-speaker model.
425
+ if args.list_speaker_idxs:
426
+ if synthesizer.tts_model.speaker_manager is None:
427
+ logger.info("Model only has a single speaker.")
428
+ return
429
+ logger.info(
430
+ "Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
431
+ )
432
+ logger.info(synthesizer.tts_model.speaker_manager.name_to_id)
433
+ return
434
+
435
+ # query langauge ids of a multi-lingual model.
436
+ if args.list_language_idxs:
437
+ if synthesizer.tts_model.language_manager is None:
438
+ logger.info("Monolingual model.")
439
+ return
440
+ logger.info(
441
+ "Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
442
+ )
443
+ logger.info(synthesizer.tts_model.language_manager.name_to_id)
444
+ return
445
+
446
+ # check the arguments against a multi-speaker model.
447
+ if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
448
+ logger.error(
449
+ "Looks like you use a multi-speaker model. Define `--speaker_idx` to "
450
+ "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
451
+ )
452
+ return
453
+
454
+ # RUN THE SYNTHESIS
455
+ if args.text:
456
+ logger.info("Text: %s", args.text)
457
+
458
+ # kick it
459
+ if tts_path is not None:
460
+ wav = synthesizer.tts(
461
+ args.text,
462
+ speaker_name=args.speaker_idx,
463
+ language_name=args.language_idx,
464
+ speaker_wav=args.speaker_wav,
465
+ reference_wav=args.reference_wav,
466
+ style_wav=args.capacitron_style_wav,
467
+ style_text=args.capacitron_style_text,
468
+ reference_speaker_name=args.reference_speaker_idx,
469
+ )
470
+ elif vc_path is not None:
471
+ wav = synthesizer.voice_conversion(
472
+ source_wav=args.source_wav,
473
+ target_wav=args.target_wav,
474
+ )
475
+ elif model_dir is not None:
476
+ wav = synthesizer.tts(
477
+ args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav
478
+ )
479
+
480
+ # save the results
481
+ synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out)
482
+ logger.info("Saved output to %s", args.out_path)
483
+
484
+
485
+ if __name__ == "__main__":
486
+ main()
TTS/bin/train_encoder.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import logging
5
+ import os
6
+ import sys
7
+ import time
8
+ import traceback
9
+ import warnings
10
+
11
+ import torch
12
+ from torch.utils.data import DataLoader
13
+ from trainer.generic_utils import count_parameters, remove_experiment_folder
14
+ from trainer.io import copy_model_files, save_best_model, save_checkpoint
15
+ from trainer.torch import NoamLR
16
+ from trainer.trainer_utils import get_optimizer
17
+
18
+ from TTS.encoder.dataset import EncoderDataset
19
+ from TTS.encoder.utils.generic_utils import setup_encoder_model
20
+ from TTS.encoder.utils.training import init_training
21
+ from TTS.encoder.utils.visual import plot_embeddings
22
+ from TTS.tts.datasets import load_tts_samples
23
+ from TTS.utils.audio import AudioProcessor
24
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
25
+ from TTS.utils.samplers import PerfectBatchSampler
26
+ from TTS.utils.training import check_update
27
+
28
+ torch.backends.cudnn.enabled = True
29
+ torch.backends.cudnn.benchmark = True
30
+ torch.manual_seed(54321)
31
+ use_cuda = torch.cuda.is_available()
32
+ num_gpus = torch.cuda.device_count()
33
+ print(" > Using CUDA: ", use_cuda)
34
+ print(" > Number of GPUs: ", num_gpus)
35
+
36
+
37
+ def setup_loader(ap: AudioProcessor, is_val: bool = False):
38
+ num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
39
+ num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch
40
+
41
+ dataset = EncoderDataset(
42
+ c,
43
+ ap,
44
+ meta_data_eval if is_val else meta_data_train,
45
+ voice_len=c.voice_len,
46
+ num_utter_per_class=num_utter_per_class,
47
+ num_classes_in_batch=num_classes_in_batch,
48
+ augmentation_config=c.audio_augmentation if not is_val else None,
49
+ use_torch_spec=c.model_params.get("use_torch_spec", False),
50
+ )
51
+ # get classes list
52
+ classes = dataset.get_class_list()
53
+
54
+ sampler = PerfectBatchSampler(
55
+ dataset.items,
56
+ classes,
57
+ batch_size=num_classes_in_batch * num_utter_per_class, # total batch size
58
+ num_classes_in_batch=num_classes_in_batch,
59
+ num_gpus=1,
60
+ shuffle=not is_val,
61
+ drop_last=True,
62
+ )
63
+
64
+ if len(classes) < num_classes_in_batch:
65
+ if is_val:
66
+ raise RuntimeError(
67
+ f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !"
68
+ )
69
+ raise RuntimeError(
70
+ f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !"
71
+ )
72
+
73
+ # set the classes to avoid get wrong class_id when the number of training and eval classes are not equal
74
+ if is_val:
75
+ dataset.set_classes(train_classes)
76
+
77
+ loader = DataLoader(
78
+ dataset,
79
+ num_workers=c.num_loader_workers,
80
+ batch_sampler=sampler,
81
+ collate_fn=dataset.collate_fn,
82
+ )
83
+
84
+ return loader, classes, dataset.get_map_classid_to_classname()
85
+
86
+
87
+ def evaluation(model, criterion, data_loader, global_step):
88
+ eval_loss = 0
89
+ for _, data in enumerate(data_loader):
90
+ with torch.no_grad():
91
+ # setup input data
92
+ inputs, labels = data
93
+
94
+ # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
95
+ labels = torch.transpose(
96
+ labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1
97
+ ).reshape(labels.shape)
98
+ inputs = torch.transpose(
99
+ inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1
100
+ ).reshape(inputs.shape)
101
+
102
+ # dispatch data to GPU
103
+ if use_cuda:
104
+ inputs = inputs.cuda(non_blocking=True)
105
+ labels = labels.cuda(non_blocking=True)
106
+
107
+ # forward pass model
108
+ outputs = model(inputs)
109
+
110
+ # loss computation
111
+ loss = criterion(
112
+ outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels
113
+ )
114
+
115
+ eval_loss += loss.item()
116
+
117
+ eval_avg_loss = eval_loss / len(data_loader)
118
+ # save stats
119
+ dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss})
120
+ try:
121
+ # plot the last batch in the evaluation
122
+ figures = {
123
+ "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
124
+ }
125
+ dashboard_logger.eval_figures(global_step, figures)
126
+ except ImportError:
127
+ warnings.warn("Install the `umap-learn` package to see embedding plots.")
128
+ return eval_avg_loss
129
+
130
+
131
+ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
132
+ model.train()
133
+ best_loss = {"train_loss": None, "eval_loss": float("inf")}
134
+ avg_loader_time = 0
135
+ end_time = time.time()
136
+ for epoch in range(c.epochs):
137
+ tot_loss = 0
138
+ epoch_time = 0
139
+ for _, data in enumerate(data_loader):
140
+ start_time = time.time()
141
+
142
+ # setup input data
143
+ inputs, labels = data
144
+ # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
145
+ labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(
146
+ labels.shape
147
+ )
148
+ inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(
149
+ inputs.shape
150
+ )
151
+ # ToDo: move it to a unit test
152
+ # labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape)
153
+ # inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape)
154
+ # idx = 0
155
+ # for j in range(0, c.num_classes_in_batch, 1):
156
+ # for i in range(j, len(labels), c.num_classes_in_batch):
157
+ # if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])):
158
+ # print("Invalid")
159
+ # print(labels)
160
+ # exit()
161
+ # idx += 1
162
+ # labels = labels_converted
163
+ # inputs = inputs_converted
164
+
165
+ loader_time = time.time() - end_time
166
+ global_step += 1
167
+
168
+ optimizer.zero_grad()
169
+
170
+ # dispatch data to GPU
171
+ if use_cuda:
172
+ inputs = inputs.cuda(non_blocking=True)
173
+ labels = labels.cuda(non_blocking=True)
174
+
175
+ # forward pass model
176
+ outputs = model(inputs)
177
+
178
+ # loss computation
179
+ loss = criterion(
180
+ outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels
181
+ )
182
+ loss.backward()
183
+ grad_norm, _ = check_update(model, c.grad_clip)
184
+ optimizer.step()
185
+
186
+ # setup lr
187
+ if c.lr_decay:
188
+ scheduler.step()
189
+
190
+ step_time = time.time() - start_time
191
+ epoch_time += step_time
192
+
193
+ # acumulate the total epoch loss
194
+ tot_loss += loss.item()
195
+
196
+ # Averaged Loader Time
197
+ num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1
198
+ avg_loader_time = (
199
+ 1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time
200
+ if avg_loader_time != 0
201
+ else loader_time
202
+ )
203
+ current_lr = optimizer.param_groups[0]["lr"]
204
+
205
+ if global_step % c.steps_plot_stats == 0:
206
+ # Plot Training Epoch Stats
207
+ train_stats = {
208
+ "loss": loss.item(),
209
+ "lr": current_lr,
210
+ "grad_norm": grad_norm,
211
+ "step_time": step_time,
212
+ "avg_loader_time": avg_loader_time,
213
+ }
214
+ dashboard_logger.train_epoch_stats(global_step, train_stats)
215
+ figures = {
216
+ "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
217
+ }
218
+ dashboard_logger.train_figures(global_step, figures)
219
+
220
+ if global_step % c.print_step == 0:
221
+ print(
222
+ " | > Step:{} Loss:{:.5f} GradNorm:{:.5f} "
223
+ "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
224
+ global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr
225
+ ),
226
+ flush=True,
227
+ )
228
+
229
+ if global_step % c.save_step == 0:
230
+ # save model
231
+ save_checkpoint(
232
+ c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
233
+ )
234
+
235
+ end_time = time.time()
236
+
237
+ print("")
238
+ print(
239
+ ">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} "
240
+ "EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format(
241
+ epoch, tot_loss / len(data_loader), grad_norm, epoch_time, avg_loader_time
242
+ ),
243
+ flush=True,
244
+ )
245
+ # evaluation
246
+ if c.run_eval:
247
+ model.eval()
248
+ eval_loss = evaluation(model, criterion, eval_data_loader, global_step)
249
+ print("\n\n")
250
+ print("--> EVAL PERFORMANCE")
251
+ print(
252
+ " | > Epoch:{} AvgLoss: {:.5f} ".format(epoch, eval_loss),
253
+ flush=True,
254
+ )
255
+ # save the best checkpoint
256
+ best_loss = save_best_model(
257
+ {"train_loss": None, "eval_loss": eval_loss},
258
+ best_loss,
259
+ c,
260
+ model,
261
+ optimizer,
262
+ None,
263
+ global_step,
264
+ epoch,
265
+ OUT_PATH,
266
+ criterion=criterion.state_dict(),
267
+ )
268
+ model.train()
269
+
270
+ return best_loss, global_step
271
+
272
+
273
+ def main(args): # pylint: disable=redefined-outer-name
274
+ # pylint: disable=global-variable-undefined
275
+ global meta_data_train
276
+ global meta_data_eval
277
+ global train_classes
278
+
279
+ ap = AudioProcessor(**c.audio)
280
+ model = setup_encoder_model(c)
281
+
282
+ optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model)
283
+
284
+ # pylint: disable=redefined-outer-name
285
+ meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)
286
+
287
+ train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False)
288
+ if c.run_eval:
289
+ eval_data_loader, _, _ = setup_loader(ap, is_val=True)
290
+ else:
291
+ eval_data_loader = None
292
+
293
+ num_classes = len(train_classes)
294
+ criterion = model.get_criterion(c, num_classes)
295
+
296
+ if c.loss == "softmaxproto" and c.model != "speaker_encoder":
297
+ c.map_classid_to_classname = map_classid_to_classname
298
+ copy_model_files(c, OUT_PATH, new_fields={})
299
+
300
+ if args.restore_path:
301
+ criterion, args.restore_step = model.load_checkpoint(
302
+ c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion
303
+ )
304
+ print(" > Model restored from step %d" % args.restore_step, flush=True)
305
+ else:
306
+ args.restore_step = 0
307
+
308
+ if c.lr_decay:
309
+ scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
310
+ else:
311
+ scheduler = None
312
+
313
+ num_params = count_parameters(model)
314
+ print("\n > Model has {} parameters".format(num_params), flush=True)
315
+
316
+ if use_cuda:
317
+ model = model.cuda()
318
+ criterion.cuda()
319
+
320
+ global_step = args.restore_step
321
+ _, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step)
322
+
323
+
324
+ if __name__ == "__main__":
325
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
326
+
327
+ args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training()
328
+
329
+ try:
330
+ main(args)
331
+ except KeyboardInterrupt:
332
+ remove_experiment_folder(OUT_PATH)
333
+ try:
334
+ sys.exit(0)
335
+ except SystemExit:
336
+ os._exit(0) # pylint: disable=protected-access
337
+ except Exception: # pylint: disable=broad-except
338
+ remove_experiment_folder(OUT_PATH)
339
+ traceback.print_exc()
340
+ sys.exit(1)
TTS/bin/train_tts.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from dataclasses import dataclass, field
4
+
5
+ from trainer import Trainer, TrainerArgs
6
+
7
+ from TTS.config import load_config, register_config
8
+ from TTS.tts.datasets import load_tts_samples
9
+ from TTS.tts.models import setup_model
10
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
11
+
12
+
13
+ @dataclass
14
+ class TrainTTSArgs(TrainerArgs):
15
+ config_path: str = field(default=None, metadata={"help": "Path to the config file."})
16
+
17
+
18
+ def main():
19
+ """Run `tts` model training directly by a `config.json` file."""
20
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
21
+
22
+ # init trainer args
23
+ train_args = TrainTTSArgs()
24
+ parser = train_args.init_argparse(arg_prefix="")
25
+
26
+ # override trainer args from comman-line args
27
+ args, config_overrides = parser.parse_known_args()
28
+ train_args.parse_args(args)
29
+
30
+ # load config.json and register
31
+ if args.config_path or args.continue_path:
32
+ if args.config_path:
33
+ # init from a file
34
+ config = load_config(args.config_path)
35
+ if len(config_overrides) > 0:
36
+ config.parse_known_args(config_overrides, relaxed_parser=True)
37
+ elif args.continue_path:
38
+ # continue from a prev experiment
39
+ config = load_config(os.path.join(args.continue_path, "config.json"))
40
+ if len(config_overrides) > 0:
41
+ config.parse_known_args(config_overrides, relaxed_parser=True)
42
+ else:
43
+ # init from console args
44
+ from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
45
+
46
+ config_base = BaseTrainingConfig()
47
+ config_base.parse_known_args(config_overrides)
48
+ config = register_config(config_base.model)()
49
+
50
+ # load training samples
51
+ train_samples, eval_samples = load_tts_samples(
52
+ config.datasets,
53
+ eval_split=True,
54
+ eval_split_max_size=config.eval_split_max_size,
55
+ eval_split_size=config.eval_split_size,
56
+ )
57
+
58
+ # init the model from config
59
+ model = setup_model(config, train_samples + eval_samples)
60
+
61
+ # init the trainer and 🚀
62
+ trainer = Trainer(
63
+ train_args,
64
+ model.config,
65
+ config.output_path,
66
+ model=model,
67
+ train_samples=train_samples,
68
+ eval_samples=eval_samples,
69
+ parse_command_line_args=False,
70
+ )
71
+ trainer.fit()
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()
TTS/bin/train_vocoder.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from dataclasses import dataclass, field
4
+
5
+ from trainer import Trainer, TrainerArgs
6
+
7
+ from TTS.config import load_config, register_config
8
+ from TTS.utils.audio import AudioProcessor
9
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
10
+ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
11
+ from TTS.vocoder.models import setup_model
12
+
13
+
14
+ @dataclass
15
+ class TrainVocoderArgs(TrainerArgs):
16
+ config_path: str = field(default=None, metadata={"help": "Path to the config file."})
17
+
18
+
19
+ def main():
20
+ """Run `tts` model training directly by a `config.json` file."""
21
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
22
+
23
+ # init trainer args
24
+ train_args = TrainVocoderArgs()
25
+ parser = train_args.init_argparse(arg_prefix="")
26
+
27
+ # override trainer args from comman-line args
28
+ args, config_overrides = parser.parse_known_args()
29
+ train_args.parse_args(args)
30
+
31
+ # load config.json and register
32
+ if args.config_path or args.continue_path:
33
+ if args.config_path:
34
+ # init from a file
35
+ config = load_config(args.config_path)
36
+ if len(config_overrides) > 0:
37
+ config.parse_known_args(config_overrides, relaxed_parser=True)
38
+ elif args.continue_path:
39
+ # continue from a prev experiment
40
+ config = load_config(os.path.join(args.continue_path, "config.json"))
41
+ if len(config_overrides) > 0:
42
+ config.parse_known_args(config_overrides, relaxed_parser=True)
43
+ else:
44
+ # init from console args
45
+ from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
46
+
47
+ config_base = BaseTrainingConfig()
48
+ config_base.parse_known_args(config_overrides)
49
+ config = register_config(config_base.model)()
50
+
51
+ # load training samples
52
+ if "feature_path" in config and config.feature_path:
53
+ # load pre-computed features
54
+ print(f" > Loading features from: {config.feature_path}")
55
+ eval_samples, train_samples = load_wav_feat_data(config.data_path, config.feature_path, config.eval_split_size)
56
+ else:
57
+ # load data raw wav files
58
+ eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
59
+
60
+ # setup audio processor
61
+ ap = AudioProcessor(**config.audio)
62
+
63
+ # init the model from config
64
+ model = setup_model(config)
65
+
66
+ # init the trainer and 🚀
67
+ trainer = Trainer(
68
+ train_args,
69
+ config,
70
+ config.output_path,
71
+ model=model,
72
+ train_samples=train_samples,
73
+ eval_samples=eval_samples,
74
+ training_assets={"audio_processor": ap},
75
+ parse_command_line_args=False,
76
+ )
77
+ trainer.fit()
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
TTS/bin/tune_wavegrad.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Search a good noise schedule for WaveGrad for a given number of inference iterations"""
2
+
3
+ import argparse
4
+ import logging
5
+ from itertools import product as cartesian_product
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ from tqdm import tqdm
11
+
12
+ from TTS.config import load_config
13
+ from TTS.utils.audio import AudioProcessor
14
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
15
+ from TTS.vocoder.datasets.preprocess import load_wav_data
16
+ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
17
+ from TTS.vocoder.models import setup_model
18
+
19
+ if __name__ == "__main__":
20
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
21
+
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
24
+ parser.add_argument("--config_path", type=str, help="Path to model config file.")
25
+ parser.add_argument("--data_path", type=str, help="Path to data directory.")
26
+ parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
27
+ parser.add_argument(
28
+ "--num_iter",
29
+ type=int,
30
+ help="Number of model inference iterations that you like to optimize noise schedule for.",
31
+ )
32
+ parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.")
33
+ parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
34
+ parser.add_argument(
35
+ "--search_depth",
36
+ type=int,
37
+ default=3,
38
+ help="Search granularity. Increasing this increases the run-time exponentially.",
39
+ )
40
+
41
+ # load config
42
+ args = parser.parse_args()
43
+ config = load_config(args.config_path)
44
+
45
+ # setup audio processor
46
+ ap = AudioProcessor(**config.audio)
47
+
48
+ # load dataset
49
+ _, train_data = load_wav_data(args.data_path, 0)
50
+ train_data = train_data[: args.num_samples]
51
+ dataset = WaveGradDataset(
52
+ ap=ap,
53
+ items=train_data,
54
+ seq_len=-1,
55
+ hop_len=ap.hop_length,
56
+ pad_short=config.pad_short,
57
+ conv_pad=config.conv_pad,
58
+ is_training=True,
59
+ return_segments=False,
60
+ use_noise_augment=False,
61
+ use_cache=False,
62
+ )
63
+ loader = DataLoader(
64
+ dataset,
65
+ batch_size=1,
66
+ shuffle=False,
67
+ collate_fn=dataset.collate_full_clips,
68
+ drop_last=False,
69
+ num_workers=config.num_loader_workers,
70
+ pin_memory=False,
71
+ )
72
+
73
+ # setup the model
74
+ model = setup_model(config)
75
+ if args.use_cuda:
76
+ model.cuda()
77
+
78
+ # setup optimization parameters
79
+ base_values = sorted(10 * np.random.uniform(size=args.search_depth))
80
+ print(f" > base values: {base_values}")
81
+ exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
82
+ best_error = float("inf")
83
+ best_schedule = None # pylint: disable=C0103
84
+ total_search_iter = len(base_values) ** args.num_iter
85
+ for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
86
+ beta = exponents * base
87
+ model.compute_noise_level(beta)
88
+ for data in loader:
89
+ mel, audio = data
90
+ y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
91
+
92
+ if args.use_cuda:
93
+ y_hat = y_hat.cpu()
94
+ y_hat = y_hat.numpy()
95
+
96
+ mel_hat = []
97
+ for i in range(y_hat.shape[0]):
98
+ m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
99
+ mel_hat.append(torch.from_numpy(m))
100
+
101
+ mel_hat = torch.stack(mel_hat)
102
+ mse = torch.sum((mel - mel_hat) ** 2).mean()
103
+ if mse.item() < best_error:
104
+ best_error = mse.item()
105
+ best_schedule = {"beta": beta}
106
+ print(f" > Found a better schedule. - MSE: {mse.item()}")
107
+ np.save(args.output_path, best_schedule)
TTS/config/__init__.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from typing import Dict
5
+
6
+ import fsspec
7
+ import yaml
8
+ from coqpit import Coqpit
9
+
10
+ from TTS.config.shared_configs import *
11
+ from TTS.utils.generic_utils import find_module
12
+
13
+
14
+ def read_json_with_comments(json_path):
15
+ """for backward compat."""
16
+ # fallback to json
17
+ with fsspec.open(json_path, "r", encoding="utf-8") as f:
18
+ input_str = f.read()
19
+ # handle comments but not urls with //
20
+ input_str = re.sub(
21
+ r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str
22
+ )
23
+ return json.loads(input_str)
24
+
25
+
26
+ def register_config(model_name: str) -> Coqpit:
27
+ """Find the right config for the given model name.
28
+
29
+ Args:
30
+ model_name (str): Model name.
31
+
32
+ Raises:
33
+ ModuleNotFoundError: No matching config for the model name.
34
+
35
+ Returns:
36
+ Coqpit: config class.
37
+ """
38
+ config_class = None
39
+ config_name = model_name + "_config"
40
+
41
+ # TODO: fix this
42
+ if model_name == "xtts":
43
+ from TTS.tts.configs.xtts_config import XttsConfig
44
+
45
+ config_class = XttsConfig
46
+ paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"]
47
+ for path in paths:
48
+ try:
49
+ config_class = find_module(path, config_name)
50
+ except ModuleNotFoundError:
51
+ pass
52
+ if config_class is None:
53
+ raise ModuleNotFoundError(f" [!] Config for {model_name} cannot be found.")
54
+ return config_class
55
+
56
+
57
+ def _process_model_name(config_dict: Dict) -> str:
58
+ """Format the model name as expected. It is a band-aid for the old `vocoder` model names.
59
+
60
+ Args:
61
+ config_dict (Dict): A dictionary including the config fields.
62
+
63
+ Returns:
64
+ str: Formatted modelname.
65
+ """
66
+ model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"]
67
+ model_name = model_name.replace("_generator", "").replace("_discriminator", "")
68
+ return model_name
69
+
70
+
71
+ def load_config(config_path: str) -> Coqpit:
72
+ """Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name
73
+ to find the corresponding Config class. Then initialize the Config.
74
+
75
+ Args:
76
+ config_path (str): path to the config file.
77
+
78
+ Raises:
79
+ TypeError: given config file has an unknown type.
80
+
81
+ Returns:
82
+ Coqpit: TTS config object.
83
+ """
84
+ config_dict = {}
85
+ ext = os.path.splitext(config_path)[1]
86
+ if ext in (".yml", ".yaml"):
87
+ with fsspec.open(config_path, "r", encoding="utf-8") as f:
88
+ data = yaml.safe_load(f)
89
+ elif ext == ".json":
90
+ try:
91
+ with fsspec.open(config_path, "r", encoding="utf-8") as f:
92
+ data = json.load(f)
93
+ except json.decoder.JSONDecodeError:
94
+ # backwards compat.
95
+ data = read_json_with_comments(config_path)
96
+ else:
97
+ raise TypeError(f" [!] Unknown config file type {ext}")
98
+ config_dict.update(data)
99
+ model_name = _process_model_name(config_dict)
100
+ config_class = register_config(model_name.lower())
101
+ config = config_class()
102
+ config.from_dict(config_dict)
103
+ return config
104
+
105
+
106
+ def check_config_and_model_args(config, arg_name, value):
107
+ """Check the give argument in `config.model_args` if exist or in `config` for
108
+ the given value.
109
+
110
+ Return False if the argument does not exist in `config.model_args` or `config`.
111
+ This is to patch up the compatibility between models with and without `model_args`.
112
+
113
+ TODO: Remove this in the future with a unified approach.
114
+ """
115
+ if hasattr(config, "model_args"):
116
+ if arg_name in config.model_args:
117
+ return config.model_args[arg_name] == value
118
+ if hasattr(config, arg_name):
119
+ return config[arg_name] == value
120
+ return False
121
+
122
+
123
+ def get_from_config_or_model_args(config, arg_name):
124
+ """Get the given argument from `config.model_args` if exist or in `config`."""
125
+ if hasattr(config, "model_args"):
126
+ if arg_name in config.model_args:
127
+ return config.model_args[arg_name]
128
+ return config[arg_name]
129
+
130
+
131
+ def get_from_config_or_model_args_with_default(config, arg_name, def_val):
132
+ """Get the given argument from `config.model_args` if exist or in `config`."""
133
+ if hasattr(config, "model_args"):
134
+ if arg_name in config.model_args:
135
+ return config.model_args[arg_name]
136
+ if hasattr(config, arg_name):
137
+ return config[arg_name]
138
+ return def_val
TTS/config/shared_configs.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict, dataclass
2
+ from typing import List
3
+
4
+ from coqpit import Coqpit, check_argument
5
+ from trainer import TrainerConfig
6
+
7
+
8
+ @dataclass
9
+ class BaseAudioConfig(Coqpit):
10
+ """Base config to definge audio processing parameters. It is used to initialize
11
+ ```TTS.utils.audio.AudioProcessor.```
12
+
13
+ Args:
14
+ fft_size (int):
15
+ Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024.
16
+
17
+ win_length (int):
18
+ Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match
19
+ ```fft_size```. Defaults to 1024.
20
+
21
+ hop_length (int):
22
+ Number of audio samples between adjacent STFT columns. Defaults to 1024.
23
+
24
+ frame_shift_ms (int):
25
+ Set ```hop_length``` based on milliseconds and sampling rate.
26
+
27
+ frame_length_ms (int):
28
+ Set ```win_length``` based on milliseconds and sampling rate.
29
+
30
+ stft_pad_mode (str):
31
+ Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'.
32
+
33
+ sample_rate (int):
34
+ Audio sampling rate. Defaults to 22050.
35
+
36
+ resample (bool):
37
+ Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```.
38
+
39
+ preemphasis (float):
40
+ Preemphasis coefficient. Defaults to 0.0.
41
+
42
+ ref_level_db (int): 20
43
+ Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air.
44
+ Defaults to 20.
45
+
46
+ do_sound_norm (bool):
47
+ Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False.
48
+
49
+ log_func (str):
50
+ Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'.
51
+
52
+ do_trim_silence (bool):
53
+ Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```.
54
+
55
+ do_amp_to_db_linear (bool, optional):
56
+ enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True.
57
+
58
+ do_amp_to_db_mel (bool, optional):
59
+ enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True.
60
+
61
+ pitch_fmax (float, optional):
62
+ Maximum frequency of the F0 frames. Defaults to ```640```.
63
+
64
+ pitch_fmin (float, optional):
65
+ Minimum frequency of the F0 frames. Defaults to ```1```.
66
+
67
+ trim_db (int):
68
+ Silence threshold used for silence trimming. Defaults to 45.
69
+
70
+ do_rms_norm (bool, optional):
71
+ enable/disable RMS volume normalization when loading an audio file. Defaults to False.
72
+
73
+ db_level (int, optional):
74
+ dB level used for rms normalization. The range is -99 to 0. Defaults to None.
75
+
76
+ power (float):
77
+ Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the
78
+ artifacts in the synthesized voice. Defaults to 1.5.
79
+
80
+ griffin_lim_iters (int):
81
+ Number of Griffing Lim iterations. Defaults to 60.
82
+
83
+ num_mels (int):
84
+ Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80.
85
+
86
+ mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices.
87
+ It needs to be adjusted for a dataset. Defaults to 0.
88
+
89
+ mel_fmax (float):
90
+ Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset.
91
+
92
+ spec_gain (int):
93
+ Gain applied when converting amplitude to DB. Defaults to 20.
94
+
95
+ signal_norm (bool):
96
+ enable/disable signal normalization. Defaults to True.
97
+
98
+ min_level_db (int):
99
+ minimum db threshold for the computed melspectrograms. Defaults to -100.
100
+
101
+ symmetric_norm (bool):
102
+ enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else
103
+ [0, k], Defaults to True.
104
+
105
+ max_norm (float):
106
+ ```k``` defining the normalization range. Defaults to 4.0.
107
+
108
+ clip_norm (bool):
109
+ enable/disable clipping the our of range values in the normalized audio signal. Defaults to True.
110
+
111
+ stats_path (str):
112
+ Path to the computed stats file. Defaults to None.
113
+ """
114
+
115
+ # stft parameters
116
+ fft_size: int = 1024
117
+ win_length: int = 1024
118
+ hop_length: int = 256
119
+ frame_shift_ms: int = None
120
+ frame_length_ms: int = None
121
+ stft_pad_mode: str = "reflect"
122
+ # audio processing parameters
123
+ sample_rate: int = 22050
124
+ resample: bool = False
125
+ preemphasis: float = 0.0
126
+ ref_level_db: int = 20
127
+ do_sound_norm: bool = False
128
+ log_func: str = "np.log10"
129
+ # silence trimming
130
+ do_trim_silence: bool = True
131
+ trim_db: int = 45
132
+ # rms volume normalization
133
+ do_rms_norm: bool = False
134
+ db_level: float = None
135
+ # griffin-lim params
136
+ power: float = 1.5
137
+ griffin_lim_iters: int = 60
138
+ # mel-spec params
139
+ num_mels: int = 80
140
+ mel_fmin: float = 0.0
141
+ mel_fmax: float = None
142
+ spec_gain: int = 20
143
+ do_amp_to_db_linear: bool = True
144
+ do_amp_to_db_mel: bool = True
145
+ # f0 params
146
+ pitch_fmax: float = 640.0
147
+ pitch_fmin: float = 1.0
148
+ # normalization params
149
+ signal_norm: bool = True
150
+ min_level_db: int = -100
151
+ symmetric_norm: bool = True
152
+ max_norm: float = 4.0
153
+ clip_norm: bool = True
154
+ stats_path: str = None
155
+
156
+ def check_values(
157
+ self,
158
+ ):
159
+ """Check config fields"""
160
+ c = asdict(self)
161
+ check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056)
162
+ check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058)
163
+ check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000)
164
+ check_argument(
165
+ "frame_length_ms",
166
+ c,
167
+ restricted=True,
168
+ min_val=10,
169
+ max_val=1000,
170
+ alternative="win_length",
171
+ )
172
+ check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length")
173
+ check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1)
174
+ check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10)
175
+ check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000)
176
+ check_argument("power", c, restricted=True, min_val=1, max_val=5)
177
+ check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000)
178
+
179
+ # normalization parameters
180
+ check_argument("signal_norm", c, restricted=True)
181
+ check_argument("symmetric_norm", c, restricted=True)
182
+ check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000)
183
+ check_argument("clip_norm", c, restricted=True)
184
+ check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000)
185
+ check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True)
186
+ check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100)
187
+ check_argument("do_trim_silence", c, restricted=True)
188
+ check_argument("trim_db", c, restricted=True)
189
+
190
+
191
+ @dataclass
192
+ class BaseDatasetConfig(Coqpit):
193
+ """Base config for TTS datasets.
194
+
195
+ Args:
196
+ formatter (str):
197
+ Formatter name that defines used formatter in ```TTS.tts.datasets.formatter```. Defaults to `""`.
198
+
199
+ dataset_name (str):
200
+ Unique name for the dataset. Defaults to `""`.
201
+
202
+ path (str):
203
+ Root path to the dataset files. Defaults to `""`.
204
+
205
+ meta_file_train (str):
206
+ Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets.
207
+ Defaults to `""`.
208
+
209
+ ignored_speakers (List):
210
+ List of speakers IDs that are not used at the training. Default None.
211
+
212
+ language (str):
213
+ Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to `""`.
214
+
215
+ phonemizer (str):
216
+ Phonemizer used for that dataset's language. By default it uses `DEF_LANG_TO_PHONEMIZER`. Defaults to `""`.
217
+
218
+ meta_file_val (str):
219
+ Name of the dataset meta file that defines the instances used at validation.
220
+
221
+ meta_file_attn_mask (str):
222
+ Path to the file that lists the attention mask files used with models that require attention masks to
223
+ train the duration predictor.
224
+ """
225
+
226
+ formatter: str = ""
227
+ dataset_name: str = ""
228
+ path: str = ""
229
+ meta_file_train: str = ""
230
+ ignored_speakers: List[str] = None
231
+ language: str = ""
232
+ phonemizer: str = ""
233
+ meta_file_val: str = ""
234
+ meta_file_attn_mask: str = ""
235
+
236
+ def check_values(
237
+ self,
238
+ ):
239
+ """Check config fields"""
240
+ c = asdict(self)
241
+ check_argument("formatter", c, restricted=True)
242
+ check_argument("path", c, restricted=True)
243
+ check_argument("meta_file_train", c, restricted=True)
244
+ check_argument("meta_file_val", c, restricted=False)
245
+ check_argument("meta_file_attn_mask", c, restricted=False)
246
+
247
+
248
+ @dataclass
249
+ class BaseTrainingConfig(TrainerConfig):
250
+ """Base config to define the basic 🐸TTS training parameters that are shared
251
+ among all the models. It is based on ```Trainer.TrainingConfig```.
252
+
253
+ Args:
254
+ model (str):
255
+ Name of the model that is used in the training.
256
+
257
+ num_loader_workers (int):
258
+ Number of workers for training time dataloader.
259
+
260
+ num_eval_loader_workers (int):
261
+ Number of workers for evaluation time dataloader.
262
+ """
263
+
264
+ model: str = None
265
+ # dataloading
266
+ num_loader_workers: int = 0
267
+ num_eval_loader_workers: int = 0
268
+ use_noise_augment: bool = False
TTS/demos/xtts_ft_demo/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ faster_whisper==0.9.0
2
+ gradio==4.7.1
TTS/demos/xtts_ft_demo/utils/formatter.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+
4
+ import pandas
5
+ import torch
6
+ import torchaudio
7
+ from faster_whisper import WhisperModel
8
+ from tqdm import tqdm
9
+
10
+ # torch.set_num_threads(1)
11
+ from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
12
+
13
+ torch.set_num_threads(16)
14
+
15
+ audio_types = (".wav", ".mp3", ".flac")
16
+
17
+
18
+ def list_audios(basePath, contains=None):
19
+ # return the set of files that are valid
20
+ return list_files(basePath, validExts=audio_types, contains=contains)
21
+
22
+
23
+ def list_files(basePath, validExts=None, contains=None):
24
+ # loop over the directory structure
25
+ for rootDir, dirNames, filenames in os.walk(basePath):
26
+ # loop over the filenames in the current directory
27
+ for filename in filenames:
28
+ # if the contains string is not none and the filename does not contain
29
+ # the supplied string, then ignore the file
30
+ if contains is not None and filename.find(contains) == -1:
31
+ continue
32
+
33
+ # determine the file extension of the current file
34
+ ext = filename[filename.rfind(".") :].lower()
35
+
36
+ # check to see if the file is an audio and should be processed
37
+ if validExts is None or ext.endswith(validExts):
38
+ # construct the path to the audio and yield it
39
+ audioPath = os.path.join(rootDir, filename)
40
+ yield audioPath
41
+
42
+
43
+ def format_audio_list(
44
+ audio_files,
45
+ target_language="en",
46
+ out_path=None,
47
+ buffer=0.2,
48
+ eval_percentage=0.15,
49
+ speaker_name="coqui",
50
+ gradio_progress=None,
51
+ ):
52
+ audio_total_size = 0
53
+ # make sure that ooutput file exists
54
+ os.makedirs(out_path, exist_ok=True)
55
+
56
+ # Loading Whisper
57
+ device = "cuda" if torch.cuda.is_available() else "cpu"
58
+
59
+ print("Loading Whisper Model!")
60
+ asr_model = WhisperModel("large-v2", device=device, compute_type="float16")
61
+
62
+ metadata = {"audio_file": [], "text": [], "speaker_name": []}
63
+
64
+ if gradio_progress is not None:
65
+ tqdm_object = gradio_progress.tqdm(audio_files, desc="Formatting...")
66
+ else:
67
+ tqdm_object = tqdm(audio_files)
68
+
69
+ for audio_path in tqdm_object:
70
+ wav, sr = torchaudio.load(audio_path)
71
+ # stereo to mono if needed
72
+ if wav.size(0) != 1:
73
+ wav = torch.mean(wav, dim=0, keepdim=True)
74
+
75
+ wav = wav.squeeze()
76
+ audio_total_size += wav.size(-1) / sr
77
+
78
+ segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
79
+ segments = list(segments)
80
+ i = 0
81
+ sentence = ""
82
+ sentence_start = None
83
+ first_word = True
84
+ # added all segments words in a unique list
85
+ words_list = []
86
+ for _, segment in enumerate(segments):
87
+ words = list(segment.words)
88
+ words_list.extend(words)
89
+
90
+ # process each word
91
+ for word_idx, word in enumerate(words_list):
92
+ if first_word:
93
+ sentence_start = word.start
94
+ # If it is the first sentence, add buffer or get the begining of the file
95
+ if word_idx == 0:
96
+ sentence_start = max(sentence_start - buffer, 0) # Add buffer to the sentence start
97
+ else:
98
+ # get previous sentence end
99
+ previous_word_end = words_list[word_idx - 1].end
100
+ # add buffer or get the silence midle between the previous sentence and the current one
101
+ sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)
102
+
103
+ sentence = word.word
104
+ first_word = False
105
+ else:
106
+ sentence += word.word
107
+
108
+ if word.word[-1] in ["!", ".", "?"]:
109
+ sentence = sentence[1:]
110
+ # Expand number and abbreviations plus normalization
111
+ sentence = multilingual_cleaners(sentence, target_language)
112
+ audio_file_name, _ = os.path.splitext(os.path.basename(audio_path))
113
+
114
+ audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav"
115
+
116
+ # Check for the next word's existence
117
+ if word_idx + 1 < len(words_list):
118
+ next_word_start = words_list[word_idx + 1].start
119
+ else:
120
+ # If don't have more words it means that it is the last sentence then use the audio len as next word start
121
+ next_word_start = (wav.shape[0] - 1) / sr
122
+
123
+ # Average the current word end and next word start
124
+ word_end = min((word.end + next_word_start) / 2, word.end + buffer)
125
+
126
+ absoulte_path = os.path.join(out_path, audio_file)
127
+ os.makedirs(os.path.dirname(absoulte_path), exist_ok=True)
128
+ i += 1
129
+ first_word = True
130
+
131
+ audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0)
132
+ # if the audio is too short ignore it (i.e < 0.33 seconds)
133
+ if audio.size(-1) >= sr / 3:
134
+ torchaudio.save(absoulte_path, audio, sr)
135
+ else:
136
+ continue
137
+
138
+ metadata["audio_file"].append(audio_file)
139
+ metadata["text"].append(sentence)
140
+ metadata["speaker_name"].append(speaker_name)
141
+
142
+ df = pandas.DataFrame(metadata)
143
+ df = df.sample(frac=1)
144
+ num_val_samples = int(len(df) * eval_percentage)
145
+
146
+ df_eval = df[:num_val_samples]
147
+ df_train = df[num_val_samples:]
148
+
149
+ df_train = df_train.sort_values("audio_file")
150
+ train_metadata_path = os.path.join(out_path, "metadata_train.csv")
151
+ df_train.to_csv(train_metadata_path, sep="|", index=False)
152
+
153
+ eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
154
+ df_eval = df_eval.sort_values("audio_file")
155
+ df_eval.to_csv(eval_metadata_path, sep="|", index=False)
156
+
157
+ # deallocate VRAM and RAM
158
+ del asr_model, df_train, df_eval, df, metadata
159
+ gc.collect()
160
+
161
+ return train_metadata_path, eval_metadata_path, audio_total_size
TTS/demos/xtts_ft_demo/utils/gpt_train.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+
4
+ from trainer import Trainer, TrainerArgs
5
+
6
+ from TTS.config.shared_configs import BaseDatasetConfig
7
+ from TTS.tts.datasets import load_tts_samples
8
+ from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
9
+ from TTS.utils.manage import ModelManager
10
+
11
+
12
+ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995):
13
+ # Logging parameters
14
+ RUN_NAME = "GPT_XTTS_FT"
15
+ PROJECT_NAME = "XTTS_trainer"
16
+ DASHBOARD_LOGGER = "tensorboard"
17
+ LOGGER_URI = None
18
+
19
+ # Set here the path that the checkpoints will be saved. Default: ./run/training/
20
+ OUT_PATH = os.path.join(output_path, "run", "training")
21
+
22
+ # Training Parameters
23
+ OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
24
+ START_WITH_EVAL = False # if True it will star with evaluation
25
+ BATCH_SIZE = batch_size # set here the batch size
26
+ GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
27
+
28
+ # Define here the dataset that you want to use for the fine-tuning on.
29
+ config_dataset = BaseDatasetConfig(
30
+ formatter="coqui",
31
+ dataset_name="ft_dataset",
32
+ path=os.path.dirname(train_csv),
33
+ meta_file_train=train_csv,
34
+ meta_file_val=eval_csv,
35
+ language=language,
36
+ )
37
+
38
+ # Add here the configs of the datasets
39
+ DATASETS_CONFIG_LIST = [config_dataset]
40
+
41
+ # Define the path where XTTS v2.0.1 files will be downloaded
42
+ CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
43
+ os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
44
+
45
+ # DVAE files
46
+ DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
47
+ MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
48
+
49
+ # Set the path to the downloaded files
50
+ DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK))
51
+ MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK))
52
+
53
+ # download DVAE files if needed
54
+ if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
55
+ print(" > Downloading DVAE files!")
56
+ ModelManager._download_model_files(
57
+ [MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
58
+ )
59
+
60
+ # Download XTTS v2.0 checkpoint if needed
61
+ TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
62
+ XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"
63
+ XTTS_CONFIG_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json"
64
+
65
+ # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
66
+ TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file
67
+ XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file
68
+ XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file
69
+
70
+ # download XTTS v2.0 files if needed
71
+ if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
72
+ print(" > Downloading XTTS v2.0 files!")
73
+ ModelManager._download_model_files(
74
+ [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK, XTTS_CONFIG_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
75
+ )
76
+
77
+ # init args and config
78
+ model_args = GPTArgs(
79
+ max_conditioning_length=132300, # 6 secs
80
+ min_conditioning_length=66150, # 3 secs
81
+ debug_loading_failures=False,
82
+ max_wav_length=max_audio_length, # ~11.6 seconds
83
+ max_text_length=200,
84
+ mel_norm_file=MEL_NORM_FILE,
85
+ dvae_checkpoint=DVAE_CHECKPOINT,
86
+ xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
87
+ tokenizer_file=TOKENIZER_FILE,
88
+ gpt_num_audio_tokens=1026,
89
+ gpt_start_audio_token=1024,
90
+ gpt_stop_audio_token=1025,
91
+ gpt_use_masking_gt_prompt_approach=True,
92
+ gpt_use_perceiver_resampler=True,
93
+ )
94
+ # define audio config
95
+ audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
96
+ # training parameters config
97
+ config = GPTTrainerConfig(
98
+ epochs=num_epochs,
99
+ output_path=OUT_PATH,
100
+ model_args=model_args,
101
+ run_name=RUN_NAME,
102
+ project_name=PROJECT_NAME,
103
+ run_description="""
104
+ GPT XTTS training
105
+ """,
106
+ dashboard_logger=DASHBOARD_LOGGER,
107
+ logger_uri=LOGGER_URI,
108
+ audio=audio_config,
109
+ batch_size=BATCH_SIZE,
110
+ batch_group_size=48,
111
+ eval_batch_size=BATCH_SIZE,
112
+ num_loader_workers=8,
113
+ eval_split_max_size=256,
114
+ print_step=50,
115
+ plot_step=100,
116
+ log_model_step=100,
117
+ save_step=1000,
118
+ save_n_checkpoints=1,
119
+ save_checkpoints=True,
120
+ # target_loss="loss",
121
+ print_eval=False,
122
+ # Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
123
+ optimizer="AdamW",
124
+ optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
125
+ optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
126
+ lr=5e-06, # learning rate
127
+ lr_scheduler="MultiStepLR",
128
+ # it was adjusted accordly for the new step scheme
129
+ lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
130
+ test_sentences=[],
131
+ )
132
+
133
+ # init the model from config
134
+ model = GPTTrainer.init_from_config(config)
135
+
136
+ # load training samples
137
+ train_samples, eval_samples = load_tts_samples(
138
+ DATASETS_CONFIG_LIST,
139
+ eval_split=True,
140
+ eval_split_max_size=config.eval_split_max_size,
141
+ eval_split_size=config.eval_split_size,
142
+ )
143
+
144
+ # init the trainer and 🚀
145
+ trainer = Trainer(
146
+ TrainerArgs(
147
+ restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
148
+ skip_train_epoch=False,
149
+ start_with_eval=START_WITH_EVAL,
150
+ grad_accum_steps=GRAD_ACUMM_STEPS,
151
+ ),
152
+ config,
153
+ output_path=OUT_PATH,
154
+ model=model,
155
+ train_samples=train_samples,
156
+ eval_samples=eval_samples,
157
+ )
158
+ trainer.fit()
159
+
160
+ # get the longest text audio file to use as speaker reference
161
+ samples_len = [len(item["text"].split(" ")) for item in train_samples]
162
+ longest_text_idx = samples_len.index(max(samples_len))
163
+ speaker_ref = train_samples[longest_text_idx]["audio_file"]
164
+
165
+ trainer_out_path = trainer.output_path
166
+
167
+ # deallocate VRAM and RAM
168
+ del model, trainer, train_samples, eval_samples
169
+ gc.collect()
170
+
171
+ return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer_out_path, speaker_ref
TTS/demos/xtts_ft_demo/xtts_demo.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import sys
5
+ import tempfile
6
+ import traceback
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import torchaudio
11
+
12
+ from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
13
+ from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
14
+ from TTS.tts.configs.xtts_config import XttsConfig
15
+ from TTS.tts.models.xtts import Xtts
16
+
17
+
18
+ def clear_gpu_cache():
19
+ # clear the GPU cache
20
+ if torch.cuda.is_available():
21
+ torch.cuda.empty_cache()
22
+
23
+
24
+ XTTS_MODEL = None
25
+
26
+
27
+ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
28
+ global XTTS_MODEL
29
+ clear_gpu_cache()
30
+ if not xtts_checkpoint or not xtts_config or not xtts_vocab:
31
+ return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
32
+ config = XttsConfig()
33
+ config.load_json(xtts_config)
34
+ XTTS_MODEL = Xtts.init_from_config(config)
35
+ print("Loading XTTS model! ")
36
+ XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False)
37
+ if torch.cuda.is_available():
38
+ XTTS_MODEL.cuda()
39
+
40
+ print("Model Loaded!")
41
+ return "Model Loaded!"
42
+
43
+
44
+ def run_tts(lang, tts_text, speaker_audio_file):
45
+ if XTTS_MODEL is None or not speaker_audio_file:
46
+ return "You need to run the previous step to load the model !!", None, None
47
+
48
+ gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
49
+ audio_path=speaker_audio_file,
50
+ gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
51
+ max_ref_length=XTTS_MODEL.config.max_ref_len,
52
+ sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
53
+ )
54
+ out = XTTS_MODEL.inference(
55
+ text=tts_text,
56
+ language=lang,
57
+ gpt_cond_latent=gpt_cond_latent,
58
+ speaker_embedding=speaker_embedding,
59
+ temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
60
+ length_penalty=XTTS_MODEL.config.length_penalty,
61
+ repetition_penalty=XTTS_MODEL.config.repetition_penalty,
62
+ top_k=XTTS_MODEL.config.top_k,
63
+ top_p=XTTS_MODEL.config.top_p,
64
+ )
65
+
66
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
67
+ out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
68
+ out_path = fp.name
69
+ torchaudio.save(out_path, out["wav"], 24000)
70
+
71
+ return "Speech generated !", out_path, speaker_audio_file
72
+
73
+
74
+ # define a logger to redirect
75
+ class Logger:
76
+ def __init__(self, filename="log.out"):
77
+ self.log_file = filename
78
+ self.terminal = sys.stdout
79
+ self.log = open(self.log_file, "w")
80
+
81
+ def write(self, message):
82
+ self.terminal.write(message)
83
+ self.log.write(message)
84
+
85
+ def flush(self):
86
+ self.terminal.flush()
87
+ self.log.flush()
88
+
89
+ def isatty(self):
90
+ return False
91
+
92
+
93
+ # redirect stdout and stderr to a file
94
+ sys.stdout = Logger()
95
+ sys.stderr = sys.stdout
96
+
97
+
98
+ # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
99
+
100
+ logging.basicConfig(
101
+ level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)]
102
+ )
103
+
104
+
105
+ def read_logs():
106
+ sys.stdout.flush()
107
+ with open(sys.stdout.log_file, "r") as f:
108
+ return f.read()
109
+
110
+
111
+ if __name__ == "__main__":
112
+ parser = argparse.ArgumentParser(
113
+ description="""XTTS fine-tuning demo\n\n"""
114
+ """
115
+ Example runs:
116
+ python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
117
+ """,
118
+ formatter_class=argparse.RawTextHelpFormatter,
119
+ )
120
+ parser.add_argument(
121
+ "--port",
122
+ type=int,
123
+ help="Port to run the gradio demo. Default: 5003",
124
+ default=5003,
125
+ )
126
+ parser.add_argument(
127
+ "--out_path",
128
+ type=str,
129
+ help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/",
130
+ default="/tmp/xtts_ft/",
131
+ )
132
+
133
+ parser.add_argument(
134
+ "--num_epochs",
135
+ type=int,
136
+ help="Number of epochs to train. Default: 10",
137
+ default=10,
138
+ )
139
+ parser.add_argument(
140
+ "--batch_size",
141
+ type=int,
142
+ help="Batch size. Default: 4",
143
+ default=4,
144
+ )
145
+ parser.add_argument(
146
+ "--grad_acumm",
147
+ type=int,
148
+ help="Grad accumulation steps. Default: 1",
149
+ default=1,
150
+ )
151
+ parser.add_argument(
152
+ "--max_audio_length",
153
+ type=int,
154
+ help="Max permitted audio size in seconds. Default: 11",
155
+ default=11,
156
+ )
157
+
158
+ args = parser.parse_args()
159
+
160
+ with gr.Blocks() as demo:
161
+ with gr.Tab("1 - Data processing"):
162
+ out_path = gr.Textbox(
163
+ label="Output path (where data and checkpoints will be saved):",
164
+ value=args.out_path,
165
+ )
166
+ # upload_file = gr.Audio(
167
+ # sources="upload",
168
+ # label="Select here the audio files that you want to use for XTTS trainining !",
169
+ # type="filepath",
170
+ # )
171
+ upload_file = gr.File(
172
+ file_count="multiple",
173
+ label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)",
174
+ )
175
+ lang = gr.Dropdown(
176
+ label="Dataset Language",
177
+ value="en",
178
+ choices=[
179
+ "en",
180
+ "es",
181
+ "fr",
182
+ "de",
183
+ "it",
184
+ "pt",
185
+ "pl",
186
+ "tr",
187
+ "ru",
188
+ "nl",
189
+ "cs",
190
+ "ar",
191
+ "zh",
192
+ "hu",
193
+ "ko",
194
+ "ja",
195
+ "hi",
196
+ ],
197
+ )
198
+ progress_data = gr.Label(label="Progress:")
199
+ logs = gr.Textbox(
200
+ label="Logs:",
201
+ interactive=False,
202
+ )
203
+ demo.load(read_logs, None, logs, every=1)
204
+
205
+ prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
206
+
207
+ def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
208
+ clear_gpu_cache()
209
+ out_path = os.path.join(out_path, "dataset")
210
+ os.makedirs(out_path, exist_ok=True)
211
+ if audio_path is None:
212
+ return (
213
+ "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!",
214
+ "",
215
+ "",
216
+ )
217
+ else:
218
+ try:
219
+ train_meta, eval_meta, audio_total_size = format_audio_list(
220
+ audio_path, target_language=language, out_path=out_path, gradio_progress=progress
221
+ )
222
+ except:
223
+ traceback.print_exc()
224
+ error = traceback.format_exc()
225
+ return (
226
+ f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}",
227
+ "",
228
+ "",
229
+ )
230
+
231
+ clear_gpu_cache()
232
+
233
+ # if audio total len is less than 2 minutes raise an error
234
+ if audio_total_size < 120:
235
+ message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
236
+ print(message)
237
+ return message, "", ""
238
+
239
+ print("Dataset Processed!")
240
+ return "Dataset Processed!", train_meta, eval_meta
241
+
242
+ with gr.Tab("2 - Fine-tuning XTTS Encoder"):
243
+ train_csv = gr.Textbox(
244
+ label="Train CSV:",
245
+ )
246
+ eval_csv = gr.Textbox(
247
+ label="Eval CSV:",
248
+ )
249
+ num_epochs = gr.Slider(
250
+ label="Number of epochs:",
251
+ minimum=1,
252
+ maximum=100,
253
+ step=1,
254
+ value=args.num_epochs,
255
+ )
256
+ batch_size = gr.Slider(
257
+ label="Batch size:",
258
+ minimum=2,
259
+ maximum=512,
260
+ step=1,
261
+ value=args.batch_size,
262
+ )
263
+ grad_acumm = gr.Slider(
264
+ label="Grad accumulation steps:",
265
+ minimum=2,
266
+ maximum=128,
267
+ step=1,
268
+ value=args.grad_acumm,
269
+ )
270
+ max_audio_length = gr.Slider(
271
+ label="Max permitted audio size in seconds:",
272
+ minimum=2,
273
+ maximum=20,
274
+ step=1,
275
+ value=args.max_audio_length,
276
+ )
277
+ progress_train = gr.Label(label="Progress:")
278
+ logs_tts_train = gr.Textbox(
279
+ label="Logs:",
280
+ interactive=False,
281
+ )
282
+ demo.load(read_logs, None, logs_tts_train, every=1)
283
+ train_btn = gr.Button(value="Step 2 - Run the training")
284
+
285
+ def train_model(
286
+ language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length
287
+ ):
288
+ clear_gpu_cache()
289
+ if not train_csv or not eval_csv:
290
+ return (
291
+ "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !",
292
+ "",
293
+ "",
294
+ "",
295
+ "",
296
+ )
297
+ try:
298
+ # convert seconds to waveform frames
299
+ max_audio_length = int(max_audio_length * 22050)
300
+ config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(
301
+ language,
302
+ num_epochs,
303
+ batch_size,
304
+ grad_acumm,
305
+ train_csv,
306
+ eval_csv,
307
+ output_path=output_path,
308
+ max_audio_length=max_audio_length,
309
+ )
310
+ except:
311
+ traceback.print_exc()
312
+ error = traceback.format_exc()
313
+ return (
314
+ f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}",
315
+ "",
316
+ "",
317
+ "",
318
+ "",
319
+ )
320
+
321
+ # copy original files to avoid parameters changes issues
322
+ os.system(f"cp {config_path} {exp_path}")
323
+ os.system(f"cp {vocab_file} {exp_path}")
324
+
325
+ ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
326
+ print("Model training done!")
327
+ clear_gpu_cache()
328
+ return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
329
+
330
+ with gr.Tab("3 - Inference"):
331
+ with gr.Row():
332
+ with gr.Column() as col1:
333
+ xtts_checkpoint = gr.Textbox(
334
+ label="XTTS checkpoint path:",
335
+ value="",
336
+ )
337
+ xtts_config = gr.Textbox(
338
+ label="XTTS config path:",
339
+ value="",
340
+ )
341
+
342
+ xtts_vocab = gr.Textbox(
343
+ label="XTTS vocab path:",
344
+ value="",
345
+ )
346
+ progress_load = gr.Label(label="Progress:")
347
+ load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
348
+
349
+ with gr.Column() as col2:
350
+ speaker_reference_audio = gr.Textbox(
351
+ label="Speaker reference audio:",
352
+ value="",
353
+ )
354
+ tts_language = gr.Dropdown(
355
+ label="Language",
356
+ value="en",
357
+ choices=[
358
+ "en",
359
+ "es",
360
+ "fr",
361
+ "de",
362
+ "it",
363
+ "pt",
364
+ "pl",
365
+ "tr",
366
+ "ru",
367
+ "nl",
368
+ "cs",
369
+ "ar",
370
+ "zh",
371
+ "hu",
372
+ "ko",
373
+ "ja",
374
+ "hi",
375
+ ],
376
+ )
377
+ tts_text = gr.Textbox(
378
+ label="Input Text.",
379
+ value="This model sounds really good and above all, it's reasonably fast.",
380
+ )
381
+ tts_btn = gr.Button(value="Step 4 - Inference")
382
+
383
+ with gr.Column() as col3:
384
+ progress_gen = gr.Label(label="Progress:")
385
+ tts_output_audio = gr.Audio(label="Generated Audio.")
386
+ reference_audio = gr.Audio(label="Reference audio used.")
387
+
388
+ prompt_compute_btn.click(
389
+ fn=preprocess_dataset,
390
+ inputs=[
391
+ upload_file,
392
+ lang,
393
+ out_path,
394
+ ],
395
+ outputs=[
396
+ progress_data,
397
+ train_csv,
398
+ eval_csv,
399
+ ],
400
+ )
401
+
402
+ train_btn.click(
403
+ fn=train_model,
404
+ inputs=[
405
+ lang,
406
+ train_csv,
407
+ eval_csv,
408
+ num_epochs,
409
+ batch_size,
410
+ grad_acumm,
411
+ out_path,
412
+ max_audio_length,
413
+ ],
414
+ outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
415
+ )
416
+
417
+ load_btn.click(
418
+ fn=load_model,
419
+ inputs=[xtts_checkpoint, xtts_config, xtts_vocab],
420
+ outputs=[progress_load],
421
+ )
422
+
423
+ tts_btn.click(
424
+ fn=run_tts,
425
+ inputs=[
426
+ tts_language,
427
+ tts_text,
428
+ speaker_reference_audio,
429
+ ],
430
+ outputs=[progress_gen, tts_output_audio, reference_audio],
431
+ )
432
+
433
+ demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")
TTS/encoder/README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Speaker Encoder
2
+
3
+ This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding.
4
+
5
+ With the code here you can generate d-vectors for both multi-speaker and single-speaker TTS datasets, then visualise and explore them along with the associated audio files in an interactive chart.
6
+
7
+ Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook as demonstrated in [this video](https://youtu.be/KW3oO7JVa7Q).
8
+
9
+ ![](umap.png)
10
+
11
+ Download a pretrained model from [Released Models](https://github.com/mozilla/TTS/wiki/Released-Models) page.
12
+
13
+ To run the code, you need to follow the same flow as in TTS.
14
+
15
+ - Define 'config.json' for your needs. Note that, audio parameters should match your TTS model.
16
+ - Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360```
17
+ - Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda /model/path/best_model.pth model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files.
18
+ - Watch training on Tensorboard as in TTS
TTS/encoder/__init__.py ADDED
File without changes
TTS/encoder/configs/base_encoder_config.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Dict, List
3
+
4
+ from coqpit import MISSING
5
+
6
+ from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
7
+
8
+
9
+ @dataclass
10
+ class BaseEncoderConfig(BaseTrainingConfig):
11
+ """Defines parameters for a Generic Encoder model."""
12
+
13
+ model: str = None
14
+ audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
15
+ datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
16
+ # model params
17
+ model_params: Dict = field(
18
+ default_factory=lambda: {
19
+ "model_name": "lstm",
20
+ "input_dim": 80,
21
+ "proj_dim": 256,
22
+ "lstm_dim": 768,
23
+ "num_lstm_layers": 3,
24
+ "use_lstm_with_projection": True,
25
+ }
26
+ )
27
+
28
+ audio_augmentation: Dict = field(default_factory=lambda: {})
29
+
30
+ # training params
31
+ epochs: int = 10000
32
+ loss: str = "angleproto"
33
+ grad_clip: float = 3.0
34
+ lr: float = 0.0001
35
+ optimizer: str = "radam"
36
+ optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0})
37
+ lr_decay: bool = False
38
+ warmup_steps: int = 4000
39
+
40
+ # logging params
41
+ tb_model_param_stats: bool = False
42
+ steps_plot_stats: int = 10
43
+ save_step: int = 1000
44
+ print_step: int = 20
45
+ run_eval: bool = False
46
+
47
+ # data loader
48
+ num_classes_in_batch: int = MISSING
49
+ num_utter_per_class: int = MISSING
50
+ eval_num_classes_in_batch: int = None
51
+ eval_num_utter_per_class: int = None
52
+
53
+ num_loader_workers: int = MISSING
54
+ voice_len: float = 1.6
55
+
56
+ def check_values(self):
57
+ super().check_values()
58
+ c = asdict(self)
59
+ assert (
60
+ c["model_params"]["input_dim"] == self.audio.num_mels
61
+ ), " [!] model input dimendion must be equal to melspectrogram dimension."
TTS/encoder/configs/emotion_encoder_config.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
4
+
5
+
6
+ @dataclass
7
+ class EmotionEncoderConfig(BaseEncoderConfig):
8
+ """Defines parameters for Emotion Encoder model."""
9
+
10
+ model: str = "emotion_encoder"
11
+ map_classid_to_classname: dict = None
12
+ class_name_key: str = "emotion_name"
TTS/encoder/configs/speaker_encoder_config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
4
+
5
+
6
+ @dataclass
7
+ class SpeakerEncoderConfig(BaseEncoderConfig):
8
+ """Defines parameters for Speaker Encoder model."""
9
+
10
+ model: str = "speaker_encoder"
11
+ class_name_key: str = "speaker_name"
TTS/encoder/dataset.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+
7
+ from TTS.encoder.utils.generic_utils import AugmentWAV
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class EncoderDataset(Dataset):
13
+ def __init__(
14
+ self,
15
+ config,
16
+ ap,
17
+ meta_data,
18
+ voice_len=1.6,
19
+ num_classes_in_batch=64,
20
+ num_utter_per_class=10,
21
+ augmentation_config=None,
22
+ use_torch_spec=None,
23
+ ):
24
+ """
25
+ Args:
26
+ ap (TTS.tts.utils.AudioProcessor): audio processor object.
27
+ meta_data (list): list of dataset instances.
28
+ seq_len (int): voice segment length in seconds.
29
+ """
30
+ super().__init__()
31
+ self.config = config
32
+ self.items = meta_data
33
+ self.sample_rate = ap.sample_rate
34
+ self.seq_len = int(voice_len * self.sample_rate)
35
+ self.num_utter_per_class = num_utter_per_class
36
+ self.ap = ap
37
+ self.use_torch_spec = use_torch_spec
38
+ self.classes, self.items = self.__parse_items()
39
+
40
+ self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
41
+
42
+ # Data Augmentation
43
+ self.augmentator = None
44
+ self.gaussian_augmentation_config = None
45
+ if augmentation_config:
46
+ self.data_augmentation_p = augmentation_config["p"]
47
+ if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
48
+ self.augmentator = AugmentWAV(ap, augmentation_config)
49
+
50
+ if "gaussian" in augmentation_config.keys():
51
+ self.gaussian_augmentation_config = augmentation_config["gaussian"]
52
+
53
+ logger.info("DataLoader initialization")
54
+ logger.info(" | Classes per batch: %d", num_classes_in_batch)
55
+ logger.info(" | Number of instances: %d", len(self.items))
56
+ logger.info(" | Sequence length: %d", self.seq_len)
57
+ logger.info(" | Number of classes: %d", len(self.classes))
58
+ logger.info(" | Classes: %s", self.classes)
59
+
60
+ def load_wav(self, filename):
61
+ audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
62
+ return audio
63
+
64
+ def __parse_items(self):
65
+ class_to_utters = {}
66
+ for item in self.items:
67
+ path_ = item["audio_file"]
68
+ class_name = item[self.config.class_name_key]
69
+ if class_name in class_to_utters.keys():
70
+ class_to_utters[class_name].append(path_)
71
+ else:
72
+ class_to_utters[class_name] = [
73
+ path_,
74
+ ]
75
+
76
+ # skip classes with number of samples >= self.num_utter_per_class
77
+ class_to_utters = {k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class}
78
+
79
+ classes = list(class_to_utters.keys())
80
+ classes.sort()
81
+
82
+ new_items = []
83
+ for item in self.items:
84
+ path_ = item["audio_file"]
85
+ class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"]
86
+ # ignore filtered classes
87
+ if class_name not in classes:
88
+ continue
89
+ # ignore small audios
90
+ if self.load_wav(path_).shape[0] - self.seq_len <= 0:
91
+ continue
92
+
93
+ new_items.append({"wav_file_path": path_, "class_name": class_name})
94
+
95
+ return classes, new_items
96
+
97
+ def __len__(self):
98
+ return len(self.items)
99
+
100
+ def get_num_classes(self):
101
+ return len(self.classes)
102
+
103
+ def get_class_list(self):
104
+ return self.classes
105
+
106
+ def set_classes(self, classes):
107
+ self.classes = classes
108
+ self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}
109
+
110
+ def get_map_classid_to_classname(self):
111
+ return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())
112
+
113
+ def __getitem__(self, idx):
114
+ return self.items[idx]
115
+
116
+ def collate_fn(self, batch):
117
+ # get the batch class_ids
118
+ labels = []
119
+ feats = []
120
+ for item in batch:
121
+ utter_path = item["wav_file_path"]
122
+ class_name = item["class_name"]
123
+
124
+ # get classid
125
+ class_id = self.classname_to_classid[class_name]
126
+ # load wav file
127
+ wav = self.load_wav(utter_path)
128
+ offset = random.randint(0, wav.shape[0] - self.seq_len)
129
+ wav = wav[offset : offset + self.seq_len]
130
+
131
+ if self.augmentator is not None and self.data_augmentation_p:
132
+ if random.random() < self.data_augmentation_p:
133
+ wav = self.augmentator.apply_one(wav)
134
+
135
+ if not self.use_torch_spec:
136
+ mel = self.ap.melspectrogram(wav)
137
+ feats.append(torch.FloatTensor(mel))
138
+ else:
139
+ feats.append(torch.FloatTensor(wav))
140
+
141
+ labels.append(class_id)
142
+
143
+ feats = torch.stack(feats)
144
+ labels = torch.LongTensor(labels)
145
+
146
+ return feats, labels
TTS/encoder/losses.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ # adapted from https://github.com/cvqluu/GE2E-Loss
11
+ class GE2ELoss(nn.Module):
12
+ def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"):
13
+ """
14
+ Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1]
15
+ Accepts an input of size (N, M, D)
16
+ where N is the number of speakers in the batch,
17
+ M is the number of utterances per speaker,
18
+ and D is the dimensionality of the embedding vector (e.g. d-vector)
19
+ Args:
20
+ - init_w (float): defines the initial value of w in Equation (5) of [1]
21
+ - init_b (float): definies the initial value of b in Equation (5) of [1]
22
+ """
23
+ super().__init__()
24
+ # pylint: disable=E1102
25
+ self.w = nn.Parameter(torch.tensor(init_w))
26
+ # pylint: disable=E1102
27
+ self.b = nn.Parameter(torch.tensor(init_b))
28
+ self.loss_method = loss_method
29
+
30
+ logger.info("Initialized Generalized End-to-End loss")
31
+
32
+ assert self.loss_method in ["softmax", "contrast"]
33
+
34
+ if self.loss_method == "softmax":
35
+ self.embed_loss = self.embed_loss_softmax
36
+ if self.loss_method == "contrast":
37
+ self.embed_loss = self.embed_loss_contrast
38
+
39
+ # pylint: disable=R0201
40
+ def calc_new_centroids(self, dvecs, centroids, spkr, utt):
41
+ """
42
+ Calculates the new centroids excluding the reference utterance
43
+ """
44
+ excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :]))
45
+ excl = torch.mean(excl, 0)
46
+ new_centroids = []
47
+ for i, centroid in enumerate(centroids):
48
+ if i == spkr:
49
+ new_centroids.append(excl)
50
+ else:
51
+ new_centroids.append(centroid)
52
+ return torch.stack(new_centroids)
53
+
54
+ def calc_cosine_sim(self, dvecs, centroids):
55
+ """
56
+ Make the cosine similarity matrix with dims (N,M,N)
57
+ """
58
+ cos_sim_matrix = []
59
+ for spkr_idx, speaker in enumerate(dvecs):
60
+ cs_row = []
61
+ for utt_idx, utterance in enumerate(speaker):
62
+ new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx)
63
+ # vector based cosine similarity for speed
64
+ cs_row.append(
65
+ torch.clamp(
66
+ torch.mm(
67
+ utterance.unsqueeze(1).transpose(0, 1),
68
+ new_centroids.transpose(0, 1),
69
+ )
70
+ / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)),
71
+ 1e-6,
72
+ )
73
+ )
74
+ cs_row = torch.cat(cs_row, dim=0)
75
+ cos_sim_matrix.append(cs_row)
76
+ return torch.stack(cos_sim_matrix)
77
+
78
+ # pylint: disable=R0201
79
+ def embed_loss_softmax(self, dvecs, cos_sim_matrix):
80
+ """
81
+ Calculates the loss on each embedding $L(e_{ji})$ by taking softmax
82
+ """
83
+ N, M, _ = dvecs.shape
84
+ L = []
85
+ for j in range(N):
86
+ L_row = []
87
+ for i in range(M):
88
+ L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j])
89
+ L_row = torch.stack(L_row)
90
+ L.append(L_row)
91
+ return torch.stack(L)
92
+
93
+ # pylint: disable=R0201
94
+ def embed_loss_contrast(self, dvecs, cos_sim_matrix):
95
+ """
96
+ Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid
97
+ """
98
+ N, M, _ = dvecs.shape
99
+ L = []
100
+ for j in range(N):
101
+ L_row = []
102
+ for i in range(M):
103
+ centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i])
104
+ excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :]))
105
+ L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids))
106
+ L_row = torch.stack(L_row)
107
+ L.append(L_row)
108
+ return torch.stack(L)
109
+
110
+ def forward(self, x, _label=None):
111
+ """
112
+ Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
113
+ """
114
+
115
+ assert x.size()[1] >= 2
116
+
117
+ centroids = torch.mean(x, 1)
118
+ cos_sim_matrix = self.calc_cosine_sim(x, centroids)
119
+ torch.clamp(self.w, 1e-6)
120
+ cos_sim_matrix = self.w * cos_sim_matrix + self.b
121
+ L = self.embed_loss(x, cos_sim_matrix)
122
+ return L.mean()
123
+
124
+
125
+ # adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py
126
+ class AngleProtoLoss(nn.Module):
127
+ """
128
+ Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982
129
+ Accepts an input of size (N, M, D)
130
+ where N is the number of speakers in the batch,
131
+ M is the number of utterances per speaker,
132
+ and D is the dimensionality of the embedding vector
133
+ Args:
134
+ - init_w (float): defines the initial value of w
135
+ - init_b (float): definies the initial value of b
136
+ """
137
+
138
+ def __init__(self, init_w=10.0, init_b=-5.0):
139
+ super().__init__()
140
+ # pylint: disable=E1102
141
+ self.w = nn.Parameter(torch.tensor(init_w))
142
+ # pylint: disable=E1102
143
+ self.b = nn.Parameter(torch.tensor(init_b))
144
+ self.criterion = torch.nn.CrossEntropyLoss()
145
+
146
+ logger.info("Initialized Angular Prototypical loss")
147
+
148
+ def forward(self, x, _label=None):
149
+ """
150
+ Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
151
+ """
152
+
153
+ assert x.size()[1] >= 2
154
+
155
+ out_anchor = torch.mean(x[:, 1:, :], 1)
156
+ out_positive = x[:, 0, :]
157
+ num_speakers = out_anchor.size()[0]
158
+
159
+ cos_sim_matrix = F.cosine_similarity(
160
+ out_positive.unsqueeze(-1).expand(-1, -1, num_speakers),
161
+ out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2),
162
+ )
163
+ torch.clamp(self.w, 1e-6)
164
+ cos_sim_matrix = cos_sim_matrix * self.w + self.b
165
+ label = torch.arange(num_speakers).to(cos_sim_matrix.device)
166
+ L = self.criterion(cos_sim_matrix, label)
167
+ return L
168
+
169
+
170
+ class SoftmaxLoss(nn.Module):
171
+ """
172
+ Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982
173
+ Args:
174
+ - embedding_dim (float): speaker embedding dim
175
+ - n_speakers (float): number of speakers
176
+ """
177
+
178
+ def __init__(self, embedding_dim, n_speakers):
179
+ super().__init__()
180
+
181
+ self.criterion = torch.nn.CrossEntropyLoss()
182
+ self.fc = nn.Linear(embedding_dim, n_speakers)
183
+
184
+ logger.info("Initialised Softmax Loss")
185
+
186
+ def forward(self, x, label=None):
187
+ # reshape for compatibility
188
+ x = x.reshape(-1, x.size()[-1])
189
+ label = label.reshape(-1)
190
+
191
+ x = self.fc(x)
192
+ L = self.criterion(x, label)
193
+
194
+ return L
195
+
196
+ def inference(self, embedding):
197
+ x = self.fc(embedding)
198
+ activations = torch.nn.functional.softmax(x, dim=1).squeeze(0)
199
+ class_id = torch.argmax(activations)
200
+ return class_id
201
+
202
+
203
+ class SoftmaxAngleProtoLoss(nn.Module):
204
+ """
205
+ Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153
206
+ Args:
207
+ - embedding_dim (float): speaker embedding dim
208
+ - n_speakers (float): number of speakers
209
+ - init_w (float): defines the initial value of w
210
+ - init_b (float): definies the initial value of b
211
+ """
212
+
213
+ def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0):
214
+ super().__init__()
215
+
216
+ self.softmax = SoftmaxLoss(embedding_dim, n_speakers)
217
+ self.angleproto = AngleProtoLoss(init_w, init_b)
218
+
219
+ logger.info("Initialised SoftmaxAnglePrototypical Loss")
220
+
221
+ def forward(self, x, label=None):
222
+ """
223
+ Calculates the SoftmaxAnglePrototypical loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
224
+ """
225
+
226
+ Lp = self.angleproto(x)
227
+
228
+ Ls = self.softmax(x, label)
229
+
230
+ return Ls + Lp
TTS/encoder/models/base_encoder.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ from coqpit import Coqpit
7
+ from torch import nn
8
+ from trainer.io import load_fsspec
9
+
10
+ from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
11
+ from TTS.utils.generic_utils import set_init_dict
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class PreEmphasis(nn.Module):
17
+ def __init__(self, coefficient=0.97):
18
+ super().__init__()
19
+ self.coefficient = coefficient
20
+ self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
21
+
22
+ def forward(self, x):
23
+ assert len(x.size()) == 2
24
+
25
+ x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
26
+ return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
27
+
28
+
29
+ class BaseEncoder(nn.Module):
30
+ """Base `encoder` class. Every new `encoder` model must inherit this.
31
+
32
+ It defines common `encoder` specific functions.
33
+ """
34
+
35
+ # pylint: disable=W0102
36
+ def __init__(self):
37
+ super(BaseEncoder, self).__init__()
38
+
39
+ def get_torch_mel_spectrogram_class(self, audio_config):
40
+ return torch.nn.Sequential(
41
+ PreEmphasis(audio_config["preemphasis"]),
42
+ # TorchSTFT(
43
+ # n_fft=audio_config["fft_size"],
44
+ # hop_length=audio_config["hop_length"],
45
+ # win_length=audio_config["win_length"],
46
+ # sample_rate=audio_config["sample_rate"],
47
+ # window="hamming_window",
48
+ # mel_fmin=0.0,
49
+ # mel_fmax=None,
50
+ # use_htk=True,
51
+ # do_amp_to_db=False,
52
+ # n_mels=audio_config["num_mels"],
53
+ # power=2.0,
54
+ # use_mel=True,
55
+ # mel_norm=None,
56
+ # )
57
+ torchaudio.transforms.MelSpectrogram(
58
+ sample_rate=audio_config["sample_rate"],
59
+ n_fft=audio_config["fft_size"],
60
+ win_length=audio_config["win_length"],
61
+ hop_length=audio_config["hop_length"],
62
+ window_fn=torch.hamming_window,
63
+ n_mels=audio_config["num_mels"],
64
+ ),
65
+ )
66
+
67
+ @torch.no_grad()
68
+ def inference(self, x, l2_norm=True):
69
+ return self.forward(x, l2_norm)
70
+
71
+ @torch.no_grad()
72
+ def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True):
73
+ """
74
+ Generate embeddings for a batch of utterances
75
+ x: 1xTxD
76
+ """
77
+ # map to the waveform size
78
+ if self.use_torch_spec:
79
+ num_frames = num_frames * self.audio_config["hop_length"]
80
+
81
+ max_len = x.shape[1]
82
+
83
+ if max_len < num_frames:
84
+ num_frames = max_len
85
+
86
+ offsets = np.linspace(0, max_len - num_frames, num=num_eval)
87
+
88
+ frames_batch = []
89
+ for offset in offsets:
90
+ offset = int(offset)
91
+ end_offset = int(offset + num_frames)
92
+ frames = x[:, offset:end_offset]
93
+ frames_batch.append(frames)
94
+
95
+ frames_batch = torch.cat(frames_batch, dim=0)
96
+ embeddings = self.inference(frames_batch, l2_norm=l2_norm)
97
+
98
+ if return_mean:
99
+ embeddings = torch.mean(embeddings, dim=0, keepdim=True)
100
+ return embeddings
101
+
102
+ def get_criterion(self, c: Coqpit, num_classes=None):
103
+ if c.loss == "ge2e":
104
+ criterion = GE2ELoss(loss_method="softmax")
105
+ elif c.loss == "angleproto":
106
+ criterion = AngleProtoLoss()
107
+ elif c.loss == "softmaxproto":
108
+ criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
109
+ else:
110
+ raise Exception("The %s not is a loss supported" % c.loss)
111
+ return criterion
112
+
113
+ def load_checkpoint(
114
+ self,
115
+ config: Coqpit,
116
+ checkpoint_path: str,
117
+ eval: bool = False,
118
+ use_cuda: bool = False,
119
+ criterion=None,
120
+ cache=False,
121
+ ):
122
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
123
+ try:
124
+ self.load_state_dict(state["model"])
125
+ logger.info("Model fully restored. ")
126
+ except (KeyError, RuntimeError) as error:
127
+ # If eval raise the error
128
+ if eval:
129
+ raise error
130
+
131
+ logger.info("Partial model initialization.")
132
+ model_dict = self.state_dict()
133
+ model_dict = set_init_dict(model_dict, state["model"], c)
134
+ self.load_state_dict(model_dict)
135
+ del model_dict
136
+
137
+ # load the criterion for restore_path
138
+ if criterion is not None and "criterion" in state:
139
+ try:
140
+ criterion.load_state_dict(state["criterion"])
141
+ except (KeyError, RuntimeError) as error:
142
+ logger.exception("Criterion load ignored because of: %s", error)
143
+
144
+ # instance and load the criterion for the encoder classifier in inference time
145
+ if (
146
+ eval
147
+ and criterion is None
148
+ and "criterion" in state
149
+ and getattr(config, "map_classid_to_classname", None) is not None
150
+ ):
151
+ criterion = self.get_criterion(config, len(config.map_classid_to_classname))
152
+ criterion.load_state_dict(state["criterion"])
153
+
154
+ if use_cuda:
155
+ self.cuda()
156
+ if criterion is not None:
157
+ criterion = criterion.cuda()
158
+
159
+ if eval:
160
+ self.eval()
161
+ assert not self.training
162
+
163
+ if not eval:
164
+ return criterion, state["step"]
165
+ return criterion
TTS/encoder/models/lstm.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from TTS.encoder.models.base_encoder import BaseEncoder
5
+
6
+
7
+ class LSTMWithProjection(nn.Module):
8
+ def __init__(self, input_size, hidden_size, proj_size):
9
+ super().__init__()
10
+ self.input_size = input_size
11
+ self.hidden_size = hidden_size
12
+ self.proj_size = proj_size
13
+ self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
14
+ self.linear = nn.Linear(hidden_size, proj_size, bias=False)
15
+
16
+ def forward(self, x):
17
+ self.lstm.flatten_parameters()
18
+ o, (_, _) = self.lstm(x)
19
+ return self.linear(o)
20
+
21
+
22
+ class LSTMWithoutProjection(nn.Module):
23
+ def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):
24
+ super().__init__()
25
+ self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True)
26
+ self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)
27
+ self.relu = nn.ReLU()
28
+
29
+ def forward(self, x):
30
+ _, (hidden, _) = self.lstm(x)
31
+ return self.relu(self.linear(hidden[-1]))
32
+
33
+
34
+ class LSTMSpeakerEncoder(BaseEncoder):
35
+ def __init__(
36
+ self,
37
+ input_dim,
38
+ proj_dim=256,
39
+ lstm_dim=768,
40
+ num_lstm_layers=3,
41
+ use_lstm_with_projection=True,
42
+ use_torch_spec=False,
43
+ audio_config=None,
44
+ ):
45
+ super().__init__()
46
+ self.use_lstm_with_projection = use_lstm_with_projection
47
+ self.use_torch_spec = use_torch_spec
48
+ self.audio_config = audio_config
49
+ self.proj_dim = proj_dim
50
+
51
+ layers = []
52
+ # choise LSTM layer
53
+ if use_lstm_with_projection:
54
+ layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))
55
+ for _ in range(num_lstm_layers - 1):
56
+ layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))
57
+ self.layers = nn.Sequential(*layers)
58
+ else:
59
+ self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
60
+
61
+ self.instancenorm = nn.InstanceNorm1d(input_dim)
62
+
63
+ if self.use_torch_spec:
64
+ self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
65
+ else:
66
+ self.torch_spec = None
67
+
68
+ self._init_layers()
69
+
70
+ def _init_layers(self):
71
+ for name, param in self.layers.named_parameters():
72
+ if "bias" in name:
73
+ nn.init.constant_(param, 0.0)
74
+ elif "weight" in name:
75
+ nn.init.xavier_normal_(param)
76
+
77
+ def forward(self, x, l2_norm=True):
78
+ """Forward pass of the model.
79
+
80
+ Args:
81
+ x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
82
+ to compute the spectrogram on-the-fly.
83
+ l2_norm (bool): Whether to L2-normalize the outputs.
84
+
85
+ Shapes:
86
+ - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
87
+ """
88
+ with torch.no_grad():
89
+ with torch.cuda.amp.autocast(enabled=False):
90
+ if self.use_torch_spec:
91
+ x.squeeze_(1)
92
+ x = self.torch_spec(x)
93
+ x = self.instancenorm(x).transpose(1, 2)
94
+ d = self.layers(x)
95
+ if self.use_lstm_with_projection:
96
+ d = d[:, -1]
97
+ if l2_norm:
98
+ d = torch.nn.functional.normalize(d, p=2, dim=1)
99
+ return d
TTS/encoder/models/resnet.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ # from TTS.utils.audio.torch_transforms import TorchSTFT
5
+ from TTS.encoder.models.base_encoder import BaseEncoder
6
+
7
+
8
+ class SELayer(nn.Module):
9
+ def __init__(self, channel, reduction=8):
10
+ super(SELayer, self).__init__()
11
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
12
+ self.fc = nn.Sequential(
13
+ nn.Linear(channel, channel // reduction),
14
+ nn.ReLU(inplace=True),
15
+ nn.Linear(channel // reduction, channel),
16
+ nn.Sigmoid(),
17
+ )
18
+
19
+ def forward(self, x):
20
+ b, c, _, _ = x.size()
21
+ y = self.avg_pool(x).view(b, c)
22
+ y = self.fc(y).view(b, c, 1, 1)
23
+ return x * y
24
+
25
+
26
+ class SEBasicBlock(nn.Module):
27
+ expansion = 1
28
+
29
+ def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
30
+ super(SEBasicBlock, self).__init__()
31
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
32
+ self.bn1 = nn.BatchNorm2d(planes)
33
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
34
+ self.bn2 = nn.BatchNorm2d(planes)
35
+ self.relu = nn.ReLU(inplace=True)
36
+ self.se = SELayer(planes, reduction)
37
+ self.downsample = downsample
38
+ self.stride = stride
39
+
40
+ def forward(self, x):
41
+ residual = x
42
+
43
+ out = self.conv1(x)
44
+ out = self.relu(out)
45
+ out = self.bn1(out)
46
+
47
+ out = self.conv2(out)
48
+ out = self.bn2(out)
49
+ out = self.se(out)
50
+
51
+ if self.downsample is not None:
52
+ residual = self.downsample(x)
53
+
54
+ out += residual
55
+ out = self.relu(out)
56
+ return out
57
+
58
+
59
+ class ResNetSpeakerEncoder(BaseEncoder):
60
+ """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153
61
+ Adapted from: https://github.com/clovaai/voxceleb_trainer
62
+ """
63
+
64
+ # pylint: disable=W0102
65
+ def __init__(
66
+ self,
67
+ input_dim=64,
68
+ proj_dim=512,
69
+ layers=[3, 4, 6, 3],
70
+ num_filters=[32, 64, 128, 256],
71
+ encoder_type="ASP",
72
+ log_input=False,
73
+ use_torch_spec=False,
74
+ audio_config=None,
75
+ ):
76
+ super(ResNetSpeakerEncoder, self).__init__()
77
+
78
+ self.encoder_type = encoder_type
79
+ self.input_dim = input_dim
80
+ self.log_input = log_input
81
+ self.use_torch_spec = use_torch_spec
82
+ self.audio_config = audio_config
83
+ self.proj_dim = proj_dim
84
+
85
+ self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
86
+ self.relu = nn.ReLU(inplace=True)
87
+ self.bn1 = nn.BatchNorm2d(num_filters[0])
88
+
89
+ self.inplanes = num_filters[0]
90
+ self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
91
+ self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
92
+ self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
93
+ self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
94
+
95
+ self.instancenorm = nn.InstanceNorm1d(input_dim)
96
+
97
+ if self.use_torch_spec:
98
+ self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config)
99
+ else:
100
+ self.torch_spec = None
101
+
102
+ outmap_size = int(self.input_dim / 8)
103
+
104
+ self.attention = nn.Sequential(
105
+ nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
106
+ nn.ReLU(),
107
+ nn.BatchNorm1d(128),
108
+ nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
109
+ nn.Softmax(dim=2),
110
+ )
111
+
112
+ if self.encoder_type == "SAP":
113
+ out_dim = num_filters[3] * outmap_size
114
+ elif self.encoder_type == "ASP":
115
+ out_dim = num_filters[3] * outmap_size * 2
116
+ else:
117
+ raise ValueError("Undefined encoder")
118
+
119
+ self.fc = nn.Linear(out_dim, proj_dim)
120
+
121
+ self._init_layers()
122
+
123
+ def _init_layers(self):
124
+ for m in self.modules():
125
+ if isinstance(m, nn.Conv2d):
126
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
127
+ elif isinstance(m, nn.BatchNorm2d):
128
+ nn.init.constant_(m.weight, 1)
129
+ nn.init.constant_(m.bias, 0)
130
+
131
+ def create_layer(self, block, planes, blocks, stride=1):
132
+ downsample = None
133
+ if stride != 1 or self.inplanes != planes * block.expansion:
134
+ downsample = nn.Sequential(
135
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
136
+ nn.BatchNorm2d(planes * block.expansion),
137
+ )
138
+
139
+ layers = []
140
+ layers.append(block(self.inplanes, planes, stride, downsample))
141
+ self.inplanes = planes * block.expansion
142
+ for _ in range(1, blocks):
143
+ layers.append(block(self.inplanes, planes))
144
+
145
+ return nn.Sequential(*layers)
146
+
147
+ # pylint: disable=R0201
148
+ def new_parameter(self, *size):
149
+ out = nn.Parameter(torch.FloatTensor(*size))
150
+ nn.init.xavier_normal_(out)
151
+ return out
152
+
153
+ def forward(self, x, l2_norm=False):
154
+ """Forward pass of the model.
155
+
156
+ Args:
157
+ x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
158
+ to compute the spectrogram on-the-fly.
159
+ l2_norm (bool): Whether to L2-normalize the outputs.
160
+
161
+ Shapes:
162
+ - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
163
+ """
164
+ x.squeeze_(1)
165
+ # if you torch spec compute it otherwise use the mel spec computed by the AP
166
+ if self.use_torch_spec:
167
+ x = self.torch_spec(x)
168
+
169
+ if self.log_input:
170
+ x = (x + 1e-6).log()
171
+ x = self.instancenorm(x).unsqueeze(1)
172
+
173
+ x = self.conv1(x)
174
+ x = self.relu(x)
175
+ x = self.bn1(x)
176
+
177
+ x = self.layer1(x)
178
+ x = self.layer2(x)
179
+ x = self.layer3(x)
180
+ x = self.layer4(x)
181
+
182
+ x = x.reshape(x.size()[0], -1, x.size()[-1])
183
+
184
+ w = self.attention(x)
185
+
186
+ if self.encoder_type == "SAP":
187
+ x = torch.sum(x * w, dim=2)
188
+ elif self.encoder_type == "ASP":
189
+ mu = torch.sum(x * w, dim=2)
190
+ sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
191
+ x = torch.cat((mu, sg), 1)
192
+
193
+ x = x.view(x.size()[0], -1)
194
+ x = self.fc(x)
195
+
196
+ if l2_norm:
197
+ x = torch.nn.functional.normalize(x, p=2, dim=1)
198
+ return x
TTS/encoder/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ umap-learn
2
+ numpy>=1.17.0
TTS/encoder/utils/__init__.py ADDED
File without changes
TTS/encoder/utils/generic_utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import logging
3
+ import os
4
+ import random
5
+
6
+ import numpy as np
7
+ from scipy import signal
8
+
9
+ from TTS.encoder.models.lstm import LSTMSpeakerEncoder
10
+ from TTS.encoder.models.resnet import ResNetSpeakerEncoder
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class AugmentWAV(object):
16
+ def __init__(self, ap, augmentation_config):
17
+ self.ap = ap
18
+ self.use_additive_noise = False
19
+
20
+ if "additive" in augmentation_config.keys():
21
+ self.additive_noise_config = augmentation_config["additive"]
22
+ additive_path = self.additive_noise_config["sounds_path"]
23
+ if additive_path:
24
+ self.use_additive_noise = True
25
+ # get noise types
26
+ self.additive_noise_types = []
27
+ for key in self.additive_noise_config.keys():
28
+ if isinstance(self.additive_noise_config[key], dict):
29
+ self.additive_noise_types.append(key)
30
+
31
+ additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True)
32
+
33
+ self.noise_list = {}
34
+
35
+ for wav_file in additive_files:
36
+ noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0]
37
+ # ignore not listed directories
38
+ if noise_dir not in self.additive_noise_types:
39
+ continue
40
+ if noise_dir not in self.noise_list:
41
+ self.noise_list[noise_dir] = []
42
+ self.noise_list[noise_dir].append(wav_file)
43
+
44
+ logger.info(
45
+ "Using Additive Noise Augmentation: with %d audios instances from %s",
46
+ len(additive_files),
47
+ self.additive_noise_types,
48
+ )
49
+
50
+ self.use_rir = False
51
+
52
+ if "rir" in augmentation_config.keys():
53
+ self.rir_config = augmentation_config["rir"]
54
+ if self.rir_config["rir_path"]:
55
+ self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
56
+ self.use_rir = True
57
+
58
+ logger.info("Using RIR Noise Augmentation: with %d audios instances", len(self.rir_files))
59
+
60
+ self.create_augmentation_global_list()
61
+
62
+ def create_augmentation_global_list(self):
63
+ if self.use_additive_noise:
64
+ self.global_noise_list = self.additive_noise_types
65
+ else:
66
+ self.global_noise_list = []
67
+ if self.use_rir:
68
+ self.global_noise_list.append("RIR_AUG")
69
+
70
+ def additive_noise(self, noise_type, audio):
71
+ clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
72
+
73
+ noise_list = random.sample(
74
+ self.noise_list[noise_type],
75
+ random.randint(
76
+ self.additive_noise_config[noise_type]["min_num_noises"],
77
+ self.additive_noise_config[noise_type]["max_num_noises"],
78
+ ),
79
+ )
80
+
81
+ audio_len = audio.shape[0]
82
+ noises_wav = None
83
+ for noise in noise_list:
84
+ noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len]
85
+
86
+ if noiseaudio.shape[0] < audio_len:
87
+ continue
88
+
89
+ noise_snr = random.uniform(
90
+ self.additive_noise_config[noise_type]["min_snr_in_db"],
91
+ self.additive_noise_config[noise_type]["max_num_noises"],
92
+ )
93
+ noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4)
94
+ noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
95
+
96
+ if noises_wav is None:
97
+ noises_wav = noise_wav
98
+ else:
99
+ noises_wav += noise_wav
100
+
101
+ # if all possible files is less than audio, choose other files
102
+ if noises_wav is None:
103
+ return self.additive_noise(noise_type, audio)
104
+
105
+ return audio + noises_wav
106
+
107
+ def reverberate(self, audio):
108
+ audio_len = audio.shape[0]
109
+
110
+ rir_file = random.choice(self.rir_files)
111
+ rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
112
+ rir = rir / np.sqrt(np.sum(rir**2))
113
+ return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]
114
+
115
+ def apply_one(self, audio):
116
+ noise_type = random.choice(self.global_noise_list)
117
+ if noise_type == "RIR_AUG":
118
+ return self.reverberate(audio)
119
+
120
+ return self.additive_noise(noise_type, audio)
121
+
122
+
123
+ def setup_encoder_model(config: "Coqpit"):
124
+ if config.model_params["model_name"].lower() == "lstm":
125
+ model = LSTMSpeakerEncoder(
126
+ config.model_params["input_dim"],
127
+ config.model_params["proj_dim"],
128
+ config.model_params["lstm_dim"],
129
+ config.model_params["num_lstm_layers"],
130
+ use_torch_spec=config.model_params.get("use_torch_spec", False),
131
+ audio_config=config.audio,
132
+ )
133
+ elif config.model_params["model_name"].lower() == "resnet":
134
+ model = ResNetSpeakerEncoder(
135
+ input_dim=config.model_params["input_dim"],
136
+ proj_dim=config.model_params["proj_dim"],
137
+ log_input=config.model_params.get("log_input", False),
138
+ use_torch_spec=config.model_params.get("use_torch_spec", False),
139
+ audio_config=config.audio,
140
+ )
141
+ return model
TTS/encoder/utils/prepare_voxceleb.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo
3
+ # All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # ==============================================================================
17
+ # Only support eager mode and TF>=2.0.0
18
+ # pylint: disable=no-member, invalid-name, relative-beyond-top-level
19
+ # pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes
20
+ """ voxceleb 1 & 2 """
21
+
22
+ import csv
23
+ import hashlib
24
+ import logging
25
+ import os
26
+ import subprocess
27
+ import sys
28
+ import zipfile
29
+
30
+ import soundfile as sf
31
+
32
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ SUBSETS = {
37
+ "vox1_dev_wav": [
38
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa",
39
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab",
40
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac",
41
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad",
42
+ ],
43
+ "vox1_test_wav": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"],
44
+ "vox2_dev_aac": [
45
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa",
46
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab",
47
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac",
48
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad",
49
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae",
50
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf",
51
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag",
52
+ "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah",
53
+ ],
54
+ "vox2_test_aac": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"],
55
+ }
56
+
57
+ MD5SUM = {
58
+ "vox1_dev_wav": "ae63e55b951748cc486645f532ba230b",
59
+ "vox2_dev_aac": "bbc063c46078a602ca71605645c2a402",
60
+ "vox1_test_wav": "185fdc63c3c739954633d50379a3d102",
61
+ "vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312",
62
+ }
63
+
64
+ USER = {"user": "", "password": ""}
65
+
66
+ speaker_id_dict = {}
67
+
68
+
69
+ def download_and_extract(directory, subset, urls):
70
+ """Download and extract the given split of dataset.
71
+
72
+ Args:
73
+ directory: the directory where to put the downloaded data.
74
+ subset: subset name of the corpus.
75
+ urls: the list of urls to download the data file.
76
+ """
77
+ os.makedirs(directory, exist_ok=True)
78
+
79
+ try:
80
+ for url in urls:
81
+ zip_filepath = os.path.join(directory, url.split("/")[-1])
82
+ if os.path.exists(zip_filepath):
83
+ continue
84
+ logger.info("Downloading %s to %s" % (url, zip_filepath))
85
+ subprocess.call(
86
+ "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath),
87
+ shell=True,
88
+ )
89
+
90
+ statinfo = os.stat(zip_filepath)
91
+ logger.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
92
+
93
+ # concatenate all parts into zip files
94
+ if ".zip" not in zip_filepath:
95
+ zip_filepath = "_".join(zip_filepath.split("_")[:-1])
96
+ subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True)
97
+ zip_filepath += ".zip"
98
+ extract_path = zip_filepath.strip(".zip")
99
+
100
+ # check zip file md5sum
101
+ with open(zip_filepath, "rb") as f_zip:
102
+ md5 = hashlib.md5(f_zip.read()).hexdigest()
103
+ if md5 != MD5SUM[subset]:
104
+ raise ValueError("md5sum of %s mismatch" % zip_filepath)
105
+
106
+ with zipfile.ZipFile(zip_filepath, "r") as zfile:
107
+ zfile.extractall(directory)
108
+ extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename)
109
+ subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True)
110
+ finally:
111
+ # os.remove(zip_filepath)
112
+ pass
113
+
114
+
115
+ def exec_cmd(cmd):
116
+ """Run a command in a subprocess.
117
+ Args:
118
+ cmd: command line to be executed.
119
+ Return:
120
+ int, the return code.
121
+ """
122
+ try:
123
+ retcode = subprocess.call(cmd, shell=True)
124
+ if retcode < 0:
125
+ logger.info(f"Child was terminated by signal {retcode}")
126
+ except OSError as e:
127
+ logger.info(f"Execution failed: {e}")
128
+ retcode = -999
129
+ return retcode
130
+
131
+
132
+ def decode_aac_with_ffmpeg(aac_file, wav_file):
133
+ """Decode a given AAC file into WAV using ffmpeg.
134
+ Args:
135
+ aac_file: file path to input AAC file.
136
+ wav_file: file path to output WAV file.
137
+ Return:
138
+ bool, True if success.
139
+ """
140
+ cmd = f"ffmpeg -i {aac_file} {wav_file}"
141
+ logger.info(f"Decoding aac file using command line: {cmd}")
142
+ ret = exec_cmd(cmd)
143
+ if ret != 0:
144
+ logger.error(f"Failed to decode aac file with retcode {ret}")
145
+ logger.error("Please check your ffmpeg installation.")
146
+ return False
147
+ return True
148
+
149
+
150
+ def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
151
+ """Optionally convert AAC to WAV and make speaker labels.
152
+ Args:
153
+ input_dir: the directory which holds the input dataset.
154
+ subset: the name of the specified subset. e.g. vox1_dev_wav
155
+ output_dir: the directory to place the newly generated csv files.
156
+ output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv
157
+ """
158
+
159
+ logger.info("Preprocessing audio and label for subset %s" % subset)
160
+ source_dir = os.path.join(input_dir, subset)
161
+
162
+ files = []
163
+ # Convert all AAC file into WAV format. At the same time, generate the csv
164
+ for root, _, filenames in os.walk(source_dir):
165
+ for filename in filenames:
166
+ name, ext = os.path.splitext(filename)
167
+ if ext.lower() == ".wav":
168
+ _, ext2 = os.path.splitext(name)
169
+ if ext2:
170
+ continue
171
+ wav_file = os.path.join(root, filename)
172
+ elif ext.lower() == ".m4a":
173
+ # Convert AAC to WAV.
174
+ aac_file = os.path.join(root, filename)
175
+ wav_file = aac_file + ".wav"
176
+ if not os.path.exists(wav_file):
177
+ if not decode_aac_with_ffmpeg(aac_file, wav_file):
178
+ raise RuntimeError("Audio decoding failed.")
179
+ else:
180
+ continue
181
+ speaker_name = root.split(os.path.sep)[-2]
182
+ if speaker_name not in speaker_id_dict:
183
+ num = len(speaker_id_dict)
184
+ speaker_id_dict[speaker_name] = num
185
+ # wav_filesize = os.path.getsize(wav_file)
186
+ wav_length = len(sf.read(wav_file)[0])
187
+ files.append((os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name))
188
+
189
+ # Write to CSV file which contains four columns:
190
+ # "wav_filename", "wav_length_ms", "speaker_id", "speaker_name".
191
+ csv_file_path = os.path.join(output_dir, output_file)
192
+ with open(csv_file_path, "w", newline="", encoding="utf-8") as f:
193
+ writer = csv.writer(f, delimiter="\t")
194
+ writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
195
+ for wav_file in files:
196
+ writer.writerow(wav_file)
197
+ logger.info("Successfully generated csv file {}".format(csv_file_path))
198
+
199
+
200
+ def processor(directory, subset, force_process):
201
+ """download and process"""
202
+ urls = SUBSETS
203
+ if subset not in urls:
204
+ raise ValueError(subset, "is not in voxceleb")
205
+
206
+ subset_csv = os.path.join(directory, subset + ".csv")
207
+ if not force_process and os.path.exists(subset_csv):
208
+ return subset_csv
209
+
210
+ logger.info("Downloading and process the voxceleb in %s", directory)
211
+ logger.info("Preparing subset %s", subset)
212
+ download_and_extract(directory, subset, urls[subset])
213
+ convert_audio_and_make_label(directory, subset, directory, subset + ".csv")
214
+ logger.info("Finished downloading and processing")
215
+ return subset_csv
216
+
217
+
218
+ if __name__ == "__main__":
219
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
220
+ if len(sys.argv) != 4:
221
+ print("Usage: python prepare_data.py save_directory user password")
222
+ sys.exit()
223
+
224
+ DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3]
225
+ for SUBSET in SUBSETS:
226
+ processor(DIR, SUBSET, False)
TTS/encoder/utils/training.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+
4
+ from coqpit import Coqpit
5
+ from trainer import TrainerArgs, get_last_checkpoint
6
+ from trainer.generic_utils import get_experiment_folder_path, get_git_branch
7
+ from trainer.io import copy_model_files
8
+ from trainer.logging import logger_factory
9
+ from trainer.logging.console_logger import ConsoleLogger
10
+
11
+ from TTS.config import load_config, register_config
12
+ from TTS.tts.utils.text.characters import parse_symbols
13
+
14
+
15
+ @dataclass
16
+ class TrainArgs(TrainerArgs):
17
+ config_path: str = field(default=None, metadata={"help": "Path to the config file."})
18
+
19
+
20
+ def getarguments():
21
+ train_config = TrainArgs()
22
+ parser = train_config.init_argparse(arg_prefix="")
23
+ return parser
24
+
25
+
26
+ def process_args(args, config=None):
27
+ """Process parsed comand line arguments and initialize the config if not provided.
28
+ Args:
29
+ args (argparse.Namespace or dict like): Parsed input arguments.
30
+ config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
31
+ Returns:
32
+ c (Coqpit): Config paramaters.
33
+ out_path (str): Path to save models and logging.
34
+ audio_path (str): Path to save generated test audios.
35
+ c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
36
+ logging to the console.
37
+ dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
38
+ TODO:
39
+ - Interactive config definition.
40
+ """
41
+ if isinstance(args, tuple):
42
+ args, coqpit_overrides = args
43
+ if args.continue_path:
44
+ # continue a previous training from its output folder
45
+ experiment_path = args.continue_path
46
+ args.config_path = os.path.join(args.continue_path, "config.json")
47
+ args.restore_path, best_model = get_last_checkpoint(args.continue_path)
48
+ if not args.best_path:
49
+ args.best_path = best_model
50
+ # init config if not already defined
51
+ if config is None:
52
+ if args.config_path:
53
+ # init from a file
54
+ config = load_config(args.config_path)
55
+ else:
56
+ # init from console args
57
+ from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
58
+
59
+ config_base = BaseTrainingConfig()
60
+ config_base.parse_known_args(coqpit_overrides)
61
+ config = register_config(config_base.model)()
62
+ # override values from command-line args
63
+ config.parse_known_args(coqpit_overrides, relaxed_parser=True)
64
+ experiment_path = args.continue_path
65
+ if not experiment_path:
66
+ experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
67
+ audio_path = os.path.join(experiment_path, "test_audios")
68
+ config.output_log_path = experiment_path
69
+ # setup rank 0 process in distributed training
70
+ dashboard_logger = None
71
+ if args.rank == 0:
72
+ new_fields = {}
73
+ if args.restore_path:
74
+ new_fields["restore_path"] = args.restore_path
75
+ new_fields["github_branch"] = get_git_branch()
76
+ # if model characters are not set in the config file
77
+ # save the default set to the config file for future
78
+ # compatibility.
79
+ if config.has("characters") and config.characters is None:
80
+ used_characters = parse_symbols()
81
+ new_fields["characters"] = used_characters
82
+ copy_model_files(config, experiment_path, new_fields)
83
+ dashboard_logger = logger_factory(config, experiment_path)
84
+ c_logger = ConsoleLogger()
85
+ return config, experiment_path, audio_path, c_logger, dashboard_logger
86
+
87
+
88
+ def init_arguments():
89
+ train_config = TrainArgs()
90
+ parser = train_config.init_argparse(arg_prefix="")
91
+ return parser
92
+
93
+
94
+ def init_training(config: Coqpit = None):
95
+ """Initialization of a training run."""
96
+ parser = init_arguments()
97
+ args = parser.parse_known_args()
98
+ config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
99
+ return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger
TTS/encoder/utils/visual.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+
5
+ matplotlib.use("Agg")
6
+
7
+
8
+ colormap = (
9
+ np.array(
10
+ [
11
+ [76, 255, 0],
12
+ [0, 127, 70],
13
+ [255, 0, 0],
14
+ [255, 217, 38],
15
+ [0, 135, 255],
16
+ [165, 0, 165],
17
+ [255, 167, 255],
18
+ [0, 255, 255],
19
+ [255, 96, 38],
20
+ [142, 76, 0],
21
+ [33, 0, 127],
22
+ [0, 0, 0],
23
+ [183, 183, 183],
24
+ ],
25
+ dtype=float,
26
+ )
27
+ / 255
28
+ )
29
+
30
+
31
+ def plot_embeddings(embeddings, num_classes_in_batch):
32
+ try:
33
+ import umap
34
+ except ImportError as e:
35
+ raise ImportError("Package not installed: umap-learn") from e
36
+ num_utter_per_class = embeddings.shape[0] // num_classes_in_batch
37
+
38
+ # if necessary get just the first 10 classes
39
+ if num_classes_in_batch > 10:
40
+ num_classes_in_batch = 10
41
+ embeddings = embeddings[: num_classes_in_batch * num_utter_per_class]
42
+
43
+ model = umap.UMAP()
44
+ projection = model.fit_transform(embeddings)
45
+ ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class)
46
+ colors = [colormap[i] for i in ground_truth]
47
+ fig, ax = plt.subplots(figsize=(16, 10))
48
+ _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors)
49
+ plt.gca().set_aspect("equal", "datalim")
50
+ plt.title("UMAP projection")
51
+ plt.tight_layout()
52
+ plt.savefig("umap")
53
+ return fig
TTS/model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import abstractmethod
3
+ from typing import Any, Union
4
+
5
+ import torch
6
+ from coqpit import Coqpit
7
+ from trainer import TrainerModel
8
+
9
+ # pylint: skip-file
10
+
11
+
12
+ class BaseTrainerModel(TrainerModel):
13
+ """BaseTrainerModel model expanding TrainerModel with required functions by 🐸TTS.
14
+
15
+ Every new 🐸TTS model must inherit it.
16
+ """
17
+
18
+ @staticmethod
19
+ @abstractmethod
20
+ def init_from_config(config: Coqpit) -> "BaseTrainerModel":
21
+ """Init the model and all its attributes from the given config.
22
+
23
+ Override this depending on your model.
24
+ """
25
+ ...
26
+
27
+ @abstractmethod
28
+ def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]:
29
+ """Forward pass for inference.
30
+
31
+ It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs```
32
+ is considered to be the main output and you can add any other auxiliary outputs as you want.
33
+
34
+ We don't use `*kwargs` since it is problematic with the TorchScript API.
35
+
36
+ Args:
37
+ input (torch.Tensor): [description]
38
+ aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc.
39
+
40
+ Returns:
41
+ Dict: [description]
42
+ """
43
+ outputs_dict = {"model_outputs": None}
44
+ ...
45
+ return outputs_dict
46
+
47
+ @abstractmethod
48
+ def load_checkpoint(
49
+ self,
50
+ config: Coqpit,
51
+ checkpoint_path: Union[str, os.PathLike[Any]],
52
+ eval: bool = False,
53
+ strict: bool = True,
54
+ cache: bool = False,
55
+ ) -> None:
56
+ """Load a model checkpoint file and get ready for training or inference.
57
+
58
+ Args:
59
+ config (Coqpit): Model configuration.
60
+ checkpoint_path (str | os.PathLike): Path to the model checkpoint file.
61
+ eval (bool, optional): If true, init model for inference else for training. Defaults to False.
62
+ strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
63
+ cache (bool, optional): If True, cache the file locally for subsequent calls.
64
+ It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False.
65
+ """
66
+ ...
TTS/server/README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # :frog: TTS demo server
2
+ Before you use the server, make sure you
3
+ [install](https://github.com/idiap/coqui-ai-TTS/tree/dev#install-tts)) :frog: TTS
4
+ properly and install the additional dependencies with `pip install
5
+ coqui-tts[server]`. Then, you can follow the steps below.
6
+
7
+ **Note:** If you install :frog:TTS using ```pip```, you can also use the ```tts-server``` end point on the terminal.
8
+
9
+ Examples runs:
10
+
11
+ List officially released models.
12
+ ```python TTS/server/server.py --list_models ```
13
+
14
+ Run the server with the official models.
15
+ ```python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan```
16
+
17
+ Run the server with the official models on a GPU.
18
+ ```CUDA_VISIBLE_DEVICES="0" python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan --use_cuda```
19
+
20
+ Run the server with a custom models.
21
+ ```python TTS/server/server.py --tts_checkpoint /path/to/tts/model.pth --tts_config /path/to/tts/config.json --vocoder_checkpoint /path/to/vocoder/model.pth --vocoder_config /path/to/vocoder/config.json```
TTS/server/__init__.py ADDED
File without changes
TTS/server/conf.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tts_path":"/media/erogol/data_ssd/Models/libri_tts/5049/", // tts model root folder
3
+ "tts_file":"best_model.pth", // tts checkpoint file
4
+ "tts_config":"config.json", // tts config.json file
5
+ "tts_speakers": null, // json file listing speaker ids. null if no speaker embedding.
6
+ "vocoder_config":null,
7
+ "vocoder_file": null,
8
+ "is_wavernn_batched":true,
9
+ "port": 5002,
10
+ "use_cuda": true,
11
+ "debug": true
12
+ }
TTS/server/server.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!flask/bin/python
2
+
3
+ """TTS demo server."""
4
+
5
+ import argparse
6
+ import io
7
+ import json
8
+ import logging
9
+ import os
10
+ import sys
11
+ from pathlib import Path
12
+ from threading import Lock
13
+ from typing import Union
14
+ from urllib.parse import parse_qs
15
+
16
+ try:
17
+ from flask import Flask, render_template, render_template_string, request, send_file
18
+ except ImportError as e:
19
+ msg = "Server requires requires flask, use `pip install coqui-tts[server]`"
20
+ raise ImportError(msg) from e
21
+
22
+ from TTS.config import load_config
23
+ from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
24
+ from TTS.utils.manage import ModelManager
25
+ from TTS.utils.synthesizer import Synthesizer
26
+
27
+ logger = logging.getLogger(__name__)
28
+ setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
29
+
30
+
31
+ def create_argparser() -> argparse.ArgumentParser:
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument(
34
+ "--list_models",
35
+ action="store_true",
36
+ help="list available pre-trained tts and vocoder models.",
37
+ )
38
+ parser.add_argument(
39
+ "--model_name",
40
+ type=str,
41
+ default="tts_models/en/ljspeech/tacotron2-DDC",
42
+ help="Name of one of the pre-trained tts models in format <language>/<dataset>/<model_name>",
43
+ )
44
+ parser.add_argument("--vocoder_name", type=str, default=None, help="name of one of the released vocoder models.")
45
+
46
+ # Args for running custom models
47
+ parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.")
48
+ parser.add_argument(
49
+ "--model_path",
50
+ type=str,
51
+ default=None,
52
+ help="Path to model file.",
53
+ )
54
+ parser.add_argument(
55
+ "--vocoder_path",
56
+ type=str,
57
+ help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).",
58
+ default=None,
59
+ )
60
+ parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None)
61
+ parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
62
+ parser.add_argument("--port", type=int, default=5002, help="port to listen on.")
63
+ parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, default=False, help="true to use CUDA.")
64
+ parser.add_argument(
65
+ "--debug", action=argparse.BooleanOptionalAction, default=False, help="true to enable Flask debug mode."
66
+ )
67
+ parser.add_argument(
68
+ "--show_details", action=argparse.BooleanOptionalAction, default=False, help="Generate model detail page."
69
+ )
70
+ return parser
71
+
72
+
73
+ # parse the args
74
+ args = create_argparser().parse_args()
75
+
76
+ path = Path(__file__).parent / "../.models.json"
77
+ manager = ModelManager(path)
78
+
79
+ # update in-use models to the specified released models.
80
+ model_path = None
81
+ config_path = None
82
+ speakers_file_path = None
83
+ vocoder_path = None
84
+ vocoder_config_path = None
85
+
86
+ # CASE1: list pre-trained TTS models
87
+ if args.list_models:
88
+ manager.list_models()
89
+ sys.exit()
90
+
91
+ # CASE2: load pre-trained model paths
92
+ if args.model_name is not None and not args.model_path:
93
+ model_path, config_path, model_item = manager.download_model(args.model_name)
94
+ args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
95
+
96
+ if args.vocoder_name is not None and not args.vocoder_path:
97
+ vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
98
+
99
+ # CASE3: set custom model paths
100
+ if args.model_path is not None:
101
+ model_path = args.model_path
102
+ config_path = args.config_path
103
+ speakers_file_path = args.speakers_file_path
104
+
105
+ if args.vocoder_path is not None:
106
+ vocoder_path = args.vocoder_path
107
+ vocoder_config_path = args.vocoder_config_path
108
+
109
+ # load models
110
+ synthesizer = Synthesizer(
111
+ tts_checkpoint=model_path,
112
+ tts_config_path=config_path,
113
+ tts_speakers_file=speakers_file_path,
114
+ tts_languages_file=None,
115
+ vocoder_checkpoint=vocoder_path,
116
+ vocoder_config=vocoder_config_path,
117
+ encoder_checkpoint="",
118
+ encoder_config="",
119
+ use_cuda=args.use_cuda,
120
+ )
121
+
122
+ use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and (
123
+ synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None
124
+ )
125
+ speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None)
126
+
127
+ use_multi_language = hasattr(synthesizer.tts_model, "num_languages") and (
128
+ synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None
129
+ )
130
+ language_manager = getattr(synthesizer.tts_model, "language_manager", None)
131
+
132
+ # TODO: set this from SpeakerManager
133
+ use_gst = synthesizer.tts_config.get("use_gst", False)
134
+ app = Flask(__name__)
135
+
136
+
137
+ def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
138
+ """Transform an uri style_wav, in either a string (path to wav file to be use for style transfer)
139
+ or a dict (gst tokens/values to be use for styling)
140
+
141
+ Args:
142
+ style_wav (str): uri
143
+
144
+ Returns:
145
+ Union[str, dict]: path to file (str) or gst style (dict)
146
+ """
147
+ if style_wav:
148
+ if os.path.isfile(style_wav) and style_wav.endswith(".wav"):
149
+ return style_wav # style_wav is a .wav file located on the server
150
+
151
+ style_wav = json.loads(style_wav)
152
+ return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...}
153
+ return None
154
+
155
+
156
+ @app.route("/")
157
+ def index():
158
+ return render_template(
159
+ "index.html",
160
+ show_details=args.show_details,
161
+ use_multi_speaker=use_multi_speaker,
162
+ use_multi_language=use_multi_language,
163
+ speaker_ids=speaker_manager.name_to_id if speaker_manager is not None else None,
164
+ language_ids=language_manager.name_to_id if language_manager is not None else None,
165
+ use_gst=use_gst,
166
+ )
167
+
168
+
169
+ @app.route("/details")
170
+ def details():
171
+ if args.config_path is not None and os.path.isfile(args.config_path):
172
+ model_config = load_config(args.config_path)
173
+ elif args.model_name is not None:
174
+ model_config = load_config(config_path)
175
+
176
+ if args.vocoder_config_path is not None and os.path.isfile(args.vocoder_config_path):
177
+ vocoder_config = load_config(args.vocoder_config_path)
178
+ elif args.vocoder_name is not None:
179
+ vocoder_config = load_config(vocoder_config_path)
180
+ else:
181
+ vocoder_config = None
182
+
183
+ return render_template(
184
+ "details.html",
185
+ show_details=args.show_details,
186
+ model_config=model_config,
187
+ vocoder_config=vocoder_config,
188
+ args=args.__dict__,
189
+ )
190
+
191
+
192
+ lock = Lock()
193
+
194
+
195
+ @app.route("/api/tts", methods=["GET", "POST"])
196
+ def tts():
197
+ with lock:
198
+ text = request.headers.get("text") or request.values.get("text", "")
199
+ speaker_idx = request.headers.get("speaker-id") or request.values.get("speaker_id", "")
200
+ language_idx = request.headers.get("language-id") or request.values.get("language_id", "")
201
+ style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "")
202
+ style_wav = style_wav_uri_to_dict(style_wav)
203
+
204
+ logger.info("Model input: %s", text)
205
+ logger.info("Speaker idx: %s", speaker_idx)
206
+ logger.info("Language idx: %s", language_idx)
207
+ wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav)
208
+ out = io.BytesIO()
209
+ synthesizer.save_wav(wavs, out)
210
+ return send_file(out, mimetype="audio/wav")
211
+
212
+
213
+ # Basic MaryTTS compatibility layer
214
+
215
+
216
+ @app.route("/locales", methods=["GET"])
217
+ def mary_tts_api_locales():
218
+ """MaryTTS-compatible /locales endpoint"""
219
+ # NOTE: We currently assume there is only one model active at the same time
220
+ if args.model_name is not None:
221
+ model_details = args.model_name.split("/")
222
+ else:
223
+ model_details = ["", "en", "", "default"]
224
+ return render_template_string("{{ locale }}\n", locale=model_details[1])
225
+
226
+
227
+ @app.route("/voices", methods=["GET"])
228
+ def mary_tts_api_voices():
229
+ """MaryTTS-compatible /voices endpoint"""
230
+ # NOTE: We currently assume there is only one model active at the same time
231
+ if args.model_name is not None:
232
+ model_details = args.model_name.split("/")
233
+ else:
234
+ model_details = ["", "en", "", "default"]
235
+ return render_template_string(
236
+ "{{ name }} {{ locale }} {{ gender }}\n", name=model_details[3], locale=model_details[1], gender="u"
237
+ )
238
+
239
+
240
+ @app.route("/process", methods=["GET", "POST"])
241
+ def mary_tts_api_process():
242
+ """MaryTTS-compatible /process endpoint"""
243
+ with lock:
244
+ if request.method == "POST":
245
+ data = parse_qs(request.get_data(as_text=True))
246
+ # NOTE: we ignore param. LOCALE and VOICE for now since we have only one active model
247
+ text = data.get("INPUT_TEXT", [""])[0]
248
+ else:
249
+ text = request.args.get("INPUT_TEXT", "")
250
+ logger.info("Model input: %s", text)
251
+ wavs = synthesizer.tts(text)
252
+ out = io.BytesIO()
253
+ synthesizer.save_wav(wavs, out)
254
+ return send_file(out, mimetype="audio/wav")
255
+
256
+
257
+ def main():
258
+ app.run(debug=args.debug, host="::", port=args.port)
259
+
260
+
261
+ if __name__ == "__main__":
262
+ main()
TTS/server/static/coqui-log-green-TTS.png ADDED
TTS/server/templates/details.html ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+
6
+ <meta charset="utf-8">
7
+ <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
8
+ <meta name="description" content="">
9
+ <meta name="author" content="">
10
+
11
+ <title>TTS engine</title>
12
+
13
+ <!-- Bootstrap core CSS -->
14
+ <link href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.1/css/bootstrap.min.css"
15
+ integrity="sha384-WskhaSGFgHYWDcbwN70/dfYBj47jz9qbsMId/iRN3ewGhXQFZCSftd1LZCfmhktB" crossorigin="anonymous"
16
+ rel="stylesheet">
17
+
18
+ <!-- Custom styles for this template -->
19
+ <style>
20
+ body {
21
+ padding-top: 54px;
22
+ }
23
+
24
+ @media (min-width: 992px) {
25
+ body {
26
+ padding-top: 56px;
27
+ }
28
+ }
29
+ </style>
30
+ </head>
31
+
32
+ <body>
33
+ <a href="https://github.com/mozilla/TTS"><img style="position: absolute; z-index:1000; top: 0; left: 0; border: 0;"
34
+ src="https://s3.amazonaws.com/github/ribbons/forkme_left_darkblue_121621.png" alt="Fork me on GitHub"></a>
35
+
36
+ {% if show_details == true %}
37
+
38
+ <div class="container">
39
+ <b>Model details</b>
40
+ </div>
41
+
42
+ <div class="container">
43
+ <details>
44
+ <summary>CLI arguments:</summary>
45
+ <table border="1" align="center" width="75%">
46
+ <tr>
47
+ <td> CLI key </td>
48
+ <td> Value </td>
49
+ </tr>
50
+
51
+ {% for key, value in args.items() %}
52
+
53
+ <tr>
54
+ <td>{{ key }}</td>
55
+ <td>{{ value }}</td>
56
+ </tr>
57
+
58
+ {% endfor %}
59
+ </table>
60
+ </details>
61
+ </div></br>
62
+
63
+ <div class="container">
64
+
65
+ {% if model_config != None %}
66
+
67
+ <details>
68
+ <summary>Model config:</summary>
69
+
70
+ <table border="1" align="center" width="75%">
71
+ <tr>
72
+ <td> Key </td>
73
+ <td> Value </td>
74
+ </tr>
75
+
76
+
77
+ {% for key, value in model_config.items() %}
78
+
79
+ <tr>
80
+ <td>{{ key }}</td>
81
+ <td>{{ value }}</td>
82
+ </tr>
83
+
84
+ {% endfor %}
85
+
86
+ </table>
87
+ </details>
88
+
89
+ {% endif %}
90
+
91
+ </div></br>
92
+
93
+
94
+
95
+ <div class="container">
96
+ {% if vocoder_config != None %}
97
+ <details>
98
+ <summary>Vocoder model config:</summary>
99
+
100
+ <table border="1" align="center" width="75%">
101
+ <tr>
102
+ <td> Key </td>
103
+ <td> Value </td>
104
+ </tr>
105
+
106
+
107
+ {% for key, value in vocoder_config.items() %}
108
+
109
+ <tr>
110
+ <td>{{ key }}</td>
111
+ <td>{{ value }}</td>
112
+ </tr>
113
+
114
+ {% endfor %}
115
+
116
+
117
+ </table>
118
+ </details>
119
+ {% endif %}
120
+ </div></br>
121
+
122
+ {% else %}
123
+ <div class="container">
124
+ <b>Please start server with --show_details=true to see details.</b>
125
+ </div>
126
+
127
+ {% endif %}
128
+
129
+ </body>
130
+
131
+ </html>
TTS/server/templates/index.html ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+
6
+ <meta charset="utf-8">
7
+ <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
8
+ <meta name="description" content="🐸Coqui AI TTS demo server.">
9
+ <meta name="author" content="🐸Coqui AI TTS">
10
+
11
+ <title>TTS engine</title>
12
+
13
+ <!-- Bootstrap core CSS -->
14
+ <link href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.1/css/bootstrap.min.css"
15
+ integrity="sha384-WskhaSGFgHYWDcbwN70/dfYBj47jz9qbsMId/iRN3ewGhXQFZCSftd1LZCfmhktB" crossorigin="anonymous"
16
+ rel="stylesheet">
17
+
18
+ <!-- Custom styles for this template -->
19
+ <style>
20
+ body {
21
+ padding-top: 54px;
22
+ }
23
+
24
+ @media (min-width: 992px) {
25
+ body {
26
+ padding-top: 56px;
27
+ }
28
+ }
29
+ </style>
30
+ </head>
31
+
32
+ <body>
33
+ <a href="https://github.com/idiap/coqui-ai-TTS"><img style="position: absolute; z-index:1000; top: 0; left: 0; border: 0;"
34
+ src="https://s3.amazonaws.com/github/ribbons/forkme_left_darkblue_121621.png" alt="Fork me on GitHub"></a>
35
+
36
+ <!-- Navigation -->
37
+ <!--
38
+ <nav class="navbar navbar-expand-lg navbar-dark bg-dark fixed-top">
39
+ <div class="container">
40
+ <a class="navbar-brand" href="#">Coqui TTS</a>
41
+ <button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbarResponsive" aria-controls="navbarResponsive" aria-expanded="false" aria-label="Toggle navigation">
42
+ <span class="navbar-toggler-icon"></span>
43
+ </button>
44
+ <div class="collapse navbar-collapse" id="navbarResponsive">
45
+ <ul class="navbar-nav ml-auto">
46
+ <li class="nav-item active">
47
+ <a class="nav-link" href="#">Home
48
+ <span class="sr-only">(current)</span>
49
+ </a>
50
+ </li>
51
+ </ul>
52
+ </div>
53
+ </div>
54
+ </nav>
55
+ -->
56
+
57
+ <!-- Page Content -->
58
+ <div class="container">
59
+ <div class="row">
60
+ <div class="col-lg-12 text-center">
61
+ <img class="mt-5" src="{{url_for('static', filename='coqui-log-green-TTS.png')}}" align="middle"
62
+ width="512" />
63
+
64
+ <ul class="list-unstyled">
65
+ </ul>
66
+
67
+ {%if use_gst%}
68
+ <input value='{"0": 0.1}' id="style_wav" placeholder="style wav (dict or path to wav).." size=45
69
+ type="text" name="style_wav">
70
+ {%endif%}
71
+
72
+ <input id="text" placeholder="Type here..." size=45 type="text" name="text">
73
+ <button id="speak-button" name="speak">Speak</button><br /><br />
74
+
75
+ {%if use_multi_speaker%}
76
+ Choose a speaker:
77
+ <select id="speaker_id" name=speaker_id method="GET" action="/">
78
+ {% for speaker_id in speaker_ids %}
79
+ <option value="{{speaker_id}}" SELECTED>{{speaker_id}}</option>"
80
+ {% endfor %}
81
+ </select><br /><br />
82
+ {%endif%}
83
+
84
+ {%if use_multi_language%}
85
+ Choose a language:
86
+ <select id="language_id" name=language_id method="GET" action="/">
87
+ {% for language_id in language_ids %}
88
+ <option value="{{language_id}}" SELECTED>{{language_id}}</option>"
89
+ {% endfor %}
90
+ </select><br /><br />
91
+ {%endif%}
92
+
93
+
94
+ {%if show_details%}
95
+ <button id="details-button" onclick="location.href = 'details'" name="model-details">Model
96
+ Details</button><br /><br />
97
+ {%endif%}
98
+ <audio id="audio" controls autoplay hidden></audio>
99
+ <p id="message"></p>
100
+ </div>
101
+ </div>
102
+ </div>
103
+
104
+ <!-- Bootstrap core JavaScript -->
105
+ <script>
106
+ function getTextValue(textId) {
107
+ const container = q(textId)
108
+ if (container) {
109
+ return container.value
110
+ }
111
+ return ""
112
+ }
113
+ function q(selector) { return document.querySelector(selector) }
114
+ q('#text').focus()
115
+ function do_tts(e) {
116
+ const text = q('#text').value
117
+ const speaker_id = getTextValue('#speaker_id')
118
+ const style_wav = getTextValue('#style_wav')
119
+ const language_id = getTextValue('#language_id')
120
+ if (text) {
121
+ q('#message').textContent = 'Synthesizing...'
122
+ q('#speak-button').disabled = true
123
+ q('#audio').hidden = true
124
+ synthesize(text, speaker_id, style_wav, language_id)
125
+ }
126
+ e.preventDefault()
127
+ return false
128
+ }
129
+ q('#speak-button').addEventListener('click', do_tts)
130
+ q('#text').addEventListener('keyup', function (e) {
131
+ if (e.keyCode == 13) { // enter
132
+ do_tts(e)
133
+ }
134
+ })
135
+ function synthesize(text, speaker_id = "", style_wav = "", language_id = "") {
136
+ fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker_id=${encodeURIComponent(speaker_id)}&style_wav=${encodeURIComponent(style_wav)}&language_id=${encodeURIComponent(language_id)}`, { cache: 'no-cache' })
137
+ .then(function (res) {
138
+ if (!res.ok) throw Error(res.statusText)
139
+ return res.blob()
140
+ }).then(function (blob) {
141
+ q('#message').textContent = ''
142
+ q('#speak-button').disabled = false
143
+ q('#audio').src = URL.createObjectURL(blob)
144
+ q('#audio').hidden = false
145
+ }).catch(function (err) {
146
+ q('#message').textContent = 'Error: ' + err.message
147
+ q('#speak-button').disabled = false
148
+ })
149
+ }
150
+ </script>
151
+
152
+ </body>
153
+
154
+ </html>
TTS/tts/__init__.py ADDED
File without changes
TTS/tts/configs/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ from inspect import isclass
4
+
5
+ # import all files under configs/
6
+ # configs_dir = os.path.dirname(__file__)
7
+ # for file in os.listdir(configs_dir):
8
+ # path = os.path.join(configs_dir, file)
9
+ # if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
10
+ # config_name = file[: file.find(".py")] if file.endswith(".py") else file
11
+ # module = importlib.import_module("TTS.tts.configs." + config_name)
12
+ # for attribute_name in dir(module):
13
+ # attribute = getattr(module, attribute_name)
14
+
15
+ # if isclass(attribute):
16
+ # # Add the class to this package's variables
17
+ # globals()[attribute_name] = attribute