viks66 commited on
Commit
221d863
1 Parent(s): 7f4762b
Files changed (2) hide show
  1. extra.py +787 -0
  2. jit_infer.py +33 -0
extra.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, List, Union
2
+ from dataclasses import asdict, dataclass, field
3
+
4
+
5
+ import re
6
+ from dataclasses import replace
7
+ from typing import Dict
8
+ _whitespace_re = re.compile(r"\s+")
9
+
10
+ from dataclasses import dataclass, field
11
+ from typing import List
12
+
13
+ # from TTS.tts.configs.shared_configs import BaseTTSConfig
14
+ # from TTS.tts.models.vits import VitsArgs, VitsAudioConfig
15
+
16
+ @dataclass
17
+ class CharactersConfig():
18
+
19
+ characters_class: str = None
20
+
21
+ # using BaseVocabulary
22
+ vocab_dict: Dict = None
23
+
24
+ # using on BaseCharacters
25
+ pad: str = None
26
+ eos: str = None
27
+ bos: str = None
28
+ blank: str = None
29
+ characters: str = None
30
+ punctuations: str = None
31
+ phonemes: str = None
32
+ is_unique: bool = True # for backwards compatibility of models trained with char sets with duplicates
33
+ is_sorted: bool = True
34
+
35
+
36
+ @dataclass
37
+ class BaseTTSConfig():
38
+
39
+ # audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
40
+ # phoneme settings
41
+ use_phonemes: bool = False
42
+ phonemizer: str = None
43
+ phoneme_language: str = None
44
+ compute_input_seq_cache: bool = False
45
+ text_cleaner: str = None
46
+ enable_eos_bos_chars: bool = False
47
+ test_sentences_file: str = ""
48
+ phoneme_cache_path: str = None
49
+ # vocabulary parameters
50
+ characters: CharactersConfig = None
51
+ add_blank: bool = False
52
+ # training params
53
+ batch_group_size: int = 0
54
+ loss_masking: bool = None
55
+ # dataloading
56
+ min_audio_len: int = 1
57
+ max_audio_len: int = float("inf")
58
+ min_text_len: int = 1
59
+ max_text_len: int = float("inf")
60
+ compute_f0: bool = False
61
+ compute_energy: bool = False
62
+ compute_linear_spec: bool = False
63
+ precompute_num_workers: int = 0
64
+ use_noise_augment: bool = False
65
+ start_by_longest: bool = False
66
+ shuffle: bool = False
67
+ drop_last: bool = False
68
+ # dataset
69
+ datasets: str = None
70
+ # optimizer
71
+ optimizer: str = "radam"
72
+ optimizer_params: dict = None
73
+ # scheduler
74
+ lr_scheduler: str = None
75
+ lr_scheduler_params: dict = field(default_factory=lambda: {})
76
+ # testing
77
+ test_sentences: List[str] = field(default_factory=lambda: [])
78
+ # evaluation
79
+ eval_split_max_size: int = None
80
+ eval_split_size: float = 0.01
81
+ # weighted samplers
82
+ use_speaker_weighted_sampler: bool = False
83
+ speaker_weighted_sampler_alpha: float = 1.0
84
+ use_language_weighted_sampler: bool = False
85
+ language_weighted_sampler_alpha: float = 1.0
86
+ use_length_weighted_sampler: bool = False
87
+ length_weighted_sampler_alpha: float = 1.0
88
+
89
+
90
+ @dataclass
91
+ class VitsAudioConfig():
92
+ fft_size: int = 1024
93
+ sample_rate: int = 22050
94
+ win_length: int = 1024
95
+ hop_length: int = 256
96
+ num_mels: int = 80
97
+ mel_fmin: int = 0
98
+ mel_fmax: int = None
99
+
100
+ @dataclass
101
+ class VitsArgs():
102
+ num_chars: int = 100
103
+ out_channels: int = 513
104
+ spec_segment_size: int = 32
105
+ hidden_channels: int = 192
106
+ hidden_channels_ffn_text_encoder: int = 768
107
+ num_heads_text_encoder: int = 2
108
+ num_layers_text_encoder: int = 6
109
+ kernel_size_text_encoder: int = 3
110
+ dropout_p_text_encoder: float = 0.1
111
+ dropout_p_duration_predictor: float = 0.5
112
+ kernel_size_posterior_encoder: int = 5
113
+ dilation_rate_posterior_encoder: int = 1
114
+ num_layers_posterior_encoder: int = 16
115
+ kernel_size_flow: int = 5
116
+ dilation_rate_flow: int = 1
117
+ num_layers_flow: int = 4
118
+ resblock_type_decoder: str = "1"
119
+ resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
120
+ resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
121
+ upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
122
+ upsample_initial_channel_decoder: int = 512
123
+ upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
124
+ periods_multi_period_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11])
125
+ use_sdp: bool = True
126
+ noise_scale: float = 1.0
127
+ inference_noise_scale: float = 0.667
128
+ length_scale: float = 1
129
+ noise_scale_dp: float = 1.0
130
+ inference_noise_scale_dp: float = 1.0
131
+ max_inference_len: int = None
132
+ init_discriminator: bool = True
133
+ use_spectral_norm_disriminator: bool = False
134
+ use_speaker_embedding: bool = False
135
+ num_speakers: int = 0
136
+ speakers_file: str = None
137
+ d_vector_file: List[str] = None
138
+ speaker_embedding_channels: int = 256
139
+ use_d_vector_file: bool = False
140
+ d_vector_dim: int = 0
141
+ detach_dp_input: bool = True
142
+ use_language_embedding: bool = False
143
+ embedded_language_dim: int = 4
144
+ num_languages: int = 0
145
+ language_ids_file: str = None
146
+ use_speaker_encoder_as_loss: bool = False
147
+ speaker_encoder_config_path: str = ""
148
+ speaker_encoder_model_path: str = ""
149
+ condition_dp_on_speaker: bool = True
150
+ freeze_encoder: bool = False
151
+ freeze_DP: bool = False
152
+ freeze_PE: bool = False
153
+ freeze_flow_decoder: bool = False
154
+ freeze_waveform_decoder: bool = False
155
+ encoder_sample_rate: int = None
156
+ interpolate_z: bool = True
157
+ reinit_DP: bool = False
158
+ reinit_text_encoder: bool = False
159
+ @dataclass
160
+ class VitsConfig(BaseTTSConfig):
161
+
162
+ model: str = "vits"
163
+ # model specific params
164
+ model_args: VitsArgs = field(default_factory=VitsArgs)
165
+ audio: VitsAudioConfig = field(default_factory=VitsAudioConfig)
166
+
167
+ # optimizer
168
+ grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
169
+ lr_gen: float = 0.0002
170
+ lr_disc: float = 0.0002
171
+ lr_scheduler_gen: str = "ExponentialLR"
172
+ lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
173
+ lr_scheduler_disc: str = "ExponentialLR"
174
+ lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
175
+ scheduler_after_epoch: bool = True
176
+ optimizer: str = "AdamW"
177
+ optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})
178
+
179
+ # loss params
180
+ kl_loss_alpha: float = 1.0
181
+ disc_loss_alpha: float = 1.0
182
+ gen_loss_alpha: float = 1.0
183
+ feat_loss_alpha: float = 1.0
184
+ mel_loss_alpha: float = 45.0
185
+ dur_loss_alpha: float = 1.0
186
+ speaker_encoder_loss_alpha: float = 1.0
187
+
188
+ # data loader params
189
+ return_wav: bool = True
190
+ compute_linear_spec: bool = True
191
+
192
+ # sampler params
193
+ use_weighted_sampler: bool = False # TODO: move it to the base config
194
+ weighted_sampler_attrs: dict = field(default_factory=lambda: {})
195
+ weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
196
+
197
+ # overrides
198
+ r: int = 1 # DO NOT CHANGE
199
+ add_blank: bool = True
200
+
201
+ # testing
202
+ test_sentences: List[List] = field(
203
+ default_factory=lambda: [
204
+ ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
205
+ ["Be a voice, not an echo."],
206
+ ["I'm sorry Dave. I'm afraid I can't do that."],
207
+ ["This cake is great. It's so delicious and moist."],
208
+ ["Prior to November 22, 1963."],
209
+ ]
210
+ )
211
+
212
+ # multi-speaker settings
213
+ # use speaker embedding layer
214
+ num_speakers: int = 0
215
+ use_speaker_embedding: bool = False
216
+ speakers_file: str = None
217
+ speaker_embedding_channels: int = 256
218
+ language_ids_file: str = None
219
+ use_language_embedding: bool = False
220
+
221
+ # use d-vectors
222
+ use_d_vector_file: bool = False
223
+ d_vector_file: List[str] = None
224
+ d_vector_dim: int = None
225
+
226
+ def __post_init__(self):
227
+ pass
228
+ # for key, val in self.model_args.items():
229
+ # if hasattr(self, key):
230
+ # self[key] = val
231
+
232
+
233
+
234
+
235
+
236
+ def parse_symbols():
237
+ return {
238
+ "pad": _pad,
239
+ "eos": _eos,
240
+ "bos": _bos,
241
+ "characters": _characters,
242
+ "punctuations": _punctuations,
243
+ "phonemes": _phonemes,
244
+ }
245
+
246
+
247
+ # DEFAULT SET OF GRAPHEMES
248
+ _pad = "<PAD>"
249
+ _eos = "<EOS>"
250
+ _bos = "<BOS>"
251
+ _blank = "<BLNK>" # TODO: check if we need this alongside with PAD
252
+ _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
253
+ _punctuations = "!'(),-.:;? "
254
+
255
+
256
+ # DEFAULT SET OF IPA PHONEMES
257
+ # Phonemes definition (All IPA characters)
258
+ _vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ"
259
+ _non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ"
260
+ _pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ"
261
+ _suprasegmentals = "ˈˌːˑ"
262
+ _other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ"
263
+ _diacrilics = "ɚ˞ɫ"
264
+ _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
265
+
266
+
267
+ class BaseVocabulary:
268
+ """Base Vocabulary class.
269
+
270
+ This class only needs a vocabulary dictionary without specifying the characters.
271
+
272
+ Args:
273
+ vocab (Dict): A dictionary of characters and their corresponding indices.
274
+ """
275
+
276
+ def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None):
277
+ self.vocab = vocab
278
+ self.pad = pad
279
+ self.blank = blank
280
+ self.bos = bos
281
+ self.eos = eos
282
+
283
+ @property
284
+ def pad_id(self) -> int:
285
+ """Return the index of the padding character. If the padding character is not specified, return the length
286
+ of the vocabulary."""
287
+ return self.char_to_id(self.pad) if self.pad else len(self.vocab)
288
+
289
+ @property
290
+ def blank_id(self) -> int:
291
+ """Return the index of the blank character. If the blank character is not specified, return the length of
292
+ the vocabulary."""
293
+ return self.char_to_id(self.blank) if self.blank else len(self.vocab)
294
+
295
+ @property
296
+ def bos_id(self) -> int:
297
+ """Return the index of the bos character. If the bos character is not specified, return the length of the
298
+ vocabulary."""
299
+ return self.char_to_id(self.bos) if self.bos else len(self.vocab)
300
+
301
+ @property
302
+ def eos_id(self) -> int:
303
+ """Return the index of the eos character. If the eos character is not specified, return the length of the
304
+ vocabulary."""
305
+ return self.char_to_id(self.eos) if self.eos else len(self.vocab)
306
+
307
+ @property
308
+ def vocab(self):
309
+ """Return the vocabulary dictionary."""
310
+ return self._vocab
311
+
312
+ @vocab.setter
313
+ def vocab(self, vocab):
314
+ """Set the vocabulary dictionary and character mapping dictionaries."""
315
+ self._vocab, self._char_to_id, self._id_to_char = None, None, None
316
+ if vocab is not None:
317
+ self._vocab = vocab
318
+ self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
319
+ self._id_to_char = {
320
+ idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
321
+ }
322
+
323
+ @staticmethod
324
+ def init_from_config(config, **kwargs):
325
+ """Initialize from the given config."""
326
+ if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict:
327
+ return (
328
+ BaseVocabulary(
329
+ config.characters.vocab_dict,
330
+ config.characters.pad,
331
+ config.characters.blank,
332
+ config.characters.bos,
333
+ config.characters.eos,
334
+ ),
335
+ config,
336
+ )
337
+ return BaseVocabulary(**kwargs), config
338
+
339
+ def to_config(self):
340
+ return CharactersConfig(
341
+ vocab_dict=self._vocab,
342
+ pad=self.pad,
343
+ eos=self.eos,
344
+ bos=self.bos,
345
+ blank=self.blank,
346
+ is_unique=False,
347
+ is_sorted=False,
348
+ )
349
+
350
+ @property
351
+ def num_chars(self):
352
+ """Return number of tokens in the vocabulary."""
353
+ return len(self._vocab)
354
+
355
+ def char_to_id(self, char: str) -> int:
356
+ """Map a character to an token ID."""
357
+ try:
358
+ return self._char_to_id[char]
359
+ except KeyError as e:
360
+ raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
361
+
362
+ def id_to_char(self, idx: int) -> str:
363
+ """Map an token ID to a character."""
364
+ return self._id_to_char[idx]
365
+
366
+
367
+ class BaseCharacters:
368
+
369
+
370
+ def __init__(
371
+ self,
372
+ characters: str = None,
373
+ punctuations: str = None,
374
+ pad: str = None,
375
+ eos: str = None,
376
+ bos: str = None,
377
+ blank: str = None,
378
+ is_unique: bool = False,
379
+ is_sorted: bool = True,
380
+ ) -> None:
381
+ self._characters = characters
382
+ self._punctuations = punctuations
383
+ self._pad = pad
384
+ self._eos = eos
385
+ self._bos = bos
386
+ self._blank = blank
387
+ self.is_unique = is_unique
388
+ self.is_sorted = is_sorted
389
+ self._create_vocab()
390
+
391
+ @property
392
+ def pad_id(self) -> int:
393
+ return self.char_to_id(self.pad) if self.pad else len(self.vocab)
394
+
395
+ @property
396
+ def blank_id(self) -> int:
397
+ return self.char_to_id(self.blank) if self.blank else len(self.vocab)
398
+
399
+ @property
400
+ def eos_id(self) -> int:
401
+ return self.char_to_id(self.eos) if self.eos else len(self.vocab)
402
+
403
+ @property
404
+ def bos_id(self) -> int:
405
+ return self.char_to_id(self.bos) if self.bos else len(self.vocab)
406
+
407
+ @property
408
+ def characters(self):
409
+ return self._characters
410
+
411
+ @characters.setter
412
+ def characters(self, characters):
413
+ self._characters = characters
414
+ self._create_vocab()
415
+
416
+ @property
417
+ def punctuations(self):
418
+ return self._punctuations
419
+
420
+ @punctuations.setter
421
+ def punctuations(self, punctuations):
422
+ self._punctuations = punctuations
423
+ self._create_vocab()
424
+
425
+ @property
426
+ def pad(self):
427
+ return self._pad
428
+
429
+ @pad.setter
430
+ def pad(self, pad):
431
+ self._pad = pad
432
+ self._create_vocab()
433
+
434
+ @property
435
+ def eos(self):
436
+ return self._eos
437
+
438
+ @eos.setter
439
+ def eos(self, eos):
440
+ self._eos = eos
441
+ self._create_vocab()
442
+
443
+ @property
444
+ def bos(self):
445
+ return self._bos
446
+
447
+ @bos.setter
448
+ def bos(self, bos):
449
+ self._bos = bos
450
+ self._create_vocab()
451
+
452
+ @property
453
+ def blank(self):
454
+ return self._blank
455
+
456
+ @blank.setter
457
+ def blank(self, blank):
458
+ self._blank = blank
459
+ self._create_vocab()
460
+
461
+ @property
462
+ def vocab(self):
463
+ return self._vocab
464
+
465
+ @vocab.setter
466
+ def vocab(self, vocab):
467
+ self._vocab = vocab
468
+ self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
469
+ self._id_to_char = {
470
+ idx: char for idx, char in enumerate(self.vocab) # pylint: disable=unnecessary-comprehension
471
+ }
472
+
473
+ @property
474
+ def num_chars(self):
475
+ return len(self._vocab)
476
+
477
+ def _create_vocab(self):
478
+ _vocab = self._characters
479
+ if self.is_unique:
480
+ _vocab = list(set(_vocab))
481
+ if self.is_sorted:
482
+ _vocab = sorted(_vocab)
483
+ _vocab = list(_vocab)
484
+ _vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab
485
+ _vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab
486
+ _vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab
487
+ _vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab
488
+ self.vocab = _vocab + list(self._punctuations)
489
+ if self.is_unique:
490
+ duplicates = {x for x in self.vocab if self.vocab.count(x) > 1}
491
+ assert (
492
+ len(self.vocab) == len(self._char_to_id) == len(self._id_to_char)
493
+ ), f" [!] There are duplicate characters in the character set. {duplicates}"
494
+
495
+ def char_to_id(self, char: str) -> int:
496
+ try:
497
+ return self._char_to_id[char]
498
+ except KeyError as e:
499
+ raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e
500
+
501
+ def id_to_char(self, idx: int) -> str:
502
+ return self._id_to_char[idx]
503
+
504
+ def print_log(self, level: int = 0):
505
+ """
506
+ Prints the vocabulary in a nice format.
507
+ """
508
+ indent = "\t" * level
509
+ print(f"{indent}| > Characters: {self._characters}")
510
+ print(f"{indent}| > Punctuations: {self._punctuations}")
511
+ print(f"{indent}| > Pad: {self._pad}")
512
+ print(f"{indent}| > EOS: {self._eos}")
513
+ print(f"{indent}| > BOS: {self._bos}")
514
+ print(f"{indent}| > Blank: {self._blank}")
515
+ print(f"{indent}| > Vocab: {self.vocab}")
516
+ print(f"{indent}| > Num chars: {self.num_chars}")
517
+
518
+ @staticmethod
519
+ def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument
520
+ """Init your character class from a config.
521
+
522
+ Implement this method for your subclass.
523
+ """
524
+ # use character set from config
525
+ if config.characters is not None:
526
+ return BaseCharacters(**config.characters), config
527
+ # return default character set
528
+ characters = BaseCharacters()
529
+ new_config = replace(config, characters=characters.to_config())
530
+ return characters, new_config
531
+
532
+ def to_config(self) -> "CharactersConfig":
533
+ return CharactersConfig(
534
+ characters=self._characters,
535
+ punctuations=self._punctuations,
536
+ pad=self._pad,
537
+ eos=self._eos,
538
+ bos=self._bos,
539
+ blank=self._blank,
540
+ is_unique=self.is_unique,
541
+ is_sorted=self.is_sorted,
542
+ )
543
+
544
+
545
+ class IPAPhonemes(BaseCharacters):
546
+
547
+
548
+ def __init__(
549
+ self,
550
+ characters: str = _phonemes,
551
+ punctuations: str = _punctuations,
552
+ pad: str = _pad,
553
+ eos: str = _eos,
554
+ bos: str = _bos,
555
+ blank: str = _blank,
556
+ is_unique: bool = False,
557
+ is_sorted: bool = True,
558
+ ) -> None:
559
+ super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted)
560
+
561
+ @staticmethod
562
+ def init_from_config(config: "Coqpit"):
563
+ """Init a IPAPhonemes object from a model config
564
+
565
+ If characters are not defined in the config, it will be set to the default characters and the config
566
+ will be updated.
567
+ """
568
+ # band-aid for compatibility with old models
569
+ if "characters" in config and config.characters is not None:
570
+ if "phonemes" in config.characters and config.characters.phonemes is not None:
571
+ config.characters["characters"] = config.characters["phonemes"]
572
+ return (
573
+ IPAPhonemes(
574
+ characters=config.characters["characters"],
575
+ punctuations=config.characters["punctuations"],
576
+ pad=config.characters["pad"],
577
+ eos=config.characters["eos"],
578
+ bos=config.characters["bos"],
579
+ blank=config.characters["blank"],
580
+ is_unique=config.characters["is_unique"],
581
+ is_sorted=config.characters["is_sorted"],
582
+ ),
583
+ config,
584
+ )
585
+ # use character set from config
586
+ if config.characters is not None:
587
+ return IPAPhonemes(**config.characters), config
588
+ # return default character set
589
+ characters = IPAPhonemes()
590
+ new_config = replace(config, characters=characters.to_config())
591
+ return characters, new_config
592
+
593
+
594
+ class Graphemes(BaseCharacters):
595
+
596
+
597
+ def __init__(
598
+ self,
599
+ characters: str = _characters,
600
+ punctuations: str = _punctuations,
601
+ pad: str = _pad,
602
+ eos: str = _eos,
603
+ bos: str = _bos,
604
+ blank: str = _blank,
605
+ is_unique: bool = False,
606
+ is_sorted: bool = True,
607
+ ) -> None:
608
+ super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted)
609
+
610
+ @staticmethod
611
+ def init_from_config(config: "Coqpit"):
612
+ """Init a Graphemes object from a model config
613
+
614
+ If characters are not defined in the config, it will be set to the default characters and the config
615
+ will be updated.
616
+ """
617
+ if config.characters is not None:
618
+ # band-aid for compatibility with old models
619
+ if "phonemes" in config.characters:
620
+ return (
621
+ Graphemes(
622
+ characters=config.characters["characters"],
623
+ punctuations=config.characters["punctuations"],
624
+ pad=config.characters["pad"],
625
+ eos=config.characters["eos"],
626
+ bos=config.characters["bos"],
627
+ blank=config.characters["blank"],
628
+ is_unique=config.characters["is_unique"],
629
+ is_sorted=config.characters["is_sorted"],
630
+ ),
631
+ config,
632
+ )
633
+ return Graphemes(**config.characters), config
634
+ characters = Graphemes()
635
+ new_config = replace(config, characters=characters.to_config())
636
+ return characters, new_config
637
+
638
+
639
+ if __name__ == "__main__":
640
+ gr = Graphemes()
641
+ ph = IPAPhonemes()
642
+ gr.print_log()
643
+ ph.print_log()
644
+
645
+
646
+ class VitsCharacters(BaseCharacters):
647
+ """Characters class for VITs model for compatibility with pre-trained models"""
648
+
649
+ def __init__(
650
+ self,
651
+ graphemes: str = _characters,
652
+ punctuations: str = _punctuations,
653
+ pad: str = _pad,
654
+ ipa_characters: str = _phonemes,
655
+ ) -> None:
656
+ if ipa_characters is not None:
657
+ graphemes += ipa_characters
658
+ super().__init__(graphemes, punctuations, pad, None, None, "<BLNK>", is_unique=False, is_sorted=True)
659
+
660
+ def _create_vocab(self):
661
+ self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
662
+ self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
663
+ # pylint: disable=unnecessary-comprehension
664
+ self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
665
+
666
+ @staticmethod
667
+ def init_from_config(config):
668
+ _pad = config.characters.pad
669
+ _punctuations = config.characters.punctuations
670
+ _letters = config.characters.characters
671
+ _letters_ipa = config.characters.phonemes
672
+ return (
673
+ VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad),
674
+ config,
675
+ )
676
+
677
+ def to_config(self) -> "CharactersConfig":
678
+ return CharactersConfig(
679
+ characters=self._characters,
680
+ punctuations=self._punctuations,
681
+ pad=self._pad,
682
+ eos=None,
683
+ bos=None,
684
+ blank=self._blank,
685
+ is_unique=False,
686
+ is_sorted=True,
687
+ )
688
+
689
+ class TTSTokenizer:
690
+ def __init__(
691
+ self,
692
+ text_cleaner: Callable = None,
693
+ characters: "BaseCharacters" = None,
694
+ ):
695
+ self.text_cleaner = text_cleaner
696
+ self.characters = characters
697
+ self.not_found_characters = []
698
+
699
+ @property
700
+ def characters(self):
701
+ return self._characters
702
+
703
+ @characters.setter
704
+ def characters(self, new_characters):
705
+ self._characters = new_characters
706
+ self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None
707
+ self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None
708
+
709
+ def encode(self, text: str) -> List[int]:
710
+ """Encodes a string of text as a sequence of IDs."""
711
+ token_ids = []
712
+ for char in text:
713
+ try:
714
+ idx = self.characters.char_to_id(char)
715
+ token_ids.append(idx)
716
+ except KeyError:
717
+ # discard but store not found characters
718
+ if char not in self.not_found_characters:
719
+ self.not_found_characters.append(char)
720
+ print(text)
721
+ print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
722
+ return token_ids
723
+
724
+ def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint: disable=unused-argument
725
+ text = self.text_cleaner(text)
726
+ text = self.encode(text)
727
+ text = self.intersperse_blank_char(text, True)
728
+ return text
729
+
730
+ def pad_with_bos_eos(self, char_sequence: List[str]):
731
+ """Pads a sequence with the special BOS and EOS characters."""
732
+ return [self.characters.bos_id] + list(char_sequence) + [self.characters.eos_id]
733
+
734
+ def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
735
+ """Intersperses the blank character between characters in a sequence.
736
+
737
+ Use the ```blank``` character if defined else use the ```pad``` character.
738
+ """
739
+ char_to_use = self.characters.blank_id if use_blank_char else self.characters.pad
740
+ result = [char_to_use] * (len(char_sequence) * 2 + 1)
741
+ result[1::2] = char_sequence
742
+ return result
743
+
744
+ @staticmethod
745
+ def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None):
746
+ text_cleaner = multilingual_cleaners
747
+ CharactersClass = VitsCharacters
748
+ characters, new_config = CharactersClass.init_from_config(config)
749
+ # new_config.characters.characters_class = get_import_path(characters)
750
+ new_config.characters.characters_class = VitsCharacters
751
+ return (
752
+ TTSTokenizer(text_cleaner, characters),new_config)
753
+
754
+
755
+ def multilingual_cleaners(text):
756
+ """Pipeline for multilingual text"""
757
+ text = lowercase(text)
758
+ text = replace_symbols(text, lang=None)
759
+ text = remove_aux_symbols(text)
760
+ text = collapse_whitespace(text)
761
+ return text
762
+
763
+ def lowercase(text):
764
+ return text.lower()
765
+
766
+ def collapse_whitespace(text):
767
+ return re.sub(_whitespace_re, " ", text).strip()
768
+
769
+ def replace_symbols(text, lang="en"):
770
+
771
+ text = text.replace(";", ",")
772
+ text = text.replace("-", " ") if lang != "ca" else text.replace("-", "")
773
+ text = text.replace(":", ",")
774
+ if lang == "en":
775
+ text = text.replace("&", " and ")
776
+ elif lang == "fr":
777
+ text = text.replace("&", " et ")
778
+ elif lang == "pt":
779
+ text = text.replace("&", " e ")
780
+ elif lang == "ca":
781
+ text = text.replace("&", " i ")
782
+ text = text.replace("'", "")
783
+ return text
784
+
785
+ def remove_aux_symbols(text):
786
+ text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text)
787
+ return text
jit_infer.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from extra import TTSTokenizer, VitsConfig, CharactersConfig, VitsCharacters
3
+ import torch
4
+ import numpy as np
5
+
6
+ #ch female
7
+ with open("chars.txt", 'r') as f:
8
+ letters = f.read().strip('\n')
9
+ model="mg_male_vits_30hrs.pt"
10
+ # text = " হলেও আমাদের সবার সার্বিক শৃঙ্খলা বোধের উন্নতি হবে"
11
+ text = "भेजना चाहते हैं हिंदी में मैसेज लेकिन नहीं आती टाइपिंग?"
12
+
13
+ config = VitsConfig(
14
+ text_cleaner="multilingual_cleaners",
15
+ characters=CharactersConfig(
16
+ characters_class=VitsCharacters,
17
+ pad="<PAD>",
18
+ eos="<EOS>",
19
+ bos="<BOS>",
20
+ blank="<BLNK>",
21
+ characters=letters,
22
+ punctuations="!¡'(),-.:;¿? ",
23
+ phonemes=None)
24
+ )
25
+ tokenizer, config = TTSTokenizer.init_from_config(config)
26
+
27
+ x = tokenizer.text_to_ids(text)
28
+ x = torch.from_numpy(np.array(x)).unsqueeze(0)
29
+ net = torch.jit.load(model)
30
+ with torch.no_grad():
31
+ out2 = net(x)
32
+ import soundfile as sf
33
+ sf.write("jit.wav", out2.squeeze().cpu().numpy(), 22050)