karim23657 commited on
Commit
b03541a
1 Parent(s): f046206

Upload train_vits.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_vits.py +107 -0
train_vits.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from trainer import Trainer, TrainerArgs
4
+
5
+ from TTS.tts.configs.shared_configs import BaseDatasetConfig , CharactersConfig
6
+ from TTS.config.shared_configs import BaseAudioConfig
7
+ from TTS.tts.configs.vits_config import VitsConfig
8
+ from TTS.tts.datasets import load_tts_samples
9
+ from TTS.tts.models.vits import Vits, VitsAudioConfig
10
+ from TTS.tts.utils.text.tokenizer import TTSTokenizer
11
+ from TTS.utils.audio import AudioProcessor
12
+ from TTS.utils.downloaders import download_thorsten_de
13
+
14
+ output_path = os.path.dirname(os.path.abspath(__file__))
15
+
16
+
17
+
18
+ dataset_config = BaseDatasetConfig(
19
+ formatter='mozilla', meta_file_train="metadata.csv", path="/kaggle/input/kaggle-gptinformal-persian"
20
+ )
21
+
22
+
23
+
24
+ audio_config = BaseAudioConfig(
25
+ sample_rate=24000,
26
+ do_trim_silence=False,
27
+ resample=False,
28
+ mel_fmin=0,
29
+ mel_fmax=None
30
+ )
31
+ character_config=CharactersConfig(
32
+ characters='ءابتثجحخدذرزسشصضطظعغفقلمنهويِپچژکگیآأؤإئًَُّ',
33
+ punctuations='!(),-.:;? ̠،؛؟‌<>',
34
+ phonemes='ˈˌːˑpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟaegiouwyɪʊ̩æɑɔəɚɛɝɨ̃ʉʌʍ0123456789"#$%*+/=ABCDEFGHIJKLMNOPRSTUVWXYZ[]^_{}',
35
+ pad="<PAD>",
36
+ eos="<EOS>",
37
+ bos="<BOS>",
38
+ blank="<BLNK>",
39
+ characters_class="TTS.tts.utils.text.characters.IPAPhonemes",
40
+ )
41
+ config = VitsConfig(
42
+ audio=audio_config,
43
+ run_name="vits_fa_female",
44
+ batch_size=32,
45
+ eval_batch_size=16,
46
+ batch_group_size=5,
47
+ num_loader_workers=0,
48
+ num_eval_loader_workers=2,
49
+ run_eval=True,
50
+ test_delay_epochs=-1,
51
+ epochs=1000,
52
+ save_step=1000,
53
+ text_cleaner="basic_cleaners",
54
+ use_phonemes=True,
55
+ phoneme_language="fa",
56
+ characters=character_config,
57
+ phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
58
+ compute_input_seq_cache=True,
59
+ print_step=25,
60
+ print_eval=True,
61
+ mixed_precision=False,
62
+ test_sentences=[
63
+ ["سلطان محمود در زمستانی سخت به طلخک گفت که: با این جامه ی یک لا در این سرما چه می کنی "],
64
+ ["مردی نزد بقالی آمد و گفت پیاز هم ده تا دهان بدان خو شبوی سازم."],
65
+ ["از مال خود پاره ای گوشت بستان و زیره بایی معطّر بساز"],
66
+ ["یک بار هم از جهنم بگویید."],
67
+ ["یکی اسبی به عاریت خواست"]
68
+ ],
69
+ output_path=output_path,
70
+ datasets=[dataset_config],
71
+ )
72
+
73
+ # INITIALIZE THE AUDIO PROCESSOR
74
+ # Audio processor is used for feature extraction and audio I/O.
75
+ # It mainly serves to the dataloader and the training loggers.
76
+ ap = AudioProcessor.init_from_config(config)
77
+
78
+ # INITIALIZE THE TOKENIZER
79
+ # Tokenizer is used to convert text to sequences of token IDs.
80
+ # config is updated with the default characters if not defined in the config.
81
+ tokenizer, config = TTSTokenizer.init_from_config(config)
82
+
83
+ # LOAD DATA SAMPLES
84
+ # Each sample is a list of ```[text, audio_file_path, speaker_name]```
85
+ # You can define your custom sample loader returning the list of samples.
86
+ # Or define your custom formatter and pass it to the `load_tts_samples`.
87
+ # Check `TTS.tts.datasets.load_tts_samples` for more details.
88
+ train_samples, eval_samples = load_tts_samples(
89
+ dataset_config,
90
+ eval_split=True,
91
+ eval_split_max_size=config.eval_split_max_size,
92
+ eval_split_size=config.eval_split_size,
93
+ )
94
+
95
+ # init model
96
+ model = Vits(config, ap, tokenizer, speaker_manager=None)
97
+
98
+ # init the trainer and 🚀
99
+ trainer = Trainer(
100
+ TrainerArgs(),
101
+ config,
102
+ output_path,
103
+ model=model,
104
+ train_samples=train_samples,
105
+ eval_samples=eval_samples,
106
+ )
107
+ trainer.fit()