Spaces:
Running
Running
Commit
•
26925fd
0
Parent(s):
Duplicate from zlc99/M4Singer
Browse filesCo-authored-by: Lichao Zhang <zlc99@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +34 -0
- README.md +14 -0
- checkpoints/m4singer_diff_e2e/config.yaml +348 -0
- checkpoints/m4singer_diff_e2e/model_ckpt_steps_900000.ckpt +3 -0
- checkpoints/m4singer_fs2_e2e/config.yaml +347 -0
- checkpoints/m4singer_fs2_e2e/model_ckpt_steps_320000.ckpt +3 -0
- checkpoints/m4singer_hifigan/config.yaml +246 -0
- checkpoints/m4singer_hifigan/model_ckpt_steps_1970000.ckpt +3 -0
- checkpoints/m4singer_pe/config.yaml +172 -0
- checkpoints/m4singer_pe/model_ckpt_steps_280000.ckpt +3 -0
- configs/config_base.yaml +42 -0
- configs/singing/base.yaml +42 -0
- configs/singing/fs2.yaml +3 -0
- configs/tts/base.yaml +95 -0
- configs/tts/base_zh.yaml +3 -0
- configs/tts/fs2.yaml +80 -0
- configs/tts/hifigan.yaml +21 -0
- configs/tts/lj/base_mel2wav.yaml +3 -0
- configs/tts/lj/base_text2mel.yaml +13 -0
- configs/tts/lj/fs2.yaml +3 -0
- configs/tts/lj/hifigan.yaml +3 -0
- configs/tts/lj/pwg.yaml +3 -0
- configs/tts/pwg.yaml +110 -0
- data_gen/singing/binarize.py +393 -0
- data_gen/tts/base_binarizer.py +224 -0
- data_gen/tts/bin/binarize.py +20 -0
- data_gen/tts/binarizer_zh.py +59 -0
- data_gen/tts/data_gen_utils.py +347 -0
- data_gen/tts/txt_processors/base_text_processor.py +8 -0
- data_gen/tts/txt_processors/en.py +78 -0
- data_gen/tts/txt_processors/zh.py +41 -0
- data_gen/tts/txt_processors/zh_g2pM.py +71 -0
- inference/m4singer/base_svs_infer.py +242 -0
- inference/m4singer/ds_e2e.py +67 -0
- inference/m4singer/gradio/gradio_settings.yaml +48 -0
- inference/m4singer/gradio/infer.py +143 -0
- inference/m4singer/gradio/share_btn.py +86 -0
- inference/m4singer/m4singer/m4singer_pinyin2ph.txt +413 -0
- inference/m4singer/m4singer/map.py +7 -0
- modules/__init__.py +0 -0
- modules/commons/common_layers.py +668 -0
- modules/commons/espnet_positional_embedding.py +113 -0
- modules/commons/ssim.py +391 -0
- modules/diffsinger_midi/fs2.py +118 -0
- modules/fastspeech/fs2.py +255 -0
- modules/fastspeech/pe.py +149 -0
- modules/fastspeech/tts_modules.py +357 -0
- modules/hifigan/hifigan.py +370 -0
- modules/hifigan/mel_utils.py +81 -0
- modules/parallel_wavegan/__init__.py +0 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: 111
|
3 |
+
emoji: 🎶
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.8.1
|
8 |
+
app_file: inference/m4singer/gradio/infer.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: zlc99/M4Singer
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
14 |
+
|
checkpoints/m4singer_diff_e2e/config.yaml
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
K_step: 1000
|
2 |
+
accumulate_grad_batches: 1
|
3 |
+
audio_num_mel_bins: 80
|
4 |
+
audio_sample_rate: 24000
|
5 |
+
base_config:
|
6 |
+
- usr/configs/m4singer/base.yaml
|
7 |
+
binarization_args:
|
8 |
+
shuffle: false
|
9 |
+
with_align: true
|
10 |
+
with_f0: true
|
11 |
+
with_f0cwt: true
|
12 |
+
with_spk_embed: true
|
13 |
+
with_txt: true
|
14 |
+
with_wav: false
|
15 |
+
binarizer_cls: data_gen.singing.binarize.M4SingerBinarizer
|
16 |
+
binary_data_dir: data/binary/m4singer
|
17 |
+
check_val_every_n_epoch: 10
|
18 |
+
clip_grad_norm: 1
|
19 |
+
content_cond_steps: []
|
20 |
+
cwt_add_f0_loss: false
|
21 |
+
cwt_hidden_size: 128
|
22 |
+
cwt_layers: 2
|
23 |
+
cwt_loss: l1
|
24 |
+
cwt_std_scale: 0.8
|
25 |
+
datasets:
|
26 |
+
- m4singer
|
27 |
+
debug: false
|
28 |
+
dec_ffn_kernel_size: 9
|
29 |
+
dec_layers: 4
|
30 |
+
decay_steps: 100000
|
31 |
+
decoder_type: fft
|
32 |
+
dict_dir: ''
|
33 |
+
diff_decoder_type: wavenet
|
34 |
+
diff_loss_type: l1
|
35 |
+
dilation_cycle_length: 4
|
36 |
+
dropout: 0.1
|
37 |
+
ds_workers: 4
|
38 |
+
dur_enc_hidden_stride_kernel:
|
39 |
+
- 0,2,3
|
40 |
+
- 0,2,3
|
41 |
+
- 0,1,3
|
42 |
+
dur_loss: mse
|
43 |
+
dur_predictor_kernel: 3
|
44 |
+
dur_predictor_layers: 5
|
45 |
+
enc_ffn_kernel_size: 9
|
46 |
+
enc_layers: 4
|
47 |
+
encoder_K: 8
|
48 |
+
encoder_type: fft
|
49 |
+
endless_ds: true
|
50 |
+
ffn_act: gelu
|
51 |
+
ffn_padding: SAME
|
52 |
+
fft_size: 512
|
53 |
+
fmax: 12000
|
54 |
+
fmin: 30
|
55 |
+
fs2_ckpt: checkpoints/m4singer_fs2_e2e
|
56 |
+
gaussian_start: true
|
57 |
+
gen_dir_name: ''
|
58 |
+
gen_tgt_spk_id: -1
|
59 |
+
hidden_size: 256
|
60 |
+
hop_size: 128
|
61 |
+
infer: false
|
62 |
+
keep_bins: 80
|
63 |
+
lambda_commit: 0.25
|
64 |
+
lambda_energy: 0.0
|
65 |
+
lambda_f0: 0.0
|
66 |
+
lambda_ph_dur: 1.0
|
67 |
+
lambda_sent_dur: 1.0
|
68 |
+
lambda_uv: 0.0
|
69 |
+
lambda_word_dur: 1.0
|
70 |
+
load_ckpt: ''
|
71 |
+
log_interval: 100
|
72 |
+
loud_norm: false
|
73 |
+
lr: 0.001
|
74 |
+
max_beta: 0.02
|
75 |
+
max_epochs: 1000
|
76 |
+
max_eval_sentences: 1
|
77 |
+
max_eval_tokens: 60000
|
78 |
+
max_frames: 5000
|
79 |
+
max_input_tokens: 1550
|
80 |
+
max_sentences: 28
|
81 |
+
max_tokens: 36000
|
82 |
+
max_updates: 900000
|
83 |
+
mel_loss: ssim:0.5|l1:0.5
|
84 |
+
mel_vmax: 1.5
|
85 |
+
mel_vmin: -6.0
|
86 |
+
min_level_db: -120
|
87 |
+
norm_type: gn
|
88 |
+
num_ckpt_keep: 3
|
89 |
+
num_heads: 2
|
90 |
+
num_sanity_val_steps: 1
|
91 |
+
num_spk: 20
|
92 |
+
num_test_samples: 0
|
93 |
+
num_valid_plots: 10
|
94 |
+
optimizer_adam_beta1: 0.9
|
95 |
+
optimizer_adam_beta2: 0.98
|
96 |
+
out_wav_norm: false
|
97 |
+
pe_ckpt: checkpoints/m4singer_pe
|
98 |
+
pe_enable: true
|
99 |
+
pitch_ar: false
|
100 |
+
pitch_enc_hidden_stride_kernel:
|
101 |
+
- 0,2,5
|
102 |
+
- 0,2,5
|
103 |
+
- 0,2,5
|
104 |
+
pitch_extractor: parselmouth
|
105 |
+
pitch_loss: l1
|
106 |
+
pitch_norm: log
|
107 |
+
pitch_type: frame
|
108 |
+
pndm_speedup: 10
|
109 |
+
pre_align_args:
|
110 |
+
allow_no_txt: false
|
111 |
+
denoise: false
|
112 |
+
forced_align: mfa
|
113 |
+
txt_processor: zh_g2pM
|
114 |
+
use_sox: true
|
115 |
+
use_tone: false
|
116 |
+
pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
|
117 |
+
predictor_dropout: 0.5
|
118 |
+
predictor_grad: 0.1
|
119 |
+
predictor_hidden: -1
|
120 |
+
predictor_kernel: 5
|
121 |
+
predictor_layers: 5
|
122 |
+
prenet_dropout: 0.5
|
123 |
+
prenet_hidden_size: 256
|
124 |
+
pretrain_fs_ckpt: ''
|
125 |
+
processed_data_dir: xxx
|
126 |
+
profile_infer: false
|
127 |
+
raw_data_dir: data/raw/m4singer
|
128 |
+
ref_norm_layer: bn
|
129 |
+
rel_pos: true
|
130 |
+
reset_phone_dict: true
|
131 |
+
residual_channels: 256
|
132 |
+
residual_layers: 20
|
133 |
+
save_best: false
|
134 |
+
save_ckpt: true
|
135 |
+
save_codes:
|
136 |
+
- configs
|
137 |
+
- modules
|
138 |
+
- tasks
|
139 |
+
- utils
|
140 |
+
- usr
|
141 |
+
save_f0: true
|
142 |
+
save_gt: true
|
143 |
+
schedule_type: linear
|
144 |
+
seed: 1234
|
145 |
+
sort_by_len: true
|
146 |
+
spec_max:
|
147 |
+
- -0.3894500136375427
|
148 |
+
- -0.3796464204788208
|
149 |
+
- -0.2914905250072479
|
150 |
+
- -0.15550297498703003
|
151 |
+
- -0.08502643555402756
|
152 |
+
- 0.10698417574167252
|
153 |
+
- -0.0739326998591423
|
154 |
+
- -0.0541548952460289
|
155 |
+
- 0.15501998364925385
|
156 |
+
- 0.06483431905508041
|
157 |
+
- 0.03054228238761425
|
158 |
+
- -0.013737732544541359
|
159 |
+
- -0.004876468330621719
|
160 |
+
- 0.04368264228105545
|
161 |
+
- 0.13329921662807465
|
162 |
+
- 0.16471388936042786
|
163 |
+
- 0.04605761915445328
|
164 |
+
- -0.05680707097053528
|
165 |
+
- 0.0542571023106575
|
166 |
+
- -0.0076539707370102406
|
167 |
+
- -0.00953489076346159
|
168 |
+
- -0.04434828832745552
|
169 |
+
- 0.001293870504014194
|
170 |
+
- -0.12238839268684387
|
171 |
+
- 0.06418416649103165
|
172 |
+
- 0.02843189612030983
|
173 |
+
- 0.08505241572856903
|
174 |
+
- 0.07062800228595734
|
175 |
+
- 0.00120724702719599
|
176 |
+
- -0.07675088942050934
|
177 |
+
- 0.03785804659128189
|
178 |
+
- 0.04890783503651619
|
179 |
+
- -0.06888376921415329
|
180 |
+
- -0.0839693546295166
|
181 |
+
- -0.17545585334300995
|
182 |
+
- -0.2911079525947571
|
183 |
+
- -0.4238220453262329
|
184 |
+
- -0.262084037065506
|
185 |
+
- -0.3002263605594635
|
186 |
+
- -0.3845032751560211
|
187 |
+
- -0.3906497061252594
|
188 |
+
- -0.6550108790397644
|
189 |
+
- -0.7810799479484558
|
190 |
+
- -0.7503029704093933
|
191 |
+
- -0.7995198965072632
|
192 |
+
- -0.8092347383499146
|
193 |
+
- -0.6196113228797913
|
194 |
+
- -0.6684317588806152
|
195 |
+
- -0.7735874056816101
|
196 |
+
- -0.8324533104896545
|
197 |
+
- -0.9601566791534424
|
198 |
+
- -0.955253541469574
|
199 |
+
- -0.748817503452301
|
200 |
+
- -0.9106167554855347
|
201 |
+
- -0.9707801342010498
|
202 |
+
- -1.053107500076294
|
203 |
+
- -1.0448424816131592
|
204 |
+
- -1.1082794666290283
|
205 |
+
- -1.1296544075012207
|
206 |
+
- -1.071642279624939
|
207 |
+
- -1.1003081798553467
|
208 |
+
- -1.166810154914856
|
209 |
+
- -1.1408926248550415
|
210 |
+
- -1.1330615282058716
|
211 |
+
- -1.1167492866516113
|
212 |
+
- -1.0716774463653564
|
213 |
+
- -1.035891056060791
|
214 |
+
- -1.0092483758926392
|
215 |
+
- -0.9675999879837036
|
216 |
+
- -0.938962996006012
|
217 |
+
- -1.0120564699172974
|
218 |
+
- -0.9777995347976685
|
219 |
+
- -1.029313564300537
|
220 |
+
- -0.9459163546562195
|
221 |
+
- -0.8519706130027771
|
222 |
+
- -0.7751091122627258
|
223 |
+
- -0.7933766841888428
|
224 |
+
- -0.9019735455513
|
225 |
+
- -0.9983296990394592
|
226 |
+
- -1.505873441696167
|
227 |
+
spec_min:
|
228 |
+
- -6.0
|
229 |
+
- -6.0
|
230 |
+
- -6.0
|
231 |
+
- -6.0
|
232 |
+
- -6.0
|
233 |
+
- -6.0
|
234 |
+
- -6.0
|
235 |
+
- -6.0
|
236 |
+
- -6.0
|
237 |
+
- -6.0
|
238 |
+
- -6.0
|
239 |
+
- -6.0
|
240 |
+
- -6.0
|
241 |
+
- -6.0
|
242 |
+
- -6.0
|
243 |
+
- -6.0
|
244 |
+
- -6.0
|
245 |
+
- -6.0
|
246 |
+
- -6.0
|
247 |
+
- -6.0
|
248 |
+
- -6.0
|
249 |
+
- -6.0
|
250 |
+
- -6.0
|
251 |
+
- -6.0
|
252 |
+
- -6.0
|
253 |
+
- -6.0
|
254 |
+
- -6.0
|
255 |
+
- -6.0
|
256 |
+
- -6.0
|
257 |
+
- -6.0
|
258 |
+
- -6.0
|
259 |
+
- -6.0
|
260 |
+
- -6.0
|
261 |
+
- -6.0
|
262 |
+
- -6.0
|
263 |
+
- -6.0
|
264 |
+
- -6.0
|
265 |
+
- -6.0
|
266 |
+
- -6.0
|
267 |
+
- -6.0
|
268 |
+
- -6.0
|
269 |
+
- -6.0
|
270 |
+
- -6.0
|
271 |
+
- -6.0
|
272 |
+
- -6.0
|
273 |
+
- -6.0
|
274 |
+
- -6.0
|
275 |
+
- -6.0
|
276 |
+
- -6.0
|
277 |
+
- -6.0
|
278 |
+
- -6.0
|
279 |
+
- -6.0
|
280 |
+
- -6.0
|
281 |
+
- -6.0
|
282 |
+
- -6.0
|
283 |
+
- -6.0
|
284 |
+
- -6.0
|
285 |
+
- -6.0
|
286 |
+
- -6.0
|
287 |
+
- -6.0
|
288 |
+
- -6.0
|
289 |
+
- -6.0
|
290 |
+
- -6.0
|
291 |
+
- -6.0
|
292 |
+
- -6.0
|
293 |
+
- -6.0
|
294 |
+
- -6.0
|
295 |
+
- -6.0
|
296 |
+
- -6.0
|
297 |
+
- -6.0
|
298 |
+
- -6.0
|
299 |
+
- -6.0
|
300 |
+
- -6.0
|
301 |
+
- -6.0
|
302 |
+
- -6.0
|
303 |
+
- -6.0
|
304 |
+
- -6.0
|
305 |
+
- -6.0
|
306 |
+
- -6.0
|
307 |
+
- -6.0
|
308 |
+
spk_cond_steps: []
|
309 |
+
stop_token_weight: 5.0
|
310 |
+
task_cls: usr.diffsinger_task.DiffSingerMIDITask
|
311 |
+
test_ids: []
|
312 |
+
test_input_dir: ''
|
313 |
+
test_num: 0
|
314 |
+
test_prefixes:
|
315 |
+
- "Alto-2#\u5C81\u6708\u795E\u5077"
|
316 |
+
- "Alto-2#\u5947\u5999\u80FD\u529B\u6B4C"
|
317 |
+
- "Tenor-1#\u4E00\u5343\u5E74\u4EE5\u540E"
|
318 |
+
- "Tenor-1#\u7AE5\u8BDD"
|
319 |
+
- "Tenor-2#\u6D88\u6101"
|
320 |
+
- "Tenor-2#\u4E00\u8364\u4E00\u7D20"
|
321 |
+
- "Soprano-1#\u5FF5\u5974\u5A07\u8D64\u58C1\u6000\u53E4"
|
322 |
+
- "Soprano-1#\u95EE\u6625"
|
323 |
+
test_set_name: test
|
324 |
+
timesteps: 1000
|
325 |
+
train_set_name: train
|
326 |
+
use_denoise: false
|
327 |
+
use_energy_embed: false
|
328 |
+
use_gt_dur: false
|
329 |
+
use_gt_f0: false
|
330 |
+
use_midi: true
|
331 |
+
use_nsf: true
|
332 |
+
use_pitch_embed: false
|
333 |
+
use_pos_embed: true
|
334 |
+
use_spk_embed: false
|
335 |
+
use_spk_id: true
|
336 |
+
use_split_spk_id: false
|
337 |
+
use_uv: true
|
338 |
+
use_var_enc: false
|
339 |
+
val_check_interval: 2000
|
340 |
+
valid_num: 0
|
341 |
+
valid_set_name: valid
|
342 |
+
vocoder: vocoders.hifigan.HifiGAN
|
343 |
+
vocoder_ckpt: checkpoints/m4singer_hifigan
|
344 |
+
warmup_updates: 2000
|
345 |
+
wav2spec_eps: 1e-6
|
346 |
+
weight_decay: 0
|
347 |
+
win_size: 512
|
348 |
+
work_dir: checkpoints/m4singer_diff_e2e
|
checkpoints/m4singer_diff_e2e/model_ckpt_steps_900000.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dbea4e8b9712d2cca54cc07915859472a17f2f3b97a86f33a6c9974192bb5b47
|
3 |
+
size 392239086
|
checkpoints/m4singer_fs2_e2e/config.yaml
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
K_step: 51
|
2 |
+
accumulate_grad_batches: 1
|
3 |
+
audio_num_mel_bins: 80
|
4 |
+
audio_sample_rate: 24000
|
5 |
+
base_config:
|
6 |
+
- configs/singing/fs2.yaml
|
7 |
+
- usr/configs/m4singer/base.yaml
|
8 |
+
binarization_args:
|
9 |
+
shuffle: false
|
10 |
+
with_align: true
|
11 |
+
with_f0: true
|
12 |
+
with_f0cwt: true
|
13 |
+
with_spk_embed: true
|
14 |
+
with_txt: true
|
15 |
+
with_wav: false
|
16 |
+
binarizer_cls: data_gen.singing.binarize.M4SingerBinarizer
|
17 |
+
binary_data_dir: data/binary/m4singer
|
18 |
+
check_val_every_n_epoch: 10
|
19 |
+
clip_grad_norm: 1
|
20 |
+
content_cond_steps: []
|
21 |
+
cwt_add_f0_loss: false
|
22 |
+
cwt_hidden_size: 128
|
23 |
+
cwt_layers: 2
|
24 |
+
cwt_loss: l1
|
25 |
+
cwt_std_scale: 0.8
|
26 |
+
datasets:
|
27 |
+
- m4singer
|
28 |
+
debug: false
|
29 |
+
dec_ffn_kernel_size: 9
|
30 |
+
dec_layers: 4
|
31 |
+
decay_steps: 50000
|
32 |
+
decoder_type: fft
|
33 |
+
dict_dir: ''
|
34 |
+
diff_decoder_type: wavenet
|
35 |
+
diff_loss_type: l1
|
36 |
+
dilation_cycle_length: 1
|
37 |
+
dropout: 0.1
|
38 |
+
ds_workers: 4
|
39 |
+
dur_enc_hidden_stride_kernel:
|
40 |
+
- 0,2,3
|
41 |
+
- 0,2,3
|
42 |
+
- 0,1,3
|
43 |
+
dur_loss: mse
|
44 |
+
dur_predictor_kernel: 3
|
45 |
+
dur_predictor_layers: 5
|
46 |
+
enc_ffn_kernel_size: 9
|
47 |
+
enc_layers: 4
|
48 |
+
encoder_K: 8
|
49 |
+
encoder_type: fft
|
50 |
+
endless_ds: true
|
51 |
+
ffn_act: gelu
|
52 |
+
ffn_padding: SAME
|
53 |
+
fft_size: 512
|
54 |
+
fmax: 12000
|
55 |
+
fmin: 30
|
56 |
+
fs2_ckpt: ''
|
57 |
+
gen_dir_name: ''
|
58 |
+
gen_tgt_spk_id: -1
|
59 |
+
hidden_size: 256
|
60 |
+
hop_size: 128
|
61 |
+
infer: false
|
62 |
+
keep_bins: 80
|
63 |
+
lambda_commit: 0.25
|
64 |
+
lambda_energy: 0.0
|
65 |
+
lambda_f0: 1.0
|
66 |
+
lambda_ph_dur: 1.0
|
67 |
+
lambda_sent_dur: 1.0
|
68 |
+
lambda_uv: 1.0
|
69 |
+
lambda_word_dur: 1.0
|
70 |
+
load_ckpt: ''
|
71 |
+
log_interval: 100
|
72 |
+
loud_norm: false
|
73 |
+
lr: 1
|
74 |
+
max_beta: 0.06
|
75 |
+
max_epochs: 1000
|
76 |
+
max_eval_sentences: 1
|
77 |
+
max_eval_tokens: 60000
|
78 |
+
max_frames: 5000
|
79 |
+
max_input_tokens: 1550
|
80 |
+
max_sentences: 12
|
81 |
+
max_tokens: 40000
|
82 |
+
max_updates: 320000
|
83 |
+
mel_loss: ssim:0.5|l1:0.5
|
84 |
+
mel_vmax: 1.5
|
85 |
+
mel_vmin: -6.0
|
86 |
+
min_level_db: -120
|
87 |
+
norm_type: gn
|
88 |
+
num_ckpt_keep: 3
|
89 |
+
num_heads: 2
|
90 |
+
num_sanity_val_steps: 1
|
91 |
+
num_spk: 20
|
92 |
+
num_test_samples: 0
|
93 |
+
num_valid_plots: 10
|
94 |
+
optimizer_adam_beta1: 0.9
|
95 |
+
optimizer_adam_beta2: 0.98
|
96 |
+
out_wav_norm: false
|
97 |
+
pe_ckpt: checkpoints/m4singer_pe
|
98 |
+
pe_enable: true
|
99 |
+
pitch_ar: false
|
100 |
+
pitch_enc_hidden_stride_kernel:
|
101 |
+
- 0,2,5
|
102 |
+
- 0,2,5
|
103 |
+
- 0,2,5
|
104 |
+
pitch_extractor: parselmouth
|
105 |
+
pitch_loss: l1
|
106 |
+
pitch_norm: log
|
107 |
+
pitch_type: frame
|
108 |
+
pre_align_args:
|
109 |
+
allow_no_txt: false
|
110 |
+
denoise: false
|
111 |
+
forced_align: mfa
|
112 |
+
txt_processor: zh_g2pM
|
113 |
+
use_sox: true
|
114 |
+
use_tone: false
|
115 |
+
pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
|
116 |
+
predictor_dropout: 0.5
|
117 |
+
predictor_grad: 0.1
|
118 |
+
predictor_hidden: -1
|
119 |
+
predictor_kernel: 5
|
120 |
+
predictor_layers: 5
|
121 |
+
prenet_dropout: 0.5
|
122 |
+
prenet_hidden_size: 256
|
123 |
+
pretrain_fs_ckpt: ''
|
124 |
+
processed_data_dir: xxx
|
125 |
+
profile_infer: false
|
126 |
+
raw_data_dir: data/raw/m4singer
|
127 |
+
ref_norm_layer: bn
|
128 |
+
rel_pos: true
|
129 |
+
reset_phone_dict: true
|
130 |
+
residual_channels: 256
|
131 |
+
residual_layers: 20
|
132 |
+
save_best: false
|
133 |
+
save_ckpt: true
|
134 |
+
save_codes:
|
135 |
+
- configs
|
136 |
+
- modules
|
137 |
+
- tasks
|
138 |
+
- utils
|
139 |
+
- usr
|
140 |
+
save_f0: true
|
141 |
+
save_gt: true
|
142 |
+
schedule_type: linear
|
143 |
+
seed: 1234
|
144 |
+
sort_by_len: true
|
145 |
+
spec_max:
|
146 |
+
- -0.3894500136375427
|
147 |
+
- -0.3796464204788208
|
148 |
+
- -0.2914905250072479
|
149 |
+
- -0.15550297498703003
|
150 |
+
- -0.08502643555402756
|
151 |
+
- 0.10698417574167252
|
152 |
+
- -0.0739326998591423
|
153 |
+
- -0.0541548952460289
|
154 |
+
- 0.15501998364925385
|
155 |
+
- 0.06483431905508041
|
156 |
+
- 0.03054228238761425
|
157 |
+
- -0.013737732544541359
|
158 |
+
- -0.004876468330621719
|
159 |
+
- 0.04368264228105545
|
160 |
+
- 0.13329921662807465
|
161 |
+
- 0.16471388936042786
|
162 |
+
- 0.04605761915445328
|
163 |
+
- -0.05680707097053528
|
164 |
+
- 0.0542571023106575
|
165 |
+
- -0.0076539707370102406
|
166 |
+
- -0.00953489076346159
|
167 |
+
- -0.04434828832745552
|
168 |
+
- 0.001293870504014194
|
169 |
+
- -0.12238839268684387
|
170 |
+
- 0.06418416649103165
|
171 |
+
- 0.02843189612030983
|
172 |
+
- 0.08505241572856903
|
173 |
+
- 0.07062800228595734
|
174 |
+
- 0.00120724702719599
|
175 |
+
- -0.07675088942050934
|
176 |
+
- 0.03785804659128189
|
177 |
+
- 0.04890783503651619
|
178 |
+
- -0.06888376921415329
|
179 |
+
- -0.0839693546295166
|
180 |
+
- -0.17545585334300995
|
181 |
+
- -0.2911079525947571
|
182 |
+
- -0.4238220453262329
|
183 |
+
- -0.262084037065506
|
184 |
+
- -0.3002263605594635
|
185 |
+
- -0.3845032751560211
|
186 |
+
- -0.3906497061252594
|
187 |
+
- -0.6550108790397644
|
188 |
+
- -0.7810799479484558
|
189 |
+
- -0.7503029704093933
|
190 |
+
- -0.7995198965072632
|
191 |
+
- -0.8092347383499146
|
192 |
+
- -0.6196113228797913
|
193 |
+
- -0.6684317588806152
|
194 |
+
- -0.7735874056816101
|
195 |
+
- -0.8324533104896545
|
196 |
+
- -0.9601566791534424
|
197 |
+
- -0.955253541469574
|
198 |
+
- -0.748817503452301
|
199 |
+
- -0.9106167554855347
|
200 |
+
- -0.9707801342010498
|
201 |
+
- -1.053107500076294
|
202 |
+
- -1.0448424816131592
|
203 |
+
- -1.1082794666290283
|
204 |
+
- -1.1296544075012207
|
205 |
+
- -1.071642279624939
|
206 |
+
- -1.1003081798553467
|
207 |
+
- -1.166810154914856
|
208 |
+
- -1.1408926248550415
|
209 |
+
- -1.1330615282058716
|
210 |
+
- -1.1167492866516113
|
211 |
+
- -1.0716774463653564
|
212 |
+
- -1.035891056060791
|
213 |
+
- -1.0092483758926392
|
214 |
+
- -0.9675999879837036
|
215 |
+
- -0.938962996006012
|
216 |
+
- -1.0120564699172974
|
217 |
+
- -0.9777995347976685
|
218 |
+
- -1.029313564300537
|
219 |
+
- -0.9459163546562195
|
220 |
+
- -0.8519706130027771
|
221 |
+
- -0.7751091122627258
|
222 |
+
- -0.7933766841888428
|
223 |
+
- -0.9019735455513
|
224 |
+
- -0.9983296990394592
|
225 |
+
- -1.505873441696167
|
226 |
+
spec_min:
|
227 |
+
- -6.0
|
228 |
+
- -6.0
|
229 |
+
- -6.0
|
230 |
+
- -6.0
|
231 |
+
- -6.0
|
232 |
+
- -6.0
|
233 |
+
- -6.0
|
234 |
+
- -6.0
|
235 |
+
- -6.0
|
236 |
+
- -6.0
|
237 |
+
- -6.0
|
238 |
+
- -6.0
|
239 |
+
- -6.0
|
240 |
+
- -6.0
|
241 |
+
- -6.0
|
242 |
+
- -6.0
|
243 |
+
- -6.0
|
244 |
+
- -6.0
|
245 |
+
- -6.0
|
246 |
+
- -6.0
|
247 |
+
- -6.0
|
248 |
+
- -6.0
|
249 |
+
- -6.0
|
250 |
+
- -6.0
|
251 |
+
- -6.0
|
252 |
+
- -6.0
|
253 |
+
- -6.0
|
254 |
+
- -6.0
|
255 |
+
- -6.0
|
256 |
+
- -6.0
|
257 |
+
- -6.0
|
258 |
+
- -6.0
|
259 |
+
- -6.0
|
260 |
+
- -6.0
|
261 |
+
- -6.0
|
262 |
+
- -6.0
|
263 |
+
- -6.0
|
264 |
+
- -6.0
|
265 |
+
- -6.0
|
266 |
+
- -6.0
|
267 |
+
- -6.0
|
268 |
+
- -6.0
|
269 |
+
- -6.0
|
270 |
+
- -6.0
|
271 |
+
- -6.0
|
272 |
+
- -6.0
|
273 |
+
- -6.0
|
274 |
+
- -6.0
|
275 |
+
- -6.0
|
276 |
+
- -6.0
|
277 |
+
- -6.0
|
278 |
+
- -6.0
|
279 |
+
- -6.0
|
280 |
+
- -6.0
|
281 |
+
- -6.0
|
282 |
+
- -6.0
|
283 |
+
- -6.0
|
284 |
+
- -6.0
|
285 |
+
- -6.0
|
286 |
+
- -6.0
|
287 |
+
- -6.0
|
288 |
+
- -6.0
|
289 |
+
- -6.0
|
290 |
+
- -6.0
|
291 |
+
- -6.0
|
292 |
+
- -6.0
|
293 |
+
- -6.0
|
294 |
+
- -6.0
|
295 |
+
- -6.0
|
296 |
+
- -6.0
|
297 |
+
- -6.0
|
298 |
+
- -6.0
|
299 |
+
- -6.0
|
300 |
+
- -6.0
|
301 |
+
- -6.0
|
302 |
+
- -6.0
|
303 |
+
- -6.0
|
304 |
+
- -6.0
|
305 |
+
- -6.0
|
306 |
+
- -6.0
|
307 |
+
spk_cond_steps: []
|
308 |
+
stop_token_weight: 5.0
|
309 |
+
task_cls: usr.diffsinger_task.AuxDecoderMIDITask
|
310 |
+
test_ids: []
|
311 |
+
test_input_dir: ''
|
312 |
+
test_num: 0
|
313 |
+
test_prefixes:
|
314 |
+
- "Alto-2#\u5C81\u6708\u795E\u5077"
|
315 |
+
- "Alto-2#\u5947\u5999\u80FD\u529B\u6B4C"
|
316 |
+
- "Tenor-1#\u4E00\u5343\u5E74\u4EE5\u540E"
|
317 |
+
- "Tenor-1#\u7AE5\u8BDD"
|
318 |
+
- "Tenor-2#\u6D88\u6101"
|
319 |
+
- "Tenor-2#\u4E00\u8364\u4E00\u7D20"
|
320 |
+
- "Soprano-1#\u5FF5\u5974\u5A07\u8D64\u58C1\u6000\u53E4"
|
321 |
+
- "Soprano-1#\u95EE\u6625"
|
322 |
+
test_set_name: test
|
323 |
+
timesteps: 100
|
324 |
+
train_set_name: train
|
325 |
+
use_denoise: false
|
326 |
+
use_energy_embed: false
|
327 |
+
use_gt_dur: false
|
328 |
+
use_gt_f0: false
|
329 |
+
use_midi: true
|
330 |
+
use_nsf: true
|
331 |
+
use_pitch_embed: false
|
332 |
+
use_pos_embed: true
|
333 |
+
use_spk_embed: false
|
334 |
+
use_spk_id: true
|
335 |
+
use_split_spk_id: false
|
336 |
+
use_uv: true
|
337 |
+
use_var_enc: false
|
338 |
+
val_check_interval: 2000
|
339 |
+
valid_num: 0
|
340 |
+
valid_set_name: valid
|
341 |
+
vocoder: vocoders.hifigan.HifiGAN
|
342 |
+
vocoder_ckpt: checkpoints/m4singer_hifigan
|
343 |
+
warmup_updates: 2000
|
344 |
+
wav2spec_eps: 1e-6
|
345 |
+
weight_decay: 0
|
346 |
+
win_size: 512
|
347 |
+
work_dir: checkpoints/m4singer_fs2_e2e
|
checkpoints/m4singer_fs2_e2e/model_ckpt_steps_320000.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:993d7063a1773bd29d2810591f98152218a4cf8440e2b10c4761516a28f9d566
|
3 |
+
size 290456153
|
checkpoints/m4singer_hifigan/config.yaml
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
max_eval_tokens: 60000
|
2 |
+
max_eval_sentences: 1
|
3 |
+
save_ckpt: true
|
4 |
+
log_interval: 100
|
5 |
+
accumulate_grad_batches: 1
|
6 |
+
adam_b1: 0.8
|
7 |
+
adam_b2: 0.99
|
8 |
+
amp: false
|
9 |
+
audio_num_mel_bins: 80
|
10 |
+
audio_sample_rate: 24000
|
11 |
+
aux_context_window: 0
|
12 |
+
#base_config:
|
13 |
+
#- egs/egs_bases/singing/pwg.yaml
|
14 |
+
#- egs/egs_bases/tts/vocoder/hifigan.yaml
|
15 |
+
binarization_args:
|
16 |
+
reset_phone_dict: true
|
17 |
+
reset_word_dict: true
|
18 |
+
shuffle: false
|
19 |
+
trim_eos_bos: false
|
20 |
+
trim_sil: false
|
21 |
+
with_align: false
|
22 |
+
with_f0: true
|
23 |
+
with_f0cwt: false
|
24 |
+
with_linear: false
|
25 |
+
with_spk_embed: false
|
26 |
+
with_spk_id: true
|
27 |
+
with_txt: false
|
28 |
+
with_wav: true
|
29 |
+
with_word: false
|
30 |
+
binarizer_cls: data_gen.tts.singing.binarize.SingingBinarizer
|
31 |
+
binary_data_dir: data/binary/m4singer_vocoder
|
32 |
+
check_val_every_n_epoch: 10
|
33 |
+
clip_grad_norm: 1
|
34 |
+
clip_grad_value: 0
|
35 |
+
datasets: []
|
36 |
+
debug: false
|
37 |
+
dec_ffn_kernel_size: 9
|
38 |
+
dec_layers: 4
|
39 |
+
dict_dir: ''
|
40 |
+
disc_start_steps: 40000
|
41 |
+
discriminator_grad_norm: 1
|
42 |
+
discriminator_optimizer_params:
|
43 |
+
eps: 1.0e-06
|
44 |
+
lr: 0.0002
|
45 |
+
weight_decay: 0.0
|
46 |
+
discriminator_params:
|
47 |
+
bias: true
|
48 |
+
conv_channels: 64
|
49 |
+
in_channels: 1
|
50 |
+
kernel_size: 3
|
51 |
+
layers: 10
|
52 |
+
nonlinear_activation: LeakyReLU
|
53 |
+
nonlinear_activation_params:
|
54 |
+
negative_slope: 0.2
|
55 |
+
out_channels: 1
|
56 |
+
use_weight_norm: true
|
57 |
+
discriminator_scheduler_params:
|
58 |
+
gamma: 0.999
|
59 |
+
step_size: 600
|
60 |
+
dropout: 0.1
|
61 |
+
ds_workers: 1
|
62 |
+
enc_ffn_kernel_size: 9
|
63 |
+
enc_layers: 4
|
64 |
+
endless_ds: true
|
65 |
+
ffn_act: gelu
|
66 |
+
ffn_padding: SAME
|
67 |
+
fft_size: 512
|
68 |
+
fmax: 12000
|
69 |
+
fmin: 30
|
70 |
+
frames_multiple: 1
|
71 |
+
gen_dir_name: ''
|
72 |
+
generator_grad_norm: 10
|
73 |
+
generator_optimizer_params:
|
74 |
+
eps: 1.0e-06
|
75 |
+
lr: 0.0002
|
76 |
+
weight_decay: 0.0
|
77 |
+
generator_params:
|
78 |
+
aux_context_window: 0
|
79 |
+
aux_channels: 80
|
80 |
+
dropout: 0.0
|
81 |
+
gate_channels: 128
|
82 |
+
in_channels: 1
|
83 |
+
kernel_size: 3
|
84 |
+
layers: 30
|
85 |
+
out_channels: 1
|
86 |
+
residual_channels: 64
|
87 |
+
skip_channels: 64
|
88 |
+
stacks: 3
|
89 |
+
upsample_net: ConvInUpsampleNetwork
|
90 |
+
upsample_params:
|
91 |
+
upsample_scales:
|
92 |
+
- 2
|
93 |
+
- 4
|
94 |
+
- 4
|
95 |
+
- 4
|
96 |
+
use_nsf: false
|
97 |
+
use_pitch_embed: true
|
98 |
+
use_weight_norm: true
|
99 |
+
generator_scheduler_params:
|
100 |
+
gamma: 0.999
|
101 |
+
step_size: 600
|
102 |
+
griffin_lim_iters: 60
|
103 |
+
hidden_size: 256
|
104 |
+
hop_size: 128
|
105 |
+
infer: false
|
106 |
+
lambda_adv: 1.0
|
107 |
+
lambda_cdisc: 4.0
|
108 |
+
lambda_energy: 0.0
|
109 |
+
lambda_f0: 0.0
|
110 |
+
lambda_mel: 5.0
|
111 |
+
lambda_mel_adv: 1.0
|
112 |
+
lambda_ph_dur: 0.0
|
113 |
+
lambda_sent_dur: 0.0
|
114 |
+
lambda_uv: 0.0
|
115 |
+
lambda_word_dur: 0.0
|
116 |
+
load_ckpt: 'checkpoints/m4singer_hifigan'
|
117 |
+
loud_norm: false
|
118 |
+
lr: 2.0
|
119 |
+
max_epochs: 1000
|
120 |
+
max_frames: 2400
|
121 |
+
max_input_tokens: 1550
|
122 |
+
max_samples: 8192
|
123 |
+
max_sentences: 20
|
124 |
+
max_tokens: 24000
|
125 |
+
max_updates: 3000000
|
126 |
+
max_valid_sentences: 1
|
127 |
+
max_valid_tokens: 60000
|
128 |
+
mel_loss: ssim:0.5|l1:0.5
|
129 |
+
mel_vmax: 1.5
|
130 |
+
mel_vmin: -6
|
131 |
+
min_frames: 0
|
132 |
+
min_level_db: -120
|
133 |
+
num_ckpt_keep: 3
|
134 |
+
num_heads: 2
|
135 |
+
num_mels: 80
|
136 |
+
num_sanity_val_steps: 5
|
137 |
+
num_spk: 100
|
138 |
+
num_test_samples: 0
|
139 |
+
num_valid_plots: 10
|
140 |
+
optimizer_adam_beta1: 0.9
|
141 |
+
optimizer_adam_beta2: 0.98
|
142 |
+
out_wav_norm: false
|
143 |
+
pitch_extractor: parselmouth
|
144 |
+
pitch_type: frame
|
145 |
+
pre_align_args:
|
146 |
+
allow_no_txt: false
|
147 |
+
denoise: false
|
148 |
+
sox_resample: true
|
149 |
+
sox_to_wav: false
|
150 |
+
trim_sil: false
|
151 |
+
txt_processor: zh
|
152 |
+
use_tone: false
|
153 |
+
pre_align_cls: data_gen.tts.singing.pre_align.SingingPreAlign
|
154 |
+
predictor_grad: 0.0
|
155 |
+
print_nan_grads: false
|
156 |
+
processed_data_dir: ''
|
157 |
+
profile_infer: false
|
158 |
+
raw_data_dir: ''
|
159 |
+
ref_level_db: 20
|
160 |
+
rename_tmux: true
|
161 |
+
rerun_gen: true
|
162 |
+
resblock: '1'
|
163 |
+
resblock_dilation_sizes:
|
164 |
+
- - 1
|
165 |
+
- 3
|
166 |
+
- 5
|
167 |
+
- - 1
|
168 |
+
- 3
|
169 |
+
- 5
|
170 |
+
- - 1
|
171 |
+
- 3
|
172 |
+
- 5
|
173 |
+
resblock_kernel_sizes:
|
174 |
+
- 3
|
175 |
+
- 7
|
176 |
+
- 11
|
177 |
+
resume_from_checkpoint: 0
|
178 |
+
save_best: true
|
179 |
+
save_codes: []
|
180 |
+
save_f0: true
|
181 |
+
save_gt: true
|
182 |
+
scheduler: rsqrt
|
183 |
+
seed: 1234
|
184 |
+
sort_by_len: true
|
185 |
+
stft_loss_params:
|
186 |
+
fft_sizes:
|
187 |
+
- 1024
|
188 |
+
- 2048
|
189 |
+
- 512
|
190 |
+
hop_sizes:
|
191 |
+
- 120
|
192 |
+
- 240
|
193 |
+
- 50
|
194 |
+
win_lengths:
|
195 |
+
- 600
|
196 |
+
- 1200
|
197 |
+
- 240
|
198 |
+
window: hann_window
|
199 |
+
task_cls: tasks.vocoder.hifigan.HifiGanTask
|
200 |
+
tb_log_interval: 100
|
201 |
+
test_ids: []
|
202 |
+
test_input_dir: ''
|
203 |
+
test_num: 50
|
204 |
+
test_prefixes: []
|
205 |
+
test_set_name: test
|
206 |
+
train_set_name: train
|
207 |
+
train_sets: ''
|
208 |
+
upsample_initial_channel: 512
|
209 |
+
upsample_kernel_sizes:
|
210 |
+
- 16
|
211 |
+
- 16
|
212 |
+
- 4
|
213 |
+
- 4
|
214 |
+
upsample_rates:
|
215 |
+
- 8
|
216 |
+
- 4
|
217 |
+
- 2
|
218 |
+
- 2
|
219 |
+
use_cdisc: false
|
220 |
+
use_cond_disc: false
|
221 |
+
use_fm_loss: false
|
222 |
+
use_gt_dur: true
|
223 |
+
use_gt_f0: true
|
224 |
+
use_mel_loss: true
|
225 |
+
use_ms_stft: false
|
226 |
+
use_pitch_embed: true
|
227 |
+
use_ref_enc: true
|
228 |
+
use_spec_disc: false
|
229 |
+
use_spk_embed: false
|
230 |
+
use_spk_id: false
|
231 |
+
use_split_spk_id: false
|
232 |
+
val_check_interval: 2000
|
233 |
+
valid_infer_interval: 10000
|
234 |
+
valid_monitor_key: val_loss
|
235 |
+
valid_monitor_mode: min
|
236 |
+
valid_set_name: valid
|
237 |
+
vocoder: pwg
|
238 |
+
vocoder_ckpt: ''
|
239 |
+
vocoder_denoise_c: 0.0
|
240 |
+
warmup_updates: 8000
|
241 |
+
weight_decay: 0
|
242 |
+
win_length: null
|
243 |
+
win_size: 512
|
244 |
+
window: hann
|
245 |
+
word_size: 3000
|
246 |
+
work_dir: checkpoints/m4singer_hifigan
|
checkpoints/m4singer_hifigan/model_ckpt_steps_1970000.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c3e859bd2b1e125fe661aedfd6fa3e97e10e06f3ec3d03b7735a041984402f89
|
3 |
+
size 1016324099
|
checkpoints/m4singer_pe/config.yaml
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accumulate_grad_batches: 1
|
2 |
+
audio_num_mel_bins: 80
|
3 |
+
audio_sample_rate: 24000
|
4 |
+
base_config:
|
5 |
+
- configs/tts/lj/fs2.yaml
|
6 |
+
binarization_args:
|
7 |
+
shuffle: false
|
8 |
+
with_align: true
|
9 |
+
with_f0: true
|
10 |
+
with_f0cwt: true
|
11 |
+
with_spk_embed: true
|
12 |
+
with_txt: true
|
13 |
+
with_wav: false
|
14 |
+
binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
|
15 |
+
binary_data_dir: data/binary/m4singer
|
16 |
+
check_val_every_n_epoch: 10
|
17 |
+
clip_grad_norm: 1
|
18 |
+
cwt_add_f0_loss: false
|
19 |
+
cwt_hidden_size: 128
|
20 |
+
cwt_layers: 2
|
21 |
+
cwt_loss: l1
|
22 |
+
cwt_std_scale: 0.8
|
23 |
+
debug: false
|
24 |
+
dec_ffn_kernel_size: 9
|
25 |
+
dec_layers: 4
|
26 |
+
decoder_type: fft
|
27 |
+
dict_dir: ''
|
28 |
+
dropout: 0.1
|
29 |
+
ds_workers: 4
|
30 |
+
dur_enc_hidden_stride_kernel:
|
31 |
+
- 0,2,3
|
32 |
+
- 0,2,3
|
33 |
+
- 0,1,3
|
34 |
+
dur_loss: mse
|
35 |
+
dur_predictor_kernel: 3
|
36 |
+
dur_predictor_layers: 2
|
37 |
+
enc_ffn_kernel_size: 9
|
38 |
+
enc_layers: 4
|
39 |
+
encoder_K: 8
|
40 |
+
encoder_type: fft
|
41 |
+
endless_ds: true
|
42 |
+
ffn_act: gelu
|
43 |
+
ffn_padding: SAME
|
44 |
+
fft_size: 512
|
45 |
+
fmax: 12000
|
46 |
+
fmin: 30
|
47 |
+
gen_dir_name: ''
|
48 |
+
hidden_size: 256
|
49 |
+
hop_size: 128
|
50 |
+
infer: false
|
51 |
+
lambda_commit: 0.25
|
52 |
+
lambda_energy: 0.1
|
53 |
+
lambda_f0: 1.0
|
54 |
+
lambda_ph_dur: 1.0
|
55 |
+
lambda_sent_dur: 1.0
|
56 |
+
lambda_uv: 1.0
|
57 |
+
lambda_word_dur: 1.0
|
58 |
+
load_ckpt: ''
|
59 |
+
log_interval: 100
|
60 |
+
loud_norm: false
|
61 |
+
lr: 0.1
|
62 |
+
max_epochs: 1000
|
63 |
+
max_eval_sentences: 1
|
64 |
+
max_eval_tokens: 60000
|
65 |
+
max_frames: 5000
|
66 |
+
max_input_tokens: 1550
|
67 |
+
max_sentences: 100000
|
68 |
+
max_tokens: 20000
|
69 |
+
max_updates: 280000
|
70 |
+
mel_loss: l1
|
71 |
+
mel_vmax: 1.5
|
72 |
+
mel_vmin: -6
|
73 |
+
min_level_db: -120
|
74 |
+
norm_type: gn
|
75 |
+
num_ckpt_keep: 3
|
76 |
+
num_heads: 2
|
77 |
+
num_sanity_val_steps: 5
|
78 |
+
num_spk: 1
|
79 |
+
num_test_samples: 20
|
80 |
+
num_valid_plots: 10
|
81 |
+
optimizer_adam_beta1: 0.9
|
82 |
+
optimizer_adam_beta2: 0.98
|
83 |
+
out_wav_norm: false
|
84 |
+
pitch_ar: false
|
85 |
+
pitch_enc_hidden_stride_kernel:
|
86 |
+
- 0,2,5
|
87 |
+
- 0,2,5
|
88 |
+
- 0,2,5
|
89 |
+
pitch_extractor_conv_layers: 2
|
90 |
+
pitch_loss: l1
|
91 |
+
pitch_norm: log
|
92 |
+
pitch_type: frame
|
93 |
+
pre_align_args:
|
94 |
+
allow_no_txt: false
|
95 |
+
denoise: false
|
96 |
+
forced_align: mfa
|
97 |
+
txt_processor: en
|
98 |
+
use_sox: false
|
99 |
+
use_tone: true
|
100 |
+
pre_align_cls: data_gen.tts.lj.pre_align.LJPreAlign
|
101 |
+
predictor_dropout: 0.5
|
102 |
+
predictor_grad: 0.1
|
103 |
+
predictor_hidden: -1
|
104 |
+
predictor_kernel: 5
|
105 |
+
predictor_layers: 2
|
106 |
+
prenet_dropout: 0.5
|
107 |
+
prenet_hidden_size: 256
|
108 |
+
pretrain_fs_ckpt: ''
|
109 |
+
processed_data_dir: data/processed/ljspeech
|
110 |
+
profile_infer: false
|
111 |
+
raw_data_dir: data/raw/LJSpeech-1.1
|
112 |
+
ref_norm_layer: bn
|
113 |
+
reset_phone_dict: true
|
114 |
+
save_best: false
|
115 |
+
save_ckpt: true
|
116 |
+
save_codes:
|
117 |
+
- configs
|
118 |
+
- modules
|
119 |
+
- tasks
|
120 |
+
- utils
|
121 |
+
- usr
|
122 |
+
save_f0: false
|
123 |
+
save_gt: false
|
124 |
+
seed: 1234
|
125 |
+
sort_by_len: true
|
126 |
+
stop_token_weight: 5.0
|
127 |
+
task_cls: tasks.tts.pe.PitchExtractionTask
|
128 |
+
test_ids:
|
129 |
+
- 68
|
130 |
+
- 70
|
131 |
+
- 74
|
132 |
+
- 87
|
133 |
+
- 110
|
134 |
+
- 172
|
135 |
+
- 190
|
136 |
+
- 215
|
137 |
+
- 231
|
138 |
+
- 294
|
139 |
+
- 316
|
140 |
+
- 324
|
141 |
+
- 402
|
142 |
+
- 422
|
143 |
+
- 485
|
144 |
+
- 500
|
145 |
+
- 505
|
146 |
+
- 508
|
147 |
+
- 509
|
148 |
+
- 519
|
149 |
+
test_input_dir: ''
|
150 |
+
test_num: 523
|
151 |
+
test_set_name: test
|
152 |
+
train_set_name: train
|
153 |
+
use_denoise: false
|
154 |
+
use_energy_embed: false
|
155 |
+
use_gt_dur: false
|
156 |
+
use_gt_f0: false
|
157 |
+
use_pitch_embed: true
|
158 |
+
use_pos_embed: true
|
159 |
+
use_spk_embed: false
|
160 |
+
use_spk_id: false
|
161 |
+
use_split_spk_id: false
|
162 |
+
use_uv: true
|
163 |
+
use_var_enc: false
|
164 |
+
val_check_interval: 2000
|
165 |
+
valid_num: 348
|
166 |
+
valid_set_name: valid
|
167 |
+
vocoder: pwg
|
168 |
+
vocoder_ckpt: ''
|
169 |
+
warmup_updates: 2000
|
170 |
+
weight_decay: 0
|
171 |
+
win_size: 512
|
172 |
+
work_dir: checkpoints/m4singer_pe
|
checkpoints/m4singer_pe/model_ckpt_steps_280000.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10cbf382bf82ecf335fbf68ba226f93c9c715b0476f6604351cbad9783f529fe
|
3 |
+
size 39146292
|
configs/config_base.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# task
|
2 |
+
binary_data_dir: ''
|
3 |
+
work_dir: '' # experiment directory.
|
4 |
+
infer: false # infer
|
5 |
+
seed: 1234
|
6 |
+
debug: false
|
7 |
+
save_codes:
|
8 |
+
- configs
|
9 |
+
- modules
|
10 |
+
- tasks
|
11 |
+
- utils
|
12 |
+
- usr
|
13 |
+
|
14 |
+
#############
|
15 |
+
# dataset
|
16 |
+
#############
|
17 |
+
ds_workers: 1
|
18 |
+
test_num: 100
|
19 |
+
valid_num: 100
|
20 |
+
endless_ds: false
|
21 |
+
sort_by_len: true
|
22 |
+
|
23 |
+
#########
|
24 |
+
# train and eval
|
25 |
+
#########
|
26 |
+
load_ckpt: ''
|
27 |
+
save_ckpt: true
|
28 |
+
save_best: false
|
29 |
+
num_ckpt_keep: 3
|
30 |
+
clip_grad_norm: 0
|
31 |
+
accumulate_grad_batches: 1
|
32 |
+
log_interval: 100
|
33 |
+
num_sanity_val_steps: 5 # steps of validation at the beginning
|
34 |
+
check_val_every_n_epoch: 10
|
35 |
+
val_check_interval: 2000
|
36 |
+
max_epochs: 1000
|
37 |
+
max_updates: 160000
|
38 |
+
max_tokens: 31250
|
39 |
+
max_sentences: 100000
|
40 |
+
max_eval_tokens: -1
|
41 |
+
max_eval_sentences: -1
|
42 |
+
test_input_dir: ''
|
configs/singing/base.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/tts/base.yaml
|
3 |
+
- configs/tts/base_zh.yaml
|
4 |
+
|
5 |
+
|
6 |
+
datasets: []
|
7 |
+
test_prefixes: []
|
8 |
+
test_num: 0
|
9 |
+
valid_num: 0
|
10 |
+
|
11 |
+
pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
|
12 |
+
binarizer_cls: data_gen.singing.binarize.SingingBinarizer
|
13 |
+
pre_align_args:
|
14 |
+
use_tone: false # for ZH
|
15 |
+
forced_align: mfa
|
16 |
+
use_sox: true
|
17 |
+
hop_size: 128 # Hop size.
|
18 |
+
fft_size: 512 # FFT size.
|
19 |
+
win_size: 512 # FFT size.
|
20 |
+
max_frames: 8000
|
21 |
+
fmin: 50 # Minimum freq in mel basis calculation.
|
22 |
+
fmax: 11025 # Maximum frequency in mel basis calculation.
|
23 |
+
pitch_type: frame
|
24 |
+
|
25 |
+
hidden_size: 256
|
26 |
+
mel_loss: "ssim:0.5|l1:0.5"
|
27 |
+
lambda_f0: 0.0
|
28 |
+
lambda_uv: 0.0
|
29 |
+
lambda_energy: 0.0
|
30 |
+
lambda_ph_dur: 0.0
|
31 |
+
lambda_sent_dur: 0.0
|
32 |
+
lambda_word_dur: 0.0
|
33 |
+
predictor_grad: 0.0
|
34 |
+
use_spk_embed: true
|
35 |
+
use_spk_id: false
|
36 |
+
|
37 |
+
max_tokens: 20000
|
38 |
+
max_updates: 400000
|
39 |
+
num_spk: 100
|
40 |
+
save_f0: true
|
41 |
+
use_gt_dur: true
|
42 |
+
use_gt_f0: true
|
configs/singing/fs2.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/tts/fs2.yaml
|
3 |
+
- configs/singing/base.yaml
|
configs/tts/base.yaml
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# task
|
2 |
+
base_config: configs/config_base.yaml
|
3 |
+
task_cls: ''
|
4 |
+
#############
|
5 |
+
# dataset
|
6 |
+
#############
|
7 |
+
raw_data_dir: ''
|
8 |
+
processed_data_dir: ''
|
9 |
+
binary_data_dir: ''
|
10 |
+
dict_dir: ''
|
11 |
+
pre_align_cls: ''
|
12 |
+
binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
|
13 |
+
pre_align_args:
|
14 |
+
use_tone: true # for ZH
|
15 |
+
forced_align: mfa
|
16 |
+
use_sox: false
|
17 |
+
txt_processor: en
|
18 |
+
allow_no_txt: false
|
19 |
+
denoise: false
|
20 |
+
binarization_args:
|
21 |
+
shuffle: false
|
22 |
+
with_txt: true
|
23 |
+
with_wav: false
|
24 |
+
with_align: true
|
25 |
+
with_spk_embed: true
|
26 |
+
with_f0: true
|
27 |
+
with_f0cwt: true
|
28 |
+
|
29 |
+
loud_norm: false
|
30 |
+
endless_ds: true
|
31 |
+
reset_phone_dict: true
|
32 |
+
|
33 |
+
test_num: 100
|
34 |
+
valid_num: 100
|
35 |
+
max_frames: 1550
|
36 |
+
max_input_tokens: 1550
|
37 |
+
audio_num_mel_bins: 80
|
38 |
+
audio_sample_rate: 22050
|
39 |
+
hop_size: 256 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
|
40 |
+
win_size: 1024 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
|
41 |
+
fmin: 80 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
42 |
+
fmax: 7600 # To be increased/reduced depending on data.
|
43 |
+
fft_size: 1024 # Extra window size is filled with 0 paddings to match this parameter
|
44 |
+
min_level_db: -100
|
45 |
+
num_spk: 1
|
46 |
+
mel_vmin: -6
|
47 |
+
mel_vmax: 1.5
|
48 |
+
ds_workers: 4
|
49 |
+
|
50 |
+
#########
|
51 |
+
# model
|
52 |
+
#########
|
53 |
+
dropout: 0.1
|
54 |
+
enc_layers: 4
|
55 |
+
dec_layers: 4
|
56 |
+
hidden_size: 384
|
57 |
+
num_heads: 2
|
58 |
+
prenet_dropout: 0.5
|
59 |
+
prenet_hidden_size: 256
|
60 |
+
stop_token_weight: 5.0
|
61 |
+
enc_ffn_kernel_size: 9
|
62 |
+
dec_ffn_kernel_size: 9
|
63 |
+
ffn_act: gelu
|
64 |
+
ffn_padding: 'SAME'
|
65 |
+
|
66 |
+
|
67 |
+
###########
|
68 |
+
# optimization
|
69 |
+
###########
|
70 |
+
lr: 2.0
|
71 |
+
warmup_updates: 8000
|
72 |
+
optimizer_adam_beta1: 0.9
|
73 |
+
optimizer_adam_beta2: 0.98
|
74 |
+
weight_decay: 0
|
75 |
+
clip_grad_norm: 1
|
76 |
+
|
77 |
+
|
78 |
+
###########
|
79 |
+
# train and eval
|
80 |
+
###########
|
81 |
+
max_tokens: 30000
|
82 |
+
max_sentences: 100000
|
83 |
+
max_eval_sentences: 1
|
84 |
+
max_eval_tokens: 60000
|
85 |
+
train_set_name: 'train'
|
86 |
+
valid_set_name: 'valid'
|
87 |
+
test_set_name: 'test'
|
88 |
+
vocoder: pwg
|
89 |
+
vocoder_ckpt: ''
|
90 |
+
profile_infer: false
|
91 |
+
out_wav_norm: false
|
92 |
+
save_gt: false
|
93 |
+
save_f0: false
|
94 |
+
gen_dir_name: ''
|
95 |
+
use_denoise: false
|
configs/tts/base_zh.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
pre_align_args:
|
2 |
+
txt_processor: zh_g2pM
|
3 |
+
binarizer_cls: data_gen.tts.binarizer_zh.ZhBinarizer
|
configs/tts/fs2.yaml
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config: configs/tts/base.yaml
|
2 |
+
task_cls: tasks.tts.fs2.FastSpeech2Task
|
3 |
+
|
4 |
+
# model
|
5 |
+
hidden_size: 256
|
6 |
+
dropout: 0.1
|
7 |
+
encoder_type: fft # fft|tacotron|tacotron2|conformer
|
8 |
+
encoder_K: 8 # for tacotron encoder
|
9 |
+
decoder_type: fft # fft|rnn|conv|conformer
|
10 |
+
use_pos_embed: true
|
11 |
+
|
12 |
+
# duration
|
13 |
+
predictor_hidden: -1
|
14 |
+
predictor_kernel: 5
|
15 |
+
predictor_layers: 2
|
16 |
+
dur_predictor_kernel: 3
|
17 |
+
dur_predictor_layers: 2
|
18 |
+
predictor_dropout: 0.5
|
19 |
+
|
20 |
+
# pitch and energy
|
21 |
+
use_pitch_embed: true
|
22 |
+
pitch_type: ph # frame|ph|cwt
|
23 |
+
use_uv: true
|
24 |
+
cwt_hidden_size: 128
|
25 |
+
cwt_layers: 2
|
26 |
+
cwt_loss: l1
|
27 |
+
cwt_add_f0_loss: false
|
28 |
+
cwt_std_scale: 0.8
|
29 |
+
|
30 |
+
pitch_ar: false
|
31 |
+
#pitch_embed_type: 0q
|
32 |
+
pitch_loss: 'l1' # l1|l2|ssim
|
33 |
+
pitch_norm: log
|
34 |
+
use_energy_embed: false
|
35 |
+
|
36 |
+
# reference encoder and speaker embedding
|
37 |
+
use_spk_id: false
|
38 |
+
use_split_spk_id: false
|
39 |
+
use_spk_embed: false
|
40 |
+
use_var_enc: false
|
41 |
+
lambda_commit: 0.25
|
42 |
+
ref_norm_layer: bn
|
43 |
+
pitch_enc_hidden_stride_kernel:
|
44 |
+
- 0,2,5 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
|
45 |
+
- 0,2,5
|
46 |
+
- 0,2,5
|
47 |
+
dur_enc_hidden_stride_kernel:
|
48 |
+
- 0,2,3 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size
|
49 |
+
- 0,2,3
|
50 |
+
- 0,1,3
|
51 |
+
|
52 |
+
|
53 |
+
# mel
|
54 |
+
mel_loss: l1:0.5|ssim:0.5 # l1|l2|gdl|ssim or l1:0.5|ssim:0.5
|
55 |
+
|
56 |
+
# loss lambda
|
57 |
+
lambda_f0: 1.0
|
58 |
+
lambda_uv: 1.0
|
59 |
+
lambda_energy: 0.1
|
60 |
+
lambda_ph_dur: 1.0
|
61 |
+
lambda_sent_dur: 1.0
|
62 |
+
lambda_word_dur: 1.0
|
63 |
+
predictor_grad: 0.1
|
64 |
+
|
65 |
+
# train and eval
|
66 |
+
pretrain_fs_ckpt: ''
|
67 |
+
warmup_updates: 2000
|
68 |
+
max_tokens: 32000
|
69 |
+
max_sentences: 100000
|
70 |
+
max_eval_sentences: 1
|
71 |
+
max_updates: 120000
|
72 |
+
num_valid_plots: 5
|
73 |
+
num_test_samples: 0
|
74 |
+
test_ids: []
|
75 |
+
use_gt_dur: false
|
76 |
+
use_gt_f0: false
|
77 |
+
|
78 |
+
# exp
|
79 |
+
dur_loss: mse # huber|mol
|
80 |
+
norm_type: gn
|
configs/tts/hifigan.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config: configs/tts/pwg.yaml
|
2 |
+
task_cls: tasks.vocoder.hifigan.HifiGanTask
|
3 |
+
resblock: "1"
|
4 |
+
adam_b1: 0.8
|
5 |
+
adam_b2: 0.99
|
6 |
+
upsample_rates: [ 8,8,2,2 ]
|
7 |
+
upsample_kernel_sizes: [ 16,16,4,4 ]
|
8 |
+
upsample_initial_channel: 128
|
9 |
+
resblock_kernel_sizes: [ 3,7,11 ]
|
10 |
+
resblock_dilation_sizes: [ [ 1,3,5 ], [ 1,3,5 ], [ 1,3,5 ] ]
|
11 |
+
|
12 |
+
lambda_mel: 45.0
|
13 |
+
|
14 |
+
max_samples: 8192
|
15 |
+
max_sentences: 16
|
16 |
+
|
17 |
+
generator_params:
|
18 |
+
lr: 0.0002 # Generator's learning rate.
|
19 |
+
aux_context_window: 0 # Context window size for auxiliary feature.
|
20 |
+
discriminator_optimizer_params:
|
21 |
+
lr: 0.0002 # Discriminator's learning rate.
|
configs/tts/lj/base_mel2wav.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
raw_data_dir: 'data/raw/LJSpeech-1.1'
|
2 |
+
processed_data_dir: 'data/processed/ljspeech'
|
3 |
+
binary_data_dir: 'data/binary/ljspeech_wav'
|
configs/tts/lj/base_text2mel.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
raw_data_dir: 'data/raw/LJSpeech-1.1'
|
2 |
+
processed_data_dir: 'data/processed/ljspeech'
|
3 |
+
binary_data_dir: 'data/binary/ljspeech'
|
4 |
+
pre_align_cls: data_gen.tts.lj.pre_align.LJPreAlign
|
5 |
+
|
6 |
+
pitch_type: cwt
|
7 |
+
mel_loss: l1
|
8 |
+
num_test_samples: 20
|
9 |
+
test_ids: [ 68, 70, 74, 87, 110, 172, 190, 215, 231, 294,
|
10 |
+
316, 324, 402, 422, 485, 500, 505, 508, 509, 519 ]
|
11 |
+
use_energy_embed: false
|
12 |
+
test_num: 523
|
13 |
+
valid_num: 348
|
configs/tts/lj/fs2.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/tts/fs2.yaml
|
3 |
+
- configs/tts/lj/base_text2mel.yaml
|
configs/tts/lj/hifigan.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/tts/hifigan.yaml
|
3 |
+
- configs/tts/lj/base_mel2wav.yaml
|
configs/tts/lj/pwg.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
base_config:
|
2 |
+
- configs/tts/pwg.yaml
|
3 |
+
- configs/tts/lj/base_mel2wav.yaml
|
configs/tts/pwg.yaml
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_config: configs/tts/base.yaml
|
2 |
+
task_cls: tasks.vocoder.pwg.PwgTask
|
3 |
+
|
4 |
+
binarization_args:
|
5 |
+
with_wav: true
|
6 |
+
with_spk_embed: false
|
7 |
+
with_align: false
|
8 |
+
test_input_dir: ''
|
9 |
+
|
10 |
+
###########
|
11 |
+
# train and eval
|
12 |
+
###########
|
13 |
+
max_samples: 25600
|
14 |
+
max_sentences: 5
|
15 |
+
max_eval_sentences: 1
|
16 |
+
max_updates: 1000000
|
17 |
+
val_check_interval: 2000
|
18 |
+
|
19 |
+
|
20 |
+
###########################################################
|
21 |
+
# FEATURE EXTRACTION SETTING #
|
22 |
+
###########################################################
|
23 |
+
sampling_rate: 22050 # Sampling rate.
|
24 |
+
fft_size: 1024 # FFT size.
|
25 |
+
hop_size: 256 # Hop size.
|
26 |
+
win_length: null # Window length.
|
27 |
+
# If set to null, it will be the same as fft_size.
|
28 |
+
window: "hann" # Window function.
|
29 |
+
num_mels: 80 # Number of mel basis.
|
30 |
+
fmin: 80 # Minimum freq in mel basis calculation.
|
31 |
+
fmax: 7600 # Maximum frequency in mel basis calculation.
|
32 |
+
format: "hdf5" # Feature file format. "npy" or "hdf5" is supported.
|
33 |
+
|
34 |
+
###########################################################
|
35 |
+
# GENERATOR NETWORK ARCHITECTURE SETTING #
|
36 |
+
###########################################################
|
37 |
+
generator_params:
|
38 |
+
in_channels: 1 # Number of input channels.
|
39 |
+
out_channels: 1 # Number of output channels.
|
40 |
+
kernel_size: 3 # Kernel size of dilated convolution.
|
41 |
+
layers: 30 # Number of residual block layers.
|
42 |
+
stacks: 3 # Number of stacks i.e., dilation cycles.
|
43 |
+
residual_channels: 64 # Number of channels in residual conv.
|
44 |
+
gate_channels: 128 # Number of channels in gated conv.
|
45 |
+
skip_channels: 64 # Number of channels in skip conv.
|
46 |
+
aux_channels: 80 # Number of channels for auxiliary feature conv.
|
47 |
+
# Must be the same as num_mels.
|
48 |
+
aux_context_window: 2 # Context window size for auxiliary feature.
|
49 |
+
# If set to 2, previous 2 and future 2 frames will be considered.
|
50 |
+
dropout: 0.0 # Dropout rate. 0.0 means no dropout applied.
|
51 |
+
use_weight_norm: true # Whether to use weight norm.
|
52 |
+
# If set to true, it will be applied to all of the conv layers.
|
53 |
+
upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture.
|
54 |
+
upsample_params: # Upsampling network parameters.
|
55 |
+
upsample_scales: [4, 4, 4, 4] # Upsampling scales. Prodcut of these must be the same as hop size.
|
56 |
+
use_pitch_embed: false
|
57 |
+
|
58 |
+
###########################################################
|
59 |
+
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
|
60 |
+
###########################################################
|
61 |
+
discriminator_params:
|
62 |
+
in_channels: 1 # Number of input channels.
|
63 |
+
out_channels: 1 # Number of output channels.
|
64 |
+
kernel_size: 3 # Number of output channels.
|
65 |
+
layers: 10 # Number of conv layers.
|
66 |
+
conv_channels: 64 # Number of chnn layers.
|
67 |
+
bias: true # Whether to use bias parameter in conv.
|
68 |
+
use_weight_norm: true # Whether to use weight norm.
|
69 |
+
# If set to true, it will be applied to all of the conv layers.
|
70 |
+
nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv.
|
71 |
+
nonlinear_activation_params: # Nonlinear function parameters
|
72 |
+
negative_slope: 0.2 # Alpha in LeakyReLU.
|
73 |
+
|
74 |
+
###########################################################
|
75 |
+
# STFT LOSS SETTING #
|
76 |
+
###########################################################
|
77 |
+
stft_loss_params:
|
78 |
+
fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
|
79 |
+
hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
|
80 |
+
win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
|
81 |
+
window: "hann_window" # Window function for STFT-based loss
|
82 |
+
use_mel_loss: false
|
83 |
+
|
84 |
+
###########################################################
|
85 |
+
# ADVERSARIAL LOSS SETTING #
|
86 |
+
###########################################################
|
87 |
+
lambda_adv: 4.0 # Loss balancing coefficient.
|
88 |
+
|
89 |
+
###########################################################
|
90 |
+
# OPTIMIZER & SCHEDULER SETTING #
|
91 |
+
###########################################################
|
92 |
+
generator_optimizer_params:
|
93 |
+
lr: 0.0001 # Generator's learning rate.
|
94 |
+
eps: 1.0e-6 # Generator's epsilon.
|
95 |
+
weight_decay: 0.0 # Generator's weight decay coefficient.
|
96 |
+
generator_scheduler_params:
|
97 |
+
step_size: 200000 # Generator's scheduler step size.
|
98 |
+
gamma: 0.5 # Generator's scheduler gamma.
|
99 |
+
# At each step size, lr will be multiplied by this parameter.
|
100 |
+
generator_grad_norm: 10 # Generator's gradient norm.
|
101 |
+
discriminator_optimizer_params:
|
102 |
+
lr: 0.00005 # Discriminator's learning rate.
|
103 |
+
eps: 1.0e-6 # Discriminator's epsilon.
|
104 |
+
weight_decay: 0.0 # Discriminator's weight decay coefficient.
|
105 |
+
discriminator_scheduler_params:
|
106 |
+
step_size: 200000 # Discriminator's scheduler step size.
|
107 |
+
gamma: 0.5 # Discriminator's scheduler gamma.
|
108 |
+
# At each step size, lr will be multiplied by this parameter.
|
109 |
+
discriminator_grad_norm: 1 # Discriminator's gradient norm.
|
110 |
+
disc_start_steps: 40000 # Number of steps to start to train discriminator.
|
data_gen/singing/binarize.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from copy import deepcopy
|
4 |
+
import pandas as pd
|
5 |
+
import logging
|
6 |
+
from tqdm import tqdm
|
7 |
+
import json
|
8 |
+
import glob
|
9 |
+
import re
|
10 |
+
from resemblyzer import VoiceEncoder
|
11 |
+
import traceback
|
12 |
+
import numpy as np
|
13 |
+
import pretty_midi
|
14 |
+
import librosa
|
15 |
+
from scipy.interpolate import interp1d
|
16 |
+
import torch
|
17 |
+
from textgrid import TextGrid
|
18 |
+
|
19 |
+
from utils.hparams import hparams
|
20 |
+
from data_gen.tts.data_gen_utils import build_phone_encoder, get_pitch
|
21 |
+
from utils.pitch_utils import f0_to_coarse
|
22 |
+
from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError
|
23 |
+
from data_gen.tts.binarizer_zh import ZhBinarizer
|
24 |
+
from data_gen.tts.txt_processors.zh_g2pM import ALL_YUNMU
|
25 |
+
from vocoders.base_vocoder import VOCODERS
|
26 |
+
|
27 |
+
|
28 |
+
class SingingBinarizer(BaseBinarizer):
|
29 |
+
def __init__(self, processed_data_dir=None):
|
30 |
+
if processed_data_dir is None:
|
31 |
+
processed_data_dir = hparams['processed_data_dir']
|
32 |
+
self.processed_data_dirs = processed_data_dir.split(",")
|
33 |
+
self.binarization_args = hparams['binarization_args']
|
34 |
+
self.pre_align_args = hparams['pre_align_args']
|
35 |
+
self.item2txt = {}
|
36 |
+
self.item2ph = {}
|
37 |
+
self.item2wavfn = {}
|
38 |
+
self.item2f0fn = {}
|
39 |
+
self.item2tgfn = {}
|
40 |
+
self.item2spk = {}
|
41 |
+
|
42 |
+
def split_train_test_set(self, item_names):
|
43 |
+
item_names = deepcopy(item_names)
|
44 |
+
test_item_names = [x for x in item_names if any([ts in x for ts in hparams['test_prefixes']])]
|
45 |
+
train_item_names = [x for x in item_names if x not in set(test_item_names)]
|
46 |
+
logging.info("train {}".format(len(train_item_names)))
|
47 |
+
logging.info("test {}".format(len(test_item_names)))
|
48 |
+
return train_item_names, test_item_names
|
49 |
+
|
50 |
+
def load_meta_data(self):
|
51 |
+
for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
|
52 |
+
wav_suffix = '_wf0.wav'
|
53 |
+
txt_suffix = '.txt'
|
54 |
+
ph_suffix = '_ph.txt'
|
55 |
+
tg_suffix = '.TextGrid'
|
56 |
+
all_wav_pieces = glob.glob(f'{processed_data_dir}/*/*{wav_suffix}')
|
57 |
+
|
58 |
+
for piece_path in all_wav_pieces:
|
59 |
+
item_name = raw_item_name = piece_path[len(processed_data_dir)+1:].replace('/', '-')[:-len(wav_suffix)]
|
60 |
+
if len(self.processed_data_dirs) > 1:
|
61 |
+
item_name = f'ds{ds_id}_{item_name}'
|
62 |
+
self.item2txt[item_name] = open(f'{piece_path.replace(wav_suffix, txt_suffix)}').readline()
|
63 |
+
self.item2ph[item_name] = open(f'{piece_path.replace(wav_suffix, ph_suffix)}').readline()
|
64 |
+
self.item2wavfn[item_name] = piece_path
|
65 |
+
|
66 |
+
self.item2spk[item_name] = re.split('-|#', piece_path.split('/')[-2])[0]
|
67 |
+
if len(self.processed_data_dirs) > 1:
|
68 |
+
self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
|
69 |
+
self.item2tgfn[item_name] = piece_path.replace(wav_suffix, tg_suffix)
|
70 |
+
print('spkers: ', set(self.item2spk.values()))
|
71 |
+
self.item_names = sorted(list(self.item2txt.keys()))
|
72 |
+
if self.binarization_args['shuffle']:
|
73 |
+
random.seed(1234)
|
74 |
+
random.shuffle(self.item_names)
|
75 |
+
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
|
76 |
+
|
77 |
+
@property
|
78 |
+
def train_item_names(self):
|
79 |
+
return self._train_item_names
|
80 |
+
|
81 |
+
@property
|
82 |
+
def valid_item_names(self):
|
83 |
+
return self._test_item_names
|
84 |
+
|
85 |
+
@property
|
86 |
+
def test_item_names(self):
|
87 |
+
return self._test_item_names
|
88 |
+
|
89 |
+
def process(self):
|
90 |
+
self.load_meta_data()
|
91 |
+
os.makedirs(hparams['binary_data_dir'], exist_ok=True)
|
92 |
+
self.spk_map = self.build_spk_map()
|
93 |
+
print("| spk_map: ", self.spk_map)
|
94 |
+
spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
|
95 |
+
json.dump(self.spk_map, open(spk_map_fn, 'w'))
|
96 |
+
|
97 |
+
self.phone_encoder = self._phone_encoder()
|
98 |
+
self.process_data('valid')
|
99 |
+
self.process_data('test')
|
100 |
+
self.process_data('train')
|
101 |
+
|
102 |
+
def _phone_encoder(self):
|
103 |
+
ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
|
104 |
+
ph_set = []
|
105 |
+
if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
|
106 |
+
for ph_sent in self.item2ph.values():
|
107 |
+
ph_set += ph_sent.split(' ')
|
108 |
+
ph_set = sorted(set(ph_set))
|
109 |
+
json.dump(ph_set, open(ph_set_fn, 'w'))
|
110 |
+
print("| Build phone set: ", ph_set)
|
111 |
+
else:
|
112 |
+
ph_set = json.load(open(ph_set_fn, 'r'))
|
113 |
+
print("| Load phone set: ", ph_set)
|
114 |
+
return build_phone_encoder(hparams['binary_data_dir'])
|
115 |
+
|
116 |
+
# @staticmethod
|
117 |
+
# def get_pitch(wav_fn, spec, res):
|
118 |
+
# wav_suffix = '_wf0.wav'
|
119 |
+
# f0_suffix = '_f0.npy'
|
120 |
+
# f0fn = wav_fn.replace(wav_suffix, f0_suffix)
|
121 |
+
# pitch_info = np.load(f0fn)
|
122 |
+
# f0 = [x[1] for x in pitch_info]
|
123 |
+
# spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)]
|
124 |
+
# f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)]
|
125 |
+
# f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)]
|
126 |
+
# # f0_x_coor = np.arange(0, 1, 1 / len(f0))
|
127 |
+
# # f0_x_coor[-1] = 1
|
128 |
+
# # f0 = interp1d(f0_x_coor, f0, 'nearest')(spec_x_coor)[:len(spec)]
|
129 |
+
# if sum(f0) == 0:
|
130 |
+
# raise BinarizationError("Empty f0")
|
131 |
+
# assert len(f0) == len(spec), (len(f0), len(spec))
|
132 |
+
# pitch_coarse = f0_to_coarse(f0)
|
133 |
+
#
|
134 |
+
# # vis f0
|
135 |
+
# # import matplotlib.pyplot as plt
|
136 |
+
# # from textgrid import TextGrid
|
137 |
+
# # tg_fn = wav_fn.replace(wav_suffix, '.TextGrid')
|
138 |
+
# # fig = plt.figure(figsize=(12, 6))
|
139 |
+
# # plt.pcolor(spec.T, vmin=-5, vmax=0)
|
140 |
+
# # ax = plt.gca()
|
141 |
+
# # ax2 = ax.twinx()
|
142 |
+
# # ax2.plot(f0, color='red')
|
143 |
+
# # ax2.set_ylim(0, 800)
|
144 |
+
# # itvs = TextGrid.fromFile(tg_fn)[0]
|
145 |
+
# # for itv in itvs:
|
146 |
+
# # x = itv.maxTime * hparams['audio_sample_rate'] / hparams['hop_size']
|
147 |
+
# # plt.vlines(x=x, ymin=0, ymax=80, color='black')
|
148 |
+
# # plt.text(x=x, y=20, s=itv.mark, color='black')
|
149 |
+
# # plt.savefig('tmp/20211229_singing_plots_test.png')
|
150 |
+
#
|
151 |
+
# res['f0'] = f0
|
152 |
+
# res['pitch'] = pitch_coarse
|
153 |
+
|
154 |
+
@classmethod
|
155 |
+
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
|
156 |
+
if hparams['vocoder'] in VOCODERS:
|
157 |
+
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
|
158 |
+
else:
|
159 |
+
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
|
160 |
+
res = {
|
161 |
+
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
|
162 |
+
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
|
163 |
+
}
|
164 |
+
try:
|
165 |
+
if binarization_args['with_f0']:
|
166 |
+
# cls.get_pitch(wav_fn, mel, res)
|
167 |
+
cls.get_pitch(wav, mel, res)
|
168 |
+
if binarization_args['with_txt']:
|
169 |
+
try:
|
170 |
+
# print(ph)
|
171 |
+
phone_encoded = res['phone'] = encoder.encode(ph)
|
172 |
+
except:
|
173 |
+
traceback.print_exc()
|
174 |
+
raise BinarizationError(f"Empty phoneme")
|
175 |
+
if binarization_args['with_align']:
|
176 |
+
cls.get_align(tg_fn, ph, mel, phone_encoded, res)
|
177 |
+
except BinarizationError as e:
|
178 |
+
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
|
179 |
+
return None
|
180 |
+
return res
|
181 |
+
|
182 |
+
|
183 |
+
class MidiSingingBinarizer(SingingBinarizer):
|
184 |
+
item2midi = {}
|
185 |
+
item2midi_dur = {}
|
186 |
+
item2is_slur = {}
|
187 |
+
item2ph_durs = {}
|
188 |
+
item2wdb = {}
|
189 |
+
|
190 |
+
def load_meta_data(self):
|
191 |
+
for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
|
192 |
+
meta_midi = json.load(open(os.path.join(processed_data_dir, 'meta.json'))) # [list of dict]
|
193 |
+
|
194 |
+
for song_item in meta_midi:
|
195 |
+
item_name = raw_item_name = song_item['item_name']
|
196 |
+
if len(self.processed_data_dirs) > 1:
|
197 |
+
item_name = f'ds{ds_id}_{item_name}'
|
198 |
+
self.item2wavfn[item_name] = song_item['wav_fn']
|
199 |
+
self.item2txt[item_name] = song_item['txt']
|
200 |
+
|
201 |
+
self.item2ph[item_name] = ' '.join(song_item['phs'])
|
202 |
+
self.item2wdb[item_name] = [1 if x in ALL_YUNMU + ['AP', 'SP', '<SIL>'] else 0 for x in song_item['phs']]
|
203 |
+
self.item2ph_durs[item_name] = song_item['ph_dur']
|
204 |
+
|
205 |
+
self.item2midi[item_name] = song_item['notes']
|
206 |
+
self.item2midi_dur[item_name] = song_item['notes_dur']
|
207 |
+
self.item2is_slur[item_name] = song_item['is_slur']
|
208 |
+
self.item2spk[item_name] = 'pop-cs'
|
209 |
+
if len(self.processed_data_dirs) > 1:
|
210 |
+
self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
|
211 |
+
|
212 |
+
print('spkers: ', set(self.item2spk.values()))
|
213 |
+
self.item_names = sorted(list(self.item2txt.keys()))
|
214 |
+
if self.binarization_args['shuffle']:
|
215 |
+
random.seed(1234)
|
216 |
+
random.shuffle(self.item_names)
|
217 |
+
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
|
218 |
+
|
219 |
+
@staticmethod
|
220 |
+
def get_pitch(wav_fn, wav, spec, ph, res):
|
221 |
+
wav_suffix = '.wav'
|
222 |
+
# midi_suffix = '.mid'
|
223 |
+
wav_dir = 'wavs'
|
224 |
+
f0_dir = 'f0'
|
225 |
+
|
226 |
+
item_name = '/'.join(os.path.splitext(wav_fn)[0].split('/')[-2:]).replace('_wf0', '')
|
227 |
+
res['pitch_midi'] = np.asarray(MidiSingingBinarizer.item2midi[item_name])
|
228 |
+
res['midi_dur'] = np.asarray(MidiSingingBinarizer.item2midi_dur[item_name])
|
229 |
+
res['is_slur'] = np.asarray(MidiSingingBinarizer.item2is_slur[item_name])
|
230 |
+
res['word_boundary'] = np.asarray(MidiSingingBinarizer.item2wdb[item_name])
|
231 |
+
assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (
|
232 |
+
res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape)
|
233 |
+
|
234 |
+
# gt f0.
|
235 |
+
gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams)
|
236 |
+
if sum(gt_f0) == 0:
|
237 |
+
raise BinarizationError("Empty **gt** f0")
|
238 |
+
res['f0'] = gt_f0
|
239 |
+
res['pitch'] = gt_pitch_coarse
|
240 |
+
|
241 |
+
@staticmethod
|
242 |
+
def get_align(ph_durs, mel, phone_encoded, res, hop_size=hparams['hop_size'], audio_sample_rate=hparams['audio_sample_rate']):
|
243 |
+
mel2ph = np.zeros([mel.shape[0]], int)
|
244 |
+
startTime = 0
|
245 |
+
|
246 |
+
for i_ph in range(len(ph_durs)):
|
247 |
+
start_frame = int(startTime * audio_sample_rate / hop_size + 0.5)
|
248 |
+
end_frame = int((startTime + ph_durs[i_ph]) * audio_sample_rate / hop_size + 0.5)
|
249 |
+
mel2ph[start_frame:end_frame] = i_ph + 1
|
250 |
+
startTime = startTime + ph_durs[i_ph]
|
251 |
+
|
252 |
+
# print('ph durs: ', ph_durs)
|
253 |
+
# print('mel2ph: ', mel2ph, len(mel2ph))
|
254 |
+
res['mel2ph'] = mel2ph
|
255 |
+
# res['dur'] = None
|
256 |
+
|
257 |
+
@classmethod
|
258 |
+
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
|
259 |
+
if hparams['vocoder'] in VOCODERS:
|
260 |
+
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
|
261 |
+
else:
|
262 |
+
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
|
263 |
+
res = {
|
264 |
+
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
|
265 |
+
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
|
266 |
+
}
|
267 |
+
try:
|
268 |
+
if binarization_args['with_f0']:
|
269 |
+
cls.get_pitch(wav_fn, wav, mel, ph, res)
|
270 |
+
if binarization_args['with_txt']:
|
271 |
+
try:
|
272 |
+
phone_encoded = res['phone'] = encoder.encode(ph)
|
273 |
+
except:
|
274 |
+
traceback.print_exc()
|
275 |
+
raise BinarizationError(f"Empty phoneme")
|
276 |
+
if binarization_args['with_align']:
|
277 |
+
cls.get_align(MidiSingingBinarizer.item2ph_durs[item_name], mel, phone_encoded, res)
|
278 |
+
except BinarizationError as e:
|
279 |
+
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
|
280 |
+
return None
|
281 |
+
return res
|
282 |
+
|
283 |
+
|
284 |
+
class ZhSingingBinarizer(ZhBinarizer, SingingBinarizer):
|
285 |
+
pass
|
286 |
+
|
287 |
+
class M4SingerBinarizer(MidiSingingBinarizer):
|
288 |
+
item2midi = {}
|
289 |
+
item2midi_dur = {}
|
290 |
+
item2is_slur = {}
|
291 |
+
item2ph_durs = {}
|
292 |
+
item2wdb = {}
|
293 |
+
|
294 |
+
def split_train_test_set(self, item_names):
|
295 |
+
item_names = deepcopy(item_names)
|
296 |
+
test_item_names = [x for x in item_names if any([x.startswith(ts) for ts in hparams['test_prefixes']])]
|
297 |
+
train_item_names = [x for x in item_names if x not in set(test_item_names)]
|
298 |
+
logging.info("train {}".format(len(train_item_names)))
|
299 |
+
logging.info("test {}".format(len(test_item_names)))
|
300 |
+
return train_item_names, test_item_names
|
301 |
+
|
302 |
+
def load_meta_data(self):
|
303 |
+
raw_data_dir = hparams['raw_data_dir']
|
304 |
+
song_items = json.load(open(os.path.join(raw_data_dir, 'meta.json'))) # [list of dict]
|
305 |
+
for song_item in song_items:
|
306 |
+
item_name = raw_item_name = song_item['item_name']
|
307 |
+
singer, song_name, sent_id = item_name.split("#")
|
308 |
+
self.item2wavfn[item_name] = f'{raw_data_dir}/{singer}#{song_name}/{sent_id}.wav'
|
309 |
+
self.item2txt[item_name] = song_item['txt']
|
310 |
+
|
311 |
+
self.item2ph[item_name] = ' '.join(song_item['phs'])
|
312 |
+
self.item2ph_durs[item_name] = song_item['ph_dur']
|
313 |
+
|
314 |
+
self.item2midi[item_name] = song_item['notes']
|
315 |
+
self.item2midi_dur[item_name] = song_item['notes_dur']
|
316 |
+
self.item2is_slur[item_name] = song_item['is_slur']
|
317 |
+
self.item2wdb[item_name] = [1 if (0 < i < len(song_item['phs']) - 1 and p in ALL_YUNMU + ['<SP>', '<AP>'])\
|
318 |
+
or i == len(song_item['phs']) - 1 else 0 for i, p in enumerate(song_item['phs'])]
|
319 |
+
self.item2spk[item_name] = singer
|
320 |
+
|
321 |
+
print('spkers: ', set(self.item2spk.values()))
|
322 |
+
self.item_names = sorted(list(self.item2txt.keys()))
|
323 |
+
if self.binarization_args['shuffle']:
|
324 |
+
random.seed(1234)
|
325 |
+
random.shuffle(self.item_names)
|
326 |
+
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
|
327 |
+
|
328 |
+
@staticmethod
|
329 |
+
def get_pitch(item_name, wav, spec, ph, res):
|
330 |
+
wav_suffix = '.wav'
|
331 |
+
# midi_suffix = '.mid'
|
332 |
+
wav_dir = 'wavs'
|
333 |
+
f0_dir = 'text_f0_align'
|
334 |
+
|
335 |
+
#item_name = os.path.splitext(os.path.basename(wav_fn))[0]
|
336 |
+
res['pitch_midi'] = np.asarray(M4SingerBinarizer.item2midi[item_name])
|
337 |
+
res['midi_dur'] = np.asarray(M4SingerBinarizer.item2midi_dur[item_name])
|
338 |
+
res['is_slur'] = np.asarray(M4SingerBinarizer.item2is_slur[item_name])
|
339 |
+
res['word_boundary'] = np.asarray(M4SingerBinarizer.item2wdb[item_name])
|
340 |
+
assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape)
|
341 |
+
|
342 |
+
# gt f0.
|
343 |
+
# f0 = None
|
344 |
+
# f0_suffix = '_f0.npy'
|
345 |
+
# f0fn = wav_fn.replace(wav_suffix, f0_suffix).replace(wav_dir, f0_dir)
|
346 |
+
# pitch_info = np.load(f0fn)
|
347 |
+
# f0 = [x[1] for x in pitch_info]
|
348 |
+
# spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)]
|
349 |
+
#
|
350 |
+
# f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)]
|
351 |
+
# f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)]
|
352 |
+
# if sum(f0) == 0:
|
353 |
+
# raise BinarizationError("Empty **gt** f0")
|
354 |
+
#
|
355 |
+
# pitch_coarse = f0_to_coarse(f0)
|
356 |
+
# res['f0'] = f0
|
357 |
+
# res['pitch'] = pitch_coarse
|
358 |
+
|
359 |
+
# gt f0.
|
360 |
+
gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams)
|
361 |
+
if sum(gt_f0) == 0:
|
362 |
+
raise BinarizationError("Empty **gt** f0")
|
363 |
+
res['f0'] = gt_f0
|
364 |
+
res['pitch'] = gt_pitch_coarse
|
365 |
+
|
366 |
+
@classmethod
|
367 |
+
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
|
368 |
+
if hparams['vocoder'] in VOCODERS:
|
369 |
+
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
|
370 |
+
else:
|
371 |
+
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
|
372 |
+
res = {
|
373 |
+
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
|
374 |
+
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
|
375 |
+
}
|
376 |
+
try:
|
377 |
+
if binarization_args['with_f0']:
|
378 |
+
cls.get_pitch(item_name, wav, mel, ph, res)
|
379 |
+
if binarization_args['with_txt']:
|
380 |
+
try:
|
381 |
+
phone_encoded = res['phone'] = encoder.encode(ph)
|
382 |
+
except:
|
383 |
+
traceback.print_exc()
|
384 |
+
raise BinarizationError(f"Empty phoneme")
|
385 |
+
if binarization_args['with_align']:
|
386 |
+
cls.get_align(M4SingerBinarizer.item2ph_durs[item_name], mel, phone_encoded, res)
|
387 |
+
except BinarizationError as e:
|
388 |
+
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
|
389 |
+
return None
|
390 |
+
return res
|
391 |
+
|
392 |
+
if __name__ == "__main__":
|
393 |
+
SingingBinarizer().process()
|
data_gen/tts/base_binarizer.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
3 |
+
|
4 |
+
from utils.multiprocess_utils import chunked_multiprocess_run
|
5 |
+
import random
|
6 |
+
import traceback
|
7 |
+
import json
|
8 |
+
from resemblyzer import VoiceEncoder
|
9 |
+
from tqdm import tqdm
|
10 |
+
from data_gen.tts.data_gen_utils import get_mel2ph, get_pitch, build_phone_encoder
|
11 |
+
from utils.hparams import set_hparams, hparams
|
12 |
+
import numpy as np
|
13 |
+
from utils.indexed_datasets import IndexedDatasetBuilder
|
14 |
+
from vocoders.base_vocoder import VOCODERS
|
15 |
+
import pandas as pd
|
16 |
+
|
17 |
+
|
18 |
+
class BinarizationError(Exception):
|
19 |
+
pass
|
20 |
+
|
21 |
+
|
22 |
+
class BaseBinarizer:
|
23 |
+
def __init__(self, processed_data_dir=None):
|
24 |
+
if processed_data_dir is None:
|
25 |
+
processed_data_dir = hparams['processed_data_dir']
|
26 |
+
self.processed_data_dirs = processed_data_dir.split(",")
|
27 |
+
self.binarization_args = hparams['binarization_args']
|
28 |
+
self.pre_align_args = hparams['pre_align_args']
|
29 |
+
self.forced_align = self.pre_align_args['forced_align']
|
30 |
+
tg_dir = None
|
31 |
+
if self.forced_align == 'mfa':
|
32 |
+
tg_dir = 'mfa_outputs'
|
33 |
+
if self.forced_align == 'kaldi':
|
34 |
+
tg_dir = 'kaldi_outputs'
|
35 |
+
self.item2txt = {}
|
36 |
+
self.item2ph = {}
|
37 |
+
self.item2wavfn = {}
|
38 |
+
self.item2tgfn = {}
|
39 |
+
self.item2spk = {}
|
40 |
+
for ds_id, processed_data_dir in enumerate(self.processed_data_dirs):
|
41 |
+
self.meta_df = pd.read_csv(f"{processed_data_dir}/metadata_phone.csv", dtype=str)
|
42 |
+
for r_idx, r in self.meta_df.iterrows():
|
43 |
+
item_name = raw_item_name = r['item_name']
|
44 |
+
if len(self.processed_data_dirs) > 1:
|
45 |
+
item_name = f'ds{ds_id}_{item_name}'
|
46 |
+
self.item2txt[item_name] = r['txt']
|
47 |
+
self.item2ph[item_name] = r['ph']
|
48 |
+
self.item2wavfn[item_name] = os.path.join(hparams['raw_data_dir'], 'wavs', os.path.basename(r['wav_fn']).split('_')[1])
|
49 |
+
self.item2spk[item_name] = r.get('spk', 'SPK1')
|
50 |
+
if len(self.processed_data_dirs) > 1:
|
51 |
+
self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}"
|
52 |
+
if tg_dir is not None:
|
53 |
+
self.item2tgfn[item_name] = f"{processed_data_dir}/{tg_dir}/{raw_item_name}.TextGrid"
|
54 |
+
self.item_names = sorted(list(self.item2txt.keys()))
|
55 |
+
if self.binarization_args['shuffle']:
|
56 |
+
random.seed(1234)
|
57 |
+
random.shuffle(self.item_names)
|
58 |
+
|
59 |
+
@property
|
60 |
+
def train_item_names(self):
|
61 |
+
return self.item_names[hparams['test_num']+hparams['valid_num']:]
|
62 |
+
|
63 |
+
@property
|
64 |
+
def valid_item_names(self):
|
65 |
+
return self.item_names[0: hparams['test_num']+hparams['valid_num']] #
|
66 |
+
|
67 |
+
@property
|
68 |
+
def test_item_names(self):
|
69 |
+
return self.item_names[0: hparams['test_num']] # Audios for MOS testing are in 'test_ids'
|
70 |
+
|
71 |
+
def build_spk_map(self):
|
72 |
+
spk_map = set()
|
73 |
+
for item_name in self.item_names:
|
74 |
+
spk_name = self.item2spk[item_name]
|
75 |
+
spk_map.add(spk_name)
|
76 |
+
spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))}
|
77 |
+
assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map)
|
78 |
+
return spk_map
|
79 |
+
|
80 |
+
def item_name2spk_id(self, item_name):
|
81 |
+
return self.spk_map[self.item2spk[item_name]]
|
82 |
+
|
83 |
+
def _phone_encoder(self):
|
84 |
+
ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
|
85 |
+
ph_set = []
|
86 |
+
if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
|
87 |
+
for processed_data_dir in self.processed_data_dirs:
|
88 |
+
ph_set += [x.split(' ')[0] for x in open(f'{processed_data_dir}/dict.txt').readlines()]
|
89 |
+
ph_set = sorted(set(ph_set))
|
90 |
+
json.dump(ph_set, open(ph_set_fn, 'w'))
|
91 |
+
else:
|
92 |
+
ph_set = json.load(open(ph_set_fn, 'r'))
|
93 |
+
print("| phone set: ", ph_set)
|
94 |
+
return build_phone_encoder(hparams['binary_data_dir'])
|
95 |
+
|
96 |
+
def meta_data(self, prefix):
|
97 |
+
if prefix == 'valid':
|
98 |
+
item_names = self.valid_item_names
|
99 |
+
elif prefix == 'test':
|
100 |
+
item_names = self.test_item_names
|
101 |
+
else:
|
102 |
+
item_names = self.train_item_names
|
103 |
+
for item_name in item_names:
|
104 |
+
ph = self.item2ph[item_name]
|
105 |
+
txt = self.item2txt[item_name]
|
106 |
+
tg_fn = self.item2tgfn.get(item_name)
|
107 |
+
wav_fn = self.item2wavfn[item_name]
|
108 |
+
spk_id = self.item_name2spk_id(item_name)
|
109 |
+
yield item_name, ph, txt, tg_fn, wav_fn, spk_id
|
110 |
+
|
111 |
+
def process(self):
|
112 |
+
os.makedirs(hparams['binary_data_dir'], exist_ok=True)
|
113 |
+
self.spk_map = self.build_spk_map()
|
114 |
+
print("| spk_map: ", self.spk_map)
|
115 |
+
spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
|
116 |
+
json.dump(self.spk_map, open(spk_map_fn, 'w'))
|
117 |
+
|
118 |
+
self.phone_encoder = self._phone_encoder()
|
119 |
+
self.process_data('valid')
|
120 |
+
self.process_data('test')
|
121 |
+
self.process_data('train')
|
122 |
+
|
123 |
+
def process_data(self, prefix):
|
124 |
+
data_dir = hparams['binary_data_dir']
|
125 |
+
args = []
|
126 |
+
builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}')
|
127 |
+
lengths = []
|
128 |
+
f0s = []
|
129 |
+
total_sec = 0
|
130 |
+
if self.binarization_args['with_spk_embed']:
|
131 |
+
voice_encoder = VoiceEncoder().cuda()
|
132 |
+
|
133 |
+
meta_data = list(self.meta_data(prefix))
|
134 |
+
for m in meta_data:
|
135 |
+
args.append(list(m) + [self.phone_encoder, self.binarization_args])
|
136 |
+
num_workers = int(os.getenv('N_PROC', os.cpu_count() // 3))
|
137 |
+
for f_id, (_, item) in enumerate(
|
138 |
+
zip(tqdm(meta_data), chunked_multiprocess_run(self.process_item, args, num_workers=num_workers))):
|
139 |
+
if item is None:
|
140 |
+
continue
|
141 |
+
item['spk_embed'] = voice_encoder.embed_utterance(item['wav']) \
|
142 |
+
if self.binarization_args['with_spk_embed'] else None
|
143 |
+
if not self.binarization_args['with_wav'] and 'wav' in item:
|
144 |
+
#print("del wav")
|
145 |
+
del item['wav']
|
146 |
+
builder.add_item(item)
|
147 |
+
lengths.append(item['len'])
|
148 |
+
total_sec += item['sec']
|
149 |
+
if item.get('f0') is not None:
|
150 |
+
f0s.append(item['f0'])
|
151 |
+
builder.finalize()
|
152 |
+
np.save(f'{data_dir}/{prefix}_lengths.npy', lengths)
|
153 |
+
if len(f0s) > 0:
|
154 |
+
f0s = np.concatenate(f0s, 0)
|
155 |
+
f0s = f0s[f0s != 0]
|
156 |
+
np.save(f'{data_dir}/{prefix}_f0s_mean_std.npy', [np.mean(f0s).item(), np.std(f0s).item()])
|
157 |
+
print(f"| {prefix} total duration: {total_sec:.3f}s")
|
158 |
+
|
159 |
+
@classmethod
|
160 |
+
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args):
|
161 |
+
if hparams['vocoder'] in VOCODERS:
|
162 |
+
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn)
|
163 |
+
else:
|
164 |
+
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn)
|
165 |
+
res = {
|
166 |
+
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn,
|
167 |
+
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id
|
168 |
+
}
|
169 |
+
try:
|
170 |
+
if binarization_args['with_f0']:
|
171 |
+
cls.get_pitch(wav, mel, res)
|
172 |
+
if binarization_args['with_f0cwt']:
|
173 |
+
cls.get_f0cwt(res['f0'], res)
|
174 |
+
if binarization_args['with_txt']:
|
175 |
+
try:
|
176 |
+
phone_encoded = res['phone'] = encoder.encode(ph)
|
177 |
+
except:
|
178 |
+
traceback.print_exc()
|
179 |
+
raise BinarizationError(f"Empty phoneme")
|
180 |
+
if binarization_args['with_align']:
|
181 |
+
cls.get_align(tg_fn, ph, mel, phone_encoded, res)
|
182 |
+
except BinarizationError as e:
|
183 |
+
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}")
|
184 |
+
return None
|
185 |
+
return res
|
186 |
+
|
187 |
+
@staticmethod
|
188 |
+
def get_align(tg_fn, ph, mel, phone_encoded, res):
|
189 |
+
if tg_fn is not None and os.path.exists(tg_fn):
|
190 |
+
mel2ph, dur = get_mel2ph(tg_fn, ph, mel, hparams)
|
191 |
+
else:
|
192 |
+
raise BinarizationError(f"Align not found")
|
193 |
+
if mel2ph.max() - 1 >= len(phone_encoded):
|
194 |
+
raise BinarizationError(
|
195 |
+
f"Align does not match: mel2ph.max() - 1: {mel2ph.max() - 1}, len(phone_encoded): {len(phone_encoded)}")
|
196 |
+
res['mel2ph'] = mel2ph
|
197 |
+
res['dur'] = dur
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
def get_pitch(wav, mel, res):
|
201 |
+
f0, pitch_coarse = get_pitch(wav, mel, hparams)
|
202 |
+
if sum(f0) == 0:
|
203 |
+
raise BinarizationError("Empty f0")
|
204 |
+
res['f0'] = f0
|
205 |
+
res['pitch'] = pitch_coarse
|
206 |
+
|
207 |
+
@staticmethod
|
208 |
+
def get_f0cwt(f0, res):
|
209 |
+
from utils.cwt import get_cont_lf0, get_lf0_cwt
|
210 |
+
uv, cont_lf0_lpf = get_cont_lf0(f0)
|
211 |
+
logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf)
|
212 |
+
cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org
|
213 |
+
Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm)
|
214 |
+
if np.any(np.isnan(Wavelet_lf0)):
|
215 |
+
raise BinarizationError("NaN CWT")
|
216 |
+
res['cwt_spec'] = Wavelet_lf0
|
217 |
+
res['cwt_scales'] = scales
|
218 |
+
res['f0_mean'] = logf0s_mean_org
|
219 |
+
res['f0_std'] = logf0s_std_org
|
220 |
+
|
221 |
+
|
222 |
+
if __name__ == "__main__":
|
223 |
+
set_hparams()
|
224 |
+
BaseBinarizer().process()
|
data_gen/tts/bin/binarize.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
4 |
+
|
5 |
+
import importlib
|
6 |
+
from utils.hparams import set_hparams, hparams
|
7 |
+
|
8 |
+
|
9 |
+
def binarize():
|
10 |
+
binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer')
|
11 |
+
pkg = ".".join(binarizer_cls.split(".")[:-1])
|
12 |
+
cls_name = binarizer_cls.split(".")[-1]
|
13 |
+
binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
|
14 |
+
print("| Binarizer: ", binarizer_cls)
|
15 |
+
binarizer_cls().process()
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
set_hparams()
|
20 |
+
binarize()
|
data_gen/tts/binarizer_zh.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
4 |
+
|
5 |
+
from data_gen.tts.txt_processors.zh_g2pM import ALL_SHENMU
|
6 |
+
from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError
|
7 |
+
from data_gen.tts.data_gen_utils import get_mel2ph
|
8 |
+
from utils.hparams import set_hparams, hparams
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
class ZhBinarizer(BaseBinarizer):
|
13 |
+
@staticmethod
|
14 |
+
def get_align(tg_fn, ph, mel, phone_encoded, res):
|
15 |
+
if tg_fn is not None and os.path.exists(tg_fn):
|
16 |
+
_, dur = get_mel2ph(tg_fn, ph, mel, hparams)
|
17 |
+
else:
|
18 |
+
raise BinarizationError(f"Align not found")
|
19 |
+
ph_list = ph.split(" ")
|
20 |
+
assert len(dur) == len(ph_list)
|
21 |
+
mel2ph = []
|
22 |
+
# 分隔符的时长分配给韵母
|
23 |
+
dur_cumsum = np.pad(np.cumsum(dur), [1, 0], mode='constant', constant_values=0)
|
24 |
+
for i in range(len(dur)):
|
25 |
+
p = ph_list[i]
|
26 |
+
if p[0] != '<' and not p[0].isalpha():
|
27 |
+
uv_ = res['f0'][dur_cumsum[i]:dur_cumsum[i + 1]] == 0
|
28 |
+
j = 0
|
29 |
+
while j < len(uv_) and not uv_[j]:
|
30 |
+
j += 1
|
31 |
+
dur[i - 1] += j
|
32 |
+
dur[i] -= j
|
33 |
+
if dur[i] < 100:
|
34 |
+
dur[i - 1] += dur[i]
|
35 |
+
dur[i] = 0
|
36 |
+
# 声母和韵母等长
|
37 |
+
for i in range(len(dur)):
|
38 |
+
p = ph_list[i]
|
39 |
+
if p in ALL_SHENMU:
|
40 |
+
p_next = ph_list[i + 1]
|
41 |
+
if not (dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU):
|
42 |
+
print(f"assert dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU, "
|
43 |
+
f"dur[i]: {dur[i]}, p: {p}, p_next: {p_next}.")
|
44 |
+
continue
|
45 |
+
total = dur[i + 1] + dur[i]
|
46 |
+
dur[i] = total // 2
|
47 |
+
dur[i + 1] = total - dur[i]
|
48 |
+
for i in range(len(dur)):
|
49 |
+
mel2ph += [i + 1] * dur[i]
|
50 |
+
mel2ph = np.array(mel2ph)
|
51 |
+
if mel2ph.max() - 1 >= len(phone_encoded):
|
52 |
+
raise BinarizationError(f"| Align does not match: {(mel2ph.max() - 1, len(phone_encoded))}")
|
53 |
+
res['mel2ph'] = mel2ph
|
54 |
+
res['dur'] = dur
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
set_hparams()
|
59 |
+
ZhBinarizer().process()
|
data_gen/tts/data_gen_utils.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
warnings.filterwarnings("ignore")
|
4 |
+
|
5 |
+
import parselmouth
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
from skimage.transform import resize
|
9 |
+
from utils.text_encoder import TokenTextEncoder
|
10 |
+
from utils.pitch_utils import f0_to_coarse
|
11 |
+
import struct
|
12 |
+
import webrtcvad
|
13 |
+
from scipy.ndimage.morphology import binary_dilation
|
14 |
+
import librosa
|
15 |
+
import numpy as np
|
16 |
+
from utils import audio
|
17 |
+
import pyloudnorm as pyln
|
18 |
+
import re
|
19 |
+
import json
|
20 |
+
from collections import OrderedDict
|
21 |
+
|
22 |
+
PUNCS = '!,.?;:'
|
23 |
+
|
24 |
+
int16_max = (2 ** 15) - 1
|
25 |
+
|
26 |
+
|
27 |
+
def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, vad_max_silence_length=12):
|
28 |
+
"""
|
29 |
+
Ensures that segments without voice in the waveform remain no longer than a
|
30 |
+
threshold determined by the VAD parameters in params.py.
|
31 |
+
:param wav: the raw waveform as a numpy array of floats
|
32 |
+
:param vad_max_silence_length: Maximum number of consecutive silent frames a segment can have.
|
33 |
+
:return: the same waveform with silences trimmed away (length <= original wav length)
|
34 |
+
"""
|
35 |
+
|
36 |
+
## Voice Activation Detection
|
37 |
+
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
38 |
+
# This sets the granularity of the VAD. Should not need to be changed.
|
39 |
+
sampling_rate = 16000
|
40 |
+
wav_raw, sr = librosa.core.load(path, sr=sr)
|
41 |
+
|
42 |
+
if norm:
|
43 |
+
meter = pyln.Meter(sr) # create BS.1770 meter
|
44 |
+
loudness = meter.integrated_loudness(wav_raw)
|
45 |
+
wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0)
|
46 |
+
if np.abs(wav_raw).max() > 1.0:
|
47 |
+
wav_raw = wav_raw / np.abs(wav_raw).max()
|
48 |
+
|
49 |
+
wav = librosa.resample(wav_raw, sr, sampling_rate, res_type='kaiser_best')
|
50 |
+
|
51 |
+
vad_window_length = 30 # In milliseconds
|
52 |
+
# Number of frames to average together when performing the moving average smoothing.
|
53 |
+
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
54 |
+
vad_moving_average_width = 8
|
55 |
+
|
56 |
+
# Compute the voice detection window size
|
57 |
+
samples_per_window = (vad_window_length * sampling_rate) // 1000
|
58 |
+
|
59 |
+
# Trim the end of the audio to have a multiple of the window size
|
60 |
+
wav = wav[:len(wav) - (len(wav) % samples_per_window)]
|
61 |
+
|
62 |
+
# Convert the float waveform to 16-bit mono PCM
|
63 |
+
pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
|
64 |
+
|
65 |
+
# Perform voice activation detection
|
66 |
+
voice_flags = []
|
67 |
+
vad = webrtcvad.Vad(mode=3)
|
68 |
+
for window_start in range(0, len(wav), samples_per_window):
|
69 |
+
window_end = window_start + samples_per_window
|
70 |
+
voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
|
71 |
+
sample_rate=sampling_rate))
|
72 |
+
voice_flags = np.array(voice_flags)
|
73 |
+
|
74 |
+
# Smooth the voice detection with a moving average
|
75 |
+
def moving_average(array, width):
|
76 |
+
array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
|
77 |
+
ret = np.cumsum(array_padded, dtype=float)
|
78 |
+
ret[width:] = ret[width:] - ret[:-width]
|
79 |
+
return ret[width - 1:] / width
|
80 |
+
|
81 |
+
audio_mask = moving_average(voice_flags, vad_moving_average_width)
|
82 |
+
audio_mask = np.round(audio_mask).astype(np.bool)
|
83 |
+
|
84 |
+
# Dilate the voiced regions
|
85 |
+
audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
|
86 |
+
audio_mask = np.repeat(audio_mask, samples_per_window)
|
87 |
+
audio_mask = resize(audio_mask, (len(wav_raw),)) > 0
|
88 |
+
if return_raw_wav:
|
89 |
+
return wav_raw, audio_mask, sr
|
90 |
+
return wav_raw[audio_mask], audio_mask, sr
|
91 |
+
|
92 |
+
|
93 |
+
def process_utterance(wav_path,
|
94 |
+
fft_size=1024,
|
95 |
+
hop_size=256,
|
96 |
+
win_length=1024,
|
97 |
+
window="hann",
|
98 |
+
num_mels=80,
|
99 |
+
fmin=80,
|
100 |
+
fmax=7600,
|
101 |
+
eps=1e-6,
|
102 |
+
sample_rate=22050,
|
103 |
+
loud_norm=False,
|
104 |
+
min_level_db=-100,
|
105 |
+
return_linear=False,
|
106 |
+
trim_long_sil=False, vocoder='pwg'):
|
107 |
+
if isinstance(wav_path, str):
|
108 |
+
if trim_long_sil:
|
109 |
+
wav, _, _ = trim_long_silences(wav_path, sample_rate)
|
110 |
+
else:
|
111 |
+
wav, _ = librosa.core.load(wav_path, sr=sample_rate)
|
112 |
+
else:
|
113 |
+
wav = wav_path
|
114 |
+
|
115 |
+
if loud_norm:
|
116 |
+
meter = pyln.Meter(sample_rate) # create BS.1770 meter
|
117 |
+
loudness = meter.integrated_loudness(wav)
|
118 |
+
wav = pyln.normalize.loudness(wav, loudness, -22.0)
|
119 |
+
if np.abs(wav).max() > 1:
|
120 |
+
wav = wav / np.abs(wav).max()
|
121 |
+
|
122 |
+
# get amplitude spectrogram
|
123 |
+
x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
|
124 |
+
win_length=win_length, window=window, pad_mode="constant")
|
125 |
+
spc = np.abs(x_stft) # (n_bins, T)
|
126 |
+
|
127 |
+
# get mel basis
|
128 |
+
fmin = 0 if fmin == -1 else fmin
|
129 |
+
fmax = sample_rate / 2 if fmax == -1 else fmax
|
130 |
+
mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax)
|
131 |
+
mel = mel_basis @ spc
|
132 |
+
|
133 |
+
if vocoder == 'pwg':
|
134 |
+
mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
|
135 |
+
else:
|
136 |
+
assert False, f'"{vocoder}" is not in ["pwg"].'
|
137 |
+
|
138 |
+
l_pad, r_pad = audio.librosa_pad_lr(wav, fft_size, hop_size, 1)
|
139 |
+
wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
|
140 |
+
wav = wav[:mel.shape[1] * hop_size]
|
141 |
+
|
142 |
+
if not return_linear:
|
143 |
+
return wav, mel
|
144 |
+
else:
|
145 |
+
spc = audio.amp_to_db(spc)
|
146 |
+
spc = audio.normalize(spc, {'min_level_db': min_level_db})
|
147 |
+
return wav, mel, spc
|
148 |
+
|
149 |
+
|
150 |
+
def get_pitch(wav_data, mel, hparams):
|
151 |
+
"""
|
152 |
+
|
153 |
+
:param wav_data: [T]
|
154 |
+
:param mel: [T, 80]
|
155 |
+
:param hparams:
|
156 |
+
:return:
|
157 |
+
"""
|
158 |
+
time_step = hparams['hop_size'] / hparams['audio_sample_rate'] * 1000
|
159 |
+
f0_min = 80
|
160 |
+
f0_max = 750
|
161 |
+
|
162 |
+
if hparams['hop_size'] == 128:
|
163 |
+
pad_size = 4
|
164 |
+
elif hparams['hop_size'] == 256:
|
165 |
+
pad_size = 2
|
166 |
+
else:
|
167 |
+
assert False
|
168 |
+
|
169 |
+
f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac(
|
170 |
+
time_step=time_step / 1000, voicing_threshold=0.6,
|
171 |
+
pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
|
172 |
+
lpad = pad_size * 2
|
173 |
+
rpad = len(mel) - len(f0) - lpad
|
174 |
+
f0 = np.pad(f0, [[lpad, rpad]], mode='constant')
|
175 |
+
# mel and f0 are extracted by 2 different libraries. we should force them to have the same length.
|
176 |
+
# Attention: we find that new version of some libraries could cause ``rpad'' to be a negetive value...
|
177 |
+
# Just to be sure, we recommend users to set up the same environments as them in requirements_auto.txt (by Anaconda)
|
178 |
+
delta_l = len(mel) - len(f0)
|
179 |
+
assert np.abs(delta_l) <= 8
|
180 |
+
if delta_l > 0:
|
181 |
+
f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
|
182 |
+
f0 = f0[:len(mel)]
|
183 |
+
pitch_coarse = f0_to_coarse(f0)
|
184 |
+
return f0, pitch_coarse
|
185 |
+
|
186 |
+
|
187 |
+
def remove_empty_lines(text):
|
188 |
+
"""remove empty lines"""
|
189 |
+
assert (len(text) > 0)
|
190 |
+
assert (isinstance(text, list))
|
191 |
+
text = [t.strip() for t in text]
|
192 |
+
if "" in text:
|
193 |
+
text.remove("")
|
194 |
+
return text
|
195 |
+
|
196 |
+
|
197 |
+
class TextGrid(object):
|
198 |
+
def __init__(self, text):
|
199 |
+
text = remove_empty_lines(text)
|
200 |
+
self.text = text
|
201 |
+
self.line_count = 0
|
202 |
+
self._get_type()
|
203 |
+
self._get_time_intval()
|
204 |
+
self._get_size()
|
205 |
+
self.tier_list = []
|
206 |
+
self._get_item_list()
|
207 |
+
|
208 |
+
def _extract_pattern(self, pattern, inc):
|
209 |
+
"""
|
210 |
+
Parameters
|
211 |
+
----------
|
212 |
+
pattern : regex to extract pattern
|
213 |
+
inc : increment of line count after extraction
|
214 |
+
Returns
|
215 |
+
-------
|
216 |
+
group : extracted info
|
217 |
+
"""
|
218 |
+
try:
|
219 |
+
group = re.match(pattern, self.text[self.line_count]).group(1)
|
220 |
+
self.line_count += inc
|
221 |
+
except AttributeError:
|
222 |
+
raise ValueError("File format error at line %d:%s" % (self.line_count, self.text[self.line_count]))
|
223 |
+
return group
|
224 |
+
|
225 |
+
def _get_type(self):
|
226 |
+
self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2)
|
227 |
+
|
228 |
+
def _get_time_intval(self):
|
229 |
+
self.xmin = self._extract_pattern(r"xmin = (.*)", 1)
|
230 |
+
self.xmax = self._extract_pattern(r"xmax = (.*)", 2)
|
231 |
+
|
232 |
+
def _get_size(self):
|
233 |
+
self.size = int(self._extract_pattern(r"size = (.*)", 2))
|
234 |
+
|
235 |
+
def _get_item_list(self):
|
236 |
+
"""Only supports IntervalTier currently"""
|
237 |
+
for itemIdx in range(1, self.size + 1):
|
238 |
+
tier = OrderedDict()
|
239 |
+
item_list = []
|
240 |
+
tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1)
|
241 |
+
tier_class = self._extract_pattern(r"class = \"(.*)\"", 1)
|
242 |
+
if tier_class != "IntervalTier":
|
243 |
+
raise NotImplementedError("Only IntervalTier class is supported currently")
|
244 |
+
tier_name = self._extract_pattern(r"name = \"(.*)\"", 1)
|
245 |
+
tier_xmin = self._extract_pattern(r"xmin = (.*)", 1)
|
246 |
+
tier_xmax = self._extract_pattern(r"xmax = (.*)", 1)
|
247 |
+
tier_size = self._extract_pattern(r"intervals: size = (.*)", 1)
|
248 |
+
for i in range(int(tier_size)):
|
249 |
+
item = OrderedDict()
|
250 |
+
item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1)
|
251 |
+
item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1)
|
252 |
+
item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1)
|
253 |
+
item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1)
|
254 |
+
item_list.append(item)
|
255 |
+
tier["idx"] = tier_idx
|
256 |
+
tier["class"] = tier_class
|
257 |
+
tier["name"] = tier_name
|
258 |
+
tier["xmin"] = tier_xmin
|
259 |
+
tier["xmax"] = tier_xmax
|
260 |
+
tier["size"] = tier_size
|
261 |
+
tier["items"] = item_list
|
262 |
+
self.tier_list.append(tier)
|
263 |
+
|
264 |
+
def toJson(self):
|
265 |
+
_json = OrderedDict()
|
266 |
+
_json["file_type"] = self.file_type
|
267 |
+
_json["xmin"] = self.xmin
|
268 |
+
_json["xmax"] = self.xmax
|
269 |
+
_json["size"] = self.size
|
270 |
+
_json["tiers"] = self.tier_list
|
271 |
+
return json.dumps(_json, ensure_ascii=False, indent=2)
|
272 |
+
|
273 |
+
|
274 |
+
def get_mel2ph(tg_fn, ph, mel, hparams):
|
275 |
+
ph_list = ph.split(" ")
|
276 |
+
with open(tg_fn, "r") as f:
|
277 |
+
tg = f.readlines()
|
278 |
+
tg = remove_empty_lines(tg)
|
279 |
+
tg = TextGrid(tg)
|
280 |
+
tg = json.loads(tg.toJson())
|
281 |
+
split = np.ones(len(ph_list) + 1, np.float) * -1
|
282 |
+
tg_idx = 0
|
283 |
+
ph_idx = 0
|
284 |
+
tg_align = [x for x in tg['tiers'][-1]['items']]
|
285 |
+
tg_align_ = []
|
286 |
+
for x in tg_align:
|
287 |
+
x['xmin'] = float(x['xmin'])
|
288 |
+
x['xmax'] = float(x['xmax'])
|
289 |
+
if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC']:
|
290 |
+
x['text'] = ''
|
291 |
+
if len(tg_align_) > 0 and tg_align_[-1]['text'] == '':
|
292 |
+
tg_align_[-1]['xmax'] = x['xmax']
|
293 |
+
continue
|
294 |
+
tg_align_.append(x)
|
295 |
+
tg_align = tg_align_
|
296 |
+
tg_len = len([x for x in tg_align if x['text'] != ''])
|
297 |
+
ph_len = len([x for x in ph_list if not is_sil_phoneme(x)])
|
298 |
+
assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, tg_fn)
|
299 |
+
while tg_idx < len(tg_align) or ph_idx < len(ph_list):
|
300 |
+
if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]):
|
301 |
+
split[ph_idx] = 1e8
|
302 |
+
ph_idx += 1
|
303 |
+
continue
|
304 |
+
x = tg_align[tg_idx]
|
305 |
+
if x['text'] == '' and ph_idx == len(ph_list):
|
306 |
+
tg_idx += 1
|
307 |
+
continue
|
308 |
+
assert ph_idx < len(ph_list), (tg_len, ph_len, tg_align, ph_list, tg_fn)
|
309 |
+
ph = ph_list[ph_idx]
|
310 |
+
if x['text'] == '' and not is_sil_phoneme(ph):
|
311 |
+
assert False, (ph_list, tg_align)
|
312 |
+
if x['text'] != '' and is_sil_phoneme(ph):
|
313 |
+
ph_idx += 1
|
314 |
+
else:
|
315 |
+
assert (x['text'] == '' and is_sil_phoneme(ph)) \
|
316 |
+
or x['text'].lower() == ph.lower() \
|
317 |
+
or x['text'].lower() == 'sil', (x['text'], ph)
|
318 |
+
split[ph_idx] = x['xmin']
|
319 |
+
if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(ph_list[ph_idx - 1]):
|
320 |
+
split[ph_idx - 1] = split[ph_idx]
|
321 |
+
ph_idx += 1
|
322 |
+
tg_idx += 1
|
323 |
+
assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align])
|
324 |
+
assert ph_idx >= len(ph_list) - 1, (ph_idx, ph_list, len(ph_list), [x['text'] for x in tg_align], tg_fn)
|
325 |
+
mel2ph = np.zeros([mel.shape[0]], np.int)
|
326 |
+
split[0] = 0
|
327 |
+
split[-1] = 1e8
|
328 |
+
for i in range(len(split) - 1):
|
329 |
+
assert split[i] != -1 and split[i] <= split[i + 1], (split[:-1],)
|
330 |
+
split = [int(s * hparams['audio_sample_rate'] / hparams['hop_size'] + 0.5) for s in split]
|
331 |
+
for ph_idx in range(len(ph_list)):
|
332 |
+
mel2ph[split[ph_idx]:split[ph_idx + 1]] = ph_idx + 1
|
333 |
+
mel2ph_torch = torch.from_numpy(mel2ph)
|
334 |
+
T_t = len(ph_list)
|
335 |
+
dur = mel2ph_torch.new_zeros([T_t + 1]).scatter_add(0, mel2ph_torch, torch.ones_like(mel2ph_torch))
|
336 |
+
dur = dur[1:].numpy()
|
337 |
+
return mel2ph, dur
|
338 |
+
|
339 |
+
|
340 |
+
def build_phone_encoder(data_dir):
|
341 |
+
phone_list_file = os.path.join(data_dir, 'phone_set.json')
|
342 |
+
phone_list = json.load(open(phone_list_file))
|
343 |
+
return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
|
344 |
+
|
345 |
+
|
346 |
+
def is_sil_phoneme(p):
|
347 |
+
return not p[0].isalpha()
|
data_gen/tts/txt_processors/base_text_processor.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class BaseTxtProcessor:
|
2 |
+
@staticmethod
|
3 |
+
def sp_phonemes():
|
4 |
+
return ['|']
|
5 |
+
|
6 |
+
@classmethod
|
7 |
+
def process(cls, txt, pre_align_args):
|
8 |
+
raise NotImplementedError
|
data_gen/tts/txt_processors/en.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from data_gen.tts.data_gen_utils import PUNCS
|
3 |
+
from g2p_en import G2p
|
4 |
+
import unicodedata
|
5 |
+
from g2p_en.expand import normalize_numbers
|
6 |
+
from nltk import pos_tag
|
7 |
+
from nltk.tokenize import TweetTokenizer
|
8 |
+
|
9 |
+
from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
|
10 |
+
|
11 |
+
|
12 |
+
class EnG2p(G2p):
|
13 |
+
word_tokenize = TweetTokenizer().tokenize
|
14 |
+
|
15 |
+
def __call__(self, text):
|
16 |
+
# preprocessing
|
17 |
+
words = EnG2p.word_tokenize(text)
|
18 |
+
tokens = pos_tag(words) # tuples of (word, tag)
|
19 |
+
|
20 |
+
# steps
|
21 |
+
prons = []
|
22 |
+
for word, pos in tokens:
|
23 |
+
if re.search("[a-z]", word) is None:
|
24 |
+
pron = [word]
|
25 |
+
|
26 |
+
elif word in self.homograph2features: # Check homograph
|
27 |
+
pron1, pron2, pos1 = self.homograph2features[word]
|
28 |
+
if pos.startswith(pos1):
|
29 |
+
pron = pron1
|
30 |
+
else:
|
31 |
+
pron = pron2
|
32 |
+
elif word in self.cmu: # lookup CMU dict
|
33 |
+
pron = self.cmu[word][0]
|
34 |
+
else: # predict for oov
|
35 |
+
pron = self.predict(word)
|
36 |
+
|
37 |
+
prons.extend(pron)
|
38 |
+
prons.extend([" "])
|
39 |
+
|
40 |
+
return prons[:-1]
|
41 |
+
|
42 |
+
|
43 |
+
class TxtProcessor(BaseTxtProcessor):
|
44 |
+
g2p = EnG2p()
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def preprocess_text(text):
|
48 |
+
text = normalize_numbers(text)
|
49 |
+
text = ''.join(char for char in unicodedata.normalize('NFD', text)
|
50 |
+
if unicodedata.category(char) != 'Mn') # Strip accents
|
51 |
+
text = text.lower()
|
52 |
+
text = re.sub("[\'\"()]+", "", text)
|
53 |
+
text = re.sub("[-]+", " ", text)
|
54 |
+
text = re.sub(f"[^ a-z{PUNCS}]", "", text)
|
55 |
+
text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text) # !! -> !
|
56 |
+
text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
|
57 |
+
text = text.replace("i.e.", "that is")
|
58 |
+
text = text.replace("i.e.", "that is")
|
59 |
+
text = text.replace("etc.", "etc")
|
60 |
+
text = re.sub(f"([{PUNCS}])", r" \1 ", text)
|
61 |
+
text = re.sub(rf"\s+", r" ", text)
|
62 |
+
return text
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def process(cls, txt, pre_align_args):
|
66 |
+
txt = cls.preprocess_text(txt).strip()
|
67 |
+
phs = cls.g2p(txt)
|
68 |
+
phs_ = []
|
69 |
+
n_word_sep = 0
|
70 |
+
for p in phs:
|
71 |
+
if p.strip() == '':
|
72 |
+
phs_ += ['|']
|
73 |
+
n_word_sep += 1
|
74 |
+
else:
|
75 |
+
phs_ += p.split(" ")
|
76 |
+
phs = phs_
|
77 |
+
assert n_word_sep + 1 == len(txt.split(" ")), (phs, f"\"{txt}\"")
|
78 |
+
return phs, txt
|
data_gen/tts/txt_processors/zh.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from pypinyin import pinyin, Style
|
3 |
+
from data_gen.tts.data_gen_utils import PUNCS
|
4 |
+
from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor
|
5 |
+
from utils.text_norm import NSWNormalizer
|
6 |
+
|
7 |
+
|
8 |
+
class TxtProcessor(BaseTxtProcessor):
|
9 |
+
table = {ord(f): ord(t) for f, t in zip(
|
10 |
+
u':,。!?【】()%#@&1234567890',
|
11 |
+
u':,.!?[]()%#@&1234567890')}
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def preprocess_text(text):
|
15 |
+
text = text.translate(TxtProcessor.table)
|
16 |
+
text = NSWNormalizer(text).normalize(remove_punc=False)
|
17 |
+
text = re.sub("[\'\"()]+", "", text)
|
18 |
+
text = re.sub("[-]+", " ", text)
|
19 |
+
text = re.sub(f"[^ A-Za-z\u4e00-\u9fff{PUNCS}]", "", text)
|
20 |
+
text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
|
21 |
+
text = re.sub(f"([{PUNCS}])", r" \1 ", text)
|
22 |
+
text = re.sub(rf"\s+", r"", text)
|
23 |
+
return text
|
24 |
+
|
25 |
+
@classmethod
|
26 |
+
def process(cls, txt, pre_align_args):
|
27 |
+
txt = cls.preprocess_text(txt)
|
28 |
+
shengmu = pinyin(txt, style=Style.INITIALS) # https://blog.csdn.net/zhoulei124/article/details/89055403
|
29 |
+
yunmu_finals = pinyin(txt, style=Style.FINALS)
|
30 |
+
yunmu_tone3 = pinyin(txt, style=Style.FINALS_TONE3)
|
31 |
+
yunmu = [[t[0] + '5'] if t[0] == f[0] else t for f, t in zip(yunmu_finals, yunmu_tone3)] \
|
32 |
+
if pre_align_args['use_tone'] else yunmu_finals
|
33 |
+
|
34 |
+
assert len(shengmu) == len(yunmu)
|
35 |
+
phs = ["|"]
|
36 |
+
for a, b, c in zip(shengmu, yunmu, yunmu_finals):
|
37 |
+
if a[0] == c[0]:
|
38 |
+
phs += [a[0], "|"]
|
39 |
+
else:
|
40 |
+
phs += [a[0], b[0], "|"]
|
41 |
+
return phs, txt
|
data_gen/tts/txt_processors/zh_g2pM.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import jieba
|
3 |
+
from pypinyin import pinyin, Style
|
4 |
+
from data_gen.tts.data_gen_utils import PUNCS
|
5 |
+
from data_gen.tts.txt_processors import zh
|
6 |
+
from g2pM import G2pM
|
7 |
+
|
8 |
+
ALL_SHENMU = ['b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'x', 'z', 'zh']
|
9 |
+
ALL_YUNMU = ['a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'ia', 'ian', 'iang', 'iao',
|
10 |
+
'ie', 'in', 'ing', 'iong', 'iou', 'o', 'ong', 'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'uei',
|
11 |
+
'uen', 'uo', 'v', 'van', 've', 'vn']
|
12 |
+
|
13 |
+
|
14 |
+
class TxtProcessor(zh.TxtProcessor):
|
15 |
+
model = G2pM()
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def sp_phonemes():
|
19 |
+
return ['|', '#']
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def process(cls, txt, pre_align_args):
|
23 |
+
txt = cls.preprocess_text(txt)
|
24 |
+
ph_list = cls.model(txt, tone=pre_align_args['use_tone'], char_split=True)
|
25 |
+
seg_list = '#'.join(jieba.cut(txt))
|
26 |
+
assert len(ph_list) == len([s for s in seg_list if s != '#']), (ph_list, seg_list)
|
27 |
+
|
28 |
+
# 加入词边界'#'
|
29 |
+
ph_list_ = []
|
30 |
+
seg_idx = 0
|
31 |
+
for p in ph_list:
|
32 |
+
p = p.replace("u:", "v")
|
33 |
+
if seg_list[seg_idx] == '#':
|
34 |
+
ph_list_.append('#')
|
35 |
+
seg_idx += 1
|
36 |
+
else:
|
37 |
+
ph_list_.append("|")
|
38 |
+
seg_idx += 1
|
39 |
+
if re.findall('[\u4e00-\u9fff]', p):
|
40 |
+
if pre_align_args['use_tone']:
|
41 |
+
p = pinyin(p, style=Style.TONE3, strict=True)[0][0]
|
42 |
+
if p[-1] not in ['1', '2', '3', '4', '5']:
|
43 |
+
p = p + '5'
|
44 |
+
else:
|
45 |
+
p = pinyin(p, style=Style.NORMAL, strict=True)[0][0]
|
46 |
+
|
47 |
+
finished = False
|
48 |
+
if len([c.isalpha() for c in p]) > 1:
|
49 |
+
for shenmu in ALL_SHENMU:
|
50 |
+
if p.startswith(shenmu) and not p.lstrip(shenmu).isnumeric():
|
51 |
+
ph_list_ += [shenmu, p.lstrip(shenmu)]
|
52 |
+
finished = True
|
53 |
+
break
|
54 |
+
if not finished:
|
55 |
+
ph_list_.append(p)
|
56 |
+
|
57 |
+
ph_list = ph_list_
|
58 |
+
|
59 |
+
# 去除静音符号周围的词边界标记 [..., '#', ',', '#', ...]
|
60 |
+
sil_phonemes = list(PUNCS) + TxtProcessor.sp_phonemes()
|
61 |
+
ph_list_ = []
|
62 |
+
for i in range(0, len(ph_list), 1):
|
63 |
+
if ph_list[i] != '#' or (ph_list[i - 1] not in sil_phonemes and ph_list[i + 1] not in sil_phonemes):
|
64 |
+
ph_list_.append(ph_list[i])
|
65 |
+
ph_list = ph_list_
|
66 |
+
return ph_list, txt
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == '__main__':
|
70 |
+
phs, txt = TxtProcessor.process('他来到了,网易杭研大厦', {'use_tone': True})
|
71 |
+
print(phs)
|
inference/m4singer/base_svs_infer.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from modules.hifigan.hifigan import HifiGanGenerator
|
6 |
+
from vocoders.hifigan import HifiGAN
|
7 |
+
from inference.m4singer.m4singer.map import m4singer_pinyin2ph_func
|
8 |
+
|
9 |
+
from utils import load_ckpt
|
10 |
+
from utils.hparams import set_hparams, hparams
|
11 |
+
from utils.text_encoder import TokenTextEncoder
|
12 |
+
from pypinyin import pinyin, lazy_pinyin, Style
|
13 |
+
import librosa
|
14 |
+
import glob
|
15 |
+
import re
|
16 |
+
|
17 |
+
|
18 |
+
class BaseSVSInfer:
|
19 |
+
def __init__(self, hparams, device=None):
|
20 |
+
if device is None:
|
21 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
22 |
+
self.hparams = hparams
|
23 |
+
self.device = device
|
24 |
+
|
25 |
+
phone_list = ["<AP>", "<SP>", "a", "ai", "an", "ang", "ao", "b", "c", "ch", "d", "e", "ei", "en", "eng", "er", "f", "g", "h",
|
26 |
+
"i", "ia", "ian", "iang", "iao", "ie", "in", "ing", "iong", "iou", "j", "k", "l", "m", "n", "o", "ong", "ou",
|
27 |
+
"p", "q", "r", "s", "sh", "t", "u", "ua", "uai", "uan", "uang", "uei", "uen", "uo", "v", "van", "ve", "vn",
|
28 |
+
"x", "z", "zh"]
|
29 |
+
self.ph_encoder = TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
|
30 |
+
self.pinyin2phs = m4singer_pinyin2ph_func()
|
31 |
+
self.spk_map = {"Alto-1": 0, "Alto-2": 1, "Alto-3": 2, "Alto-4": 3, "Alto-5": 4, "Alto-6": 5, "Alto-7": 6, "Bass-1": 7,
|
32 |
+
"Bass-2": 8, "Bass-3": 9, "Soprano-1": 10, "Soprano-2": 11, "Soprano-3": 12, "Tenor-1": 13, "Tenor-2": 14,
|
33 |
+
"Tenor-3": 15, "Tenor-4": 16, "Tenor-5": 17, "Tenor-6": 18, "Tenor-7": 19}
|
34 |
+
|
35 |
+
self.model = self.build_model()
|
36 |
+
self.model.eval()
|
37 |
+
self.model.to(self.device)
|
38 |
+
self.vocoder = self.build_vocoder()
|
39 |
+
self.vocoder.eval()
|
40 |
+
self.vocoder.to(self.device)
|
41 |
+
|
42 |
+
def build_model(self):
|
43 |
+
raise NotImplementedError
|
44 |
+
|
45 |
+
def forward_model(self, inp):
|
46 |
+
raise NotImplementedError
|
47 |
+
|
48 |
+
def build_vocoder(self):
|
49 |
+
base_dir = hparams['vocoder_ckpt']
|
50 |
+
config_path = f'{base_dir}/config.yaml'
|
51 |
+
ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
|
52 |
+
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
|
53 |
+
print('| load HifiGAN: ', ckpt)
|
54 |
+
ckpt_dict = torch.load(ckpt, map_location="cpu")
|
55 |
+
config = set_hparams(config_path, global_hparams=False)
|
56 |
+
state = ckpt_dict["state_dict"]["model_gen"]
|
57 |
+
vocoder = HifiGanGenerator(config)
|
58 |
+
vocoder.load_state_dict(state, strict=True)
|
59 |
+
vocoder.remove_weight_norm()
|
60 |
+
vocoder = vocoder.eval().to(self.device)
|
61 |
+
return vocoder
|
62 |
+
|
63 |
+
def run_vocoder(self, c, **kwargs):
|
64 |
+
c = c.transpose(2, 1) # [B, 80, T]
|
65 |
+
f0 = kwargs.get('f0') # [B, T]
|
66 |
+
if f0 is not None and hparams.get('use_nsf'):
|
67 |
+
# f0 = torch.FloatTensor(f0).to(self.device)
|
68 |
+
y = self.vocoder(c, f0).view(-1)
|
69 |
+
else:
|
70 |
+
y = self.vocoder(c).view(-1)
|
71 |
+
# [T]
|
72 |
+
return y[None]
|
73 |
+
|
74 |
+
def preprocess_word_level_input(self, inp):
|
75 |
+
# Pypinyin can't solve polyphonic words
|
76 |
+
text_raw = inp['text']
|
77 |
+
|
78 |
+
# lyric
|
79 |
+
pinyins = lazy_pinyin(text_raw, strict=False)
|
80 |
+
ph_per_word_lst = [self.pinyin2phs[pinyin.strip()] for pinyin in pinyins if pinyin.strip() in self.pinyin2phs]
|
81 |
+
|
82 |
+
# Note
|
83 |
+
note_per_word_lst = [x.strip() for x in inp['notes'].split('|') if x.strip() != '']
|
84 |
+
mididur_per_word_lst = [x.strip() for x in inp['notes_duration'].split('|') if x.strip() != '']
|
85 |
+
|
86 |
+
if len(note_per_word_lst) == len(ph_per_word_lst) == len(mididur_per_word_lst):
|
87 |
+
print('Pass word-notes check.')
|
88 |
+
else:
|
89 |
+
print('The number of words does\'t match the number of notes\' windows. ',
|
90 |
+
'You should split the note(s) for each word by | mark.')
|
91 |
+
print(ph_per_word_lst, note_per_word_lst, mididur_per_word_lst)
|
92 |
+
print(len(ph_per_word_lst), len(note_per_word_lst), len(mididur_per_word_lst))
|
93 |
+
return None
|
94 |
+
|
95 |
+
note_lst = []
|
96 |
+
ph_lst = []
|
97 |
+
midi_dur_lst = []
|
98 |
+
is_slur = []
|
99 |
+
for idx, ph_per_word in enumerate(ph_per_word_lst):
|
100 |
+
# for phs in one word:
|
101 |
+
# single ph like ['ai'] or multiple phs like ['n', 'i']
|
102 |
+
ph_in_this_word = ph_per_word.split()
|
103 |
+
|
104 |
+
# for notes in one word:
|
105 |
+
# single note like ['D4'] or multiple notes like ['D4', 'E4'] which means a 'slur' here.
|
106 |
+
note_in_this_word = note_per_word_lst[idx].split()
|
107 |
+
midi_dur_in_this_word = mididur_per_word_lst[idx].split()
|
108 |
+
# process for the model input
|
109 |
+
# Step 1.
|
110 |
+
# Deal with note of 'not slur' case or the first note of 'slur' case
|
111 |
+
# j ie
|
112 |
+
# F#4/Gb4 F#4/Gb4
|
113 |
+
# 0 0
|
114 |
+
for ph in ph_in_this_word:
|
115 |
+
ph_lst.append(ph)
|
116 |
+
note_lst.append(note_in_this_word[0])
|
117 |
+
midi_dur_lst.append(midi_dur_in_this_word[0])
|
118 |
+
is_slur.append(0)
|
119 |
+
# step 2.
|
120 |
+
# Deal with the 2nd, 3rd... notes of 'slur' case
|
121 |
+
# j ie ie
|
122 |
+
# F#4/Gb4 F#4/Gb4 C#4/Db4
|
123 |
+
# 0 0 1
|
124 |
+
if len(note_in_this_word) > 1: # is_slur = True, we should repeat the YUNMU to match the 2nd, 3rd... notes.
|
125 |
+
for idx in range(1, len(note_in_this_word)):
|
126 |
+
ph_lst.append(ph_in_this_word[-1])
|
127 |
+
note_lst.append(note_in_this_word[idx])
|
128 |
+
midi_dur_lst.append(midi_dur_in_this_word[idx])
|
129 |
+
is_slur.append(1)
|
130 |
+
ph_seq = ' '.join(ph_lst)
|
131 |
+
|
132 |
+
if len(ph_lst) == len(note_lst) == len(midi_dur_lst):
|
133 |
+
print(len(ph_lst), len(note_lst), len(midi_dur_lst))
|
134 |
+
print('Pass word-notes check.')
|
135 |
+
else:
|
136 |
+
print('The number of words does\'t match the number of notes\' windows. ',
|
137 |
+
'You should split the note(s) for each word by | mark.')
|
138 |
+
return None
|
139 |
+
return ph_seq, note_lst, midi_dur_lst, is_slur
|
140 |
+
|
141 |
+
def preprocess_phoneme_level_input(self, inp):
|
142 |
+
ph_seq = inp['ph_seq']
|
143 |
+
note_lst = inp['note_seq'].split()
|
144 |
+
midi_dur_lst = inp['note_dur_seq'].split()
|
145 |
+
is_slur = [float(x) for x in inp['is_slur_seq'].split()]
|
146 |
+
print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
|
147 |
+
if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
|
148 |
+
print('Pass word-notes check.')
|
149 |
+
else:
|
150 |
+
print('The number of words does\'t match the number of notes\' windows. ',
|
151 |
+
'You should split the note(s) for each word by | mark.')
|
152 |
+
return None
|
153 |
+
return ph_seq, note_lst, midi_dur_lst, is_slur
|
154 |
+
|
155 |
+
def preprocess_input(self, inp, input_type='word'):
|
156 |
+
"""
|
157 |
+
|
158 |
+
:param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)}
|
159 |
+
:return:
|
160 |
+
"""
|
161 |
+
|
162 |
+
item_name = inp.get('item_name', '<ITEM_NAME>')
|
163 |
+
spk_name = inp.get('spk_name', 'Alto-1')
|
164 |
+
|
165 |
+
# single spk
|
166 |
+
spk_id = self.spk_map[spk_name]
|
167 |
+
|
168 |
+
# get ph seq, note lst, midi dur lst, is slur lst.
|
169 |
+
if input_type == 'word':
|
170 |
+
ret = self.preprocess_word_level_input(inp)
|
171 |
+
elif input_type == 'phoneme':
|
172 |
+
ret = self.preprocess_phoneme_level_input(inp)
|
173 |
+
else:
|
174 |
+
print('Invalid input type.')
|
175 |
+
return None
|
176 |
+
|
177 |
+
if ret:
|
178 |
+
ph_seq, note_lst, midi_dur_lst, is_slur = ret
|
179 |
+
else:
|
180 |
+
print('==========> Preprocess_word_level or phone_level input wrong.')
|
181 |
+
return None
|
182 |
+
|
183 |
+
# convert note lst to midi id; convert note dur lst to midi duration
|
184 |
+
try:
|
185 |
+
midis = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0
|
186 |
+
for x in note_lst]
|
187 |
+
midi_dur_lst = [float(x) for x in midi_dur_lst]
|
188 |
+
except Exception as e:
|
189 |
+
print(e)
|
190 |
+
print('Invalid Input Type.')
|
191 |
+
return None
|
192 |
+
|
193 |
+
ph_token = self.ph_encoder.encode(ph_seq)
|
194 |
+
item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_id': spk_id,
|
195 |
+
'ph_token': ph_token, 'pitch_midi': np.asarray(midis), 'midi_dur': np.asarray(midi_dur_lst),
|
196 |
+
'is_slur': np.asarray(is_slur), }
|
197 |
+
item['ph_len'] = len(item['ph_token'])
|
198 |
+
return item
|
199 |
+
|
200 |
+
def input_to_batch(self, item):
|
201 |
+
item_names = [item['item_name']]
|
202 |
+
text = [item['text']]
|
203 |
+
ph = [item['ph']]
|
204 |
+
txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device)
|
205 |
+
txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device)
|
206 |
+
spk_ids = torch.LongTensor([item['spk_id']])[:].to(self.device)
|
207 |
+
|
208 |
+
pitch_midi = torch.LongTensor(item['pitch_midi'])[None, :hparams['max_frames']].to(self.device)
|
209 |
+
midi_dur = torch.FloatTensor(item['midi_dur'])[None, :hparams['max_frames']].to(self.device)
|
210 |
+
is_slur = torch.LongTensor(item['is_slur'])[None, :hparams['max_frames']].to(self.device)
|
211 |
+
|
212 |
+
batch = {
|
213 |
+
'item_name': item_names,
|
214 |
+
'text': text,
|
215 |
+
'ph': ph,
|
216 |
+
'txt_tokens': txt_tokens,
|
217 |
+
'txt_lengths': txt_lengths,
|
218 |
+
'spk_ids': spk_ids,
|
219 |
+
'pitch_midi': pitch_midi,
|
220 |
+
'midi_dur': midi_dur,
|
221 |
+
'is_slur': is_slur
|
222 |
+
}
|
223 |
+
return batch
|
224 |
+
|
225 |
+
def postprocess_output(self, output):
|
226 |
+
return output
|
227 |
+
|
228 |
+
def infer_once(self, inp):
|
229 |
+
inp = self.preprocess_input(inp, input_type=inp['input_type'] if inp.get('input_type') else 'word')
|
230 |
+
output = self.forward_model(inp)
|
231 |
+
output = self.postprocess_output(output)
|
232 |
+
return output
|
233 |
+
|
234 |
+
@classmethod
|
235 |
+
def example_run(cls, inp):
|
236 |
+
from utils.audio import save_wav
|
237 |
+
set_hparams(print_hparams=False)
|
238 |
+
infer_ins = cls(hparams)
|
239 |
+
out = infer_ins.infer_once(inp)
|
240 |
+
os.makedirs('infer_out', exist_ok=True)
|
241 |
+
f_name = inp['spk_name'] + ' | ' + inp['text']
|
242 |
+
save_wav(out, f'infer_out/{f_name}.wav', hparams['audio_sample_rate'])
|
inference/m4singer/ds_e2e.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
# from inference.tts.fs import FastSpeechInfer
|
3 |
+
# from modules.tts.fs2_orig import FastSpeech2Orig
|
4 |
+
from inference.m4singer.base_svs_infer import BaseSVSInfer
|
5 |
+
from utils import load_ckpt
|
6 |
+
from utils.hparams import hparams
|
7 |
+
from usr.diff.shallow_diffusion_tts import GaussianDiffusion
|
8 |
+
from usr.diffsinger_task import DIFF_DECODERS
|
9 |
+
from modules.fastspeech.pe import PitchExtractor
|
10 |
+
import utils
|
11 |
+
|
12 |
+
|
13 |
+
class DiffSingerE2EInfer(BaseSVSInfer):
|
14 |
+
def build_model(self):
|
15 |
+
model = GaussianDiffusion(
|
16 |
+
phone_encoder=self.ph_encoder,
|
17 |
+
out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
|
18 |
+
timesteps=hparams['timesteps'],
|
19 |
+
K_step=hparams['K_step'],
|
20 |
+
loss_type=hparams['diff_loss_type'],
|
21 |
+
spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
|
22 |
+
)
|
23 |
+
model.eval()
|
24 |
+
load_ckpt(model, hparams['work_dir'], 'model')
|
25 |
+
|
26 |
+
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
|
27 |
+
self.pe = PitchExtractor().to(self.device)
|
28 |
+
utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
|
29 |
+
self.pe.eval()
|
30 |
+
return model
|
31 |
+
|
32 |
+
def forward_model(self, inp):
|
33 |
+
sample = self.input_to_batch(inp)
|
34 |
+
txt_tokens = sample['txt_tokens'] # [B, T_t]
|
35 |
+
spk_id = sample.get('spk_ids')
|
36 |
+
with torch.no_grad():
|
37 |
+
output = self.model(txt_tokens, spk_embed=spk_id, ref_mels=None, infer=True,
|
38 |
+
pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
|
39 |
+
is_slur=sample['is_slur'])
|
40 |
+
mel_out = output['mel_out'] # [B, T,80]
|
41 |
+
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
|
42 |
+
f0_pred = self.pe(mel_out)['f0_denorm_pred'] # pe predict from Pred mel
|
43 |
+
else:
|
44 |
+
f0_pred = output['f0_denorm']
|
45 |
+
wav_out = self.run_vocoder(mel_out, f0=f0_pred)
|
46 |
+
wav_out = wav_out.cpu().numpy()
|
47 |
+
return wav_out[0]
|
48 |
+
|
49 |
+
if __name__ == '__main__':
|
50 |
+
inp = {
|
51 |
+
'spk_name': 'Tenor-1',
|
52 |
+
'text': 'AP你要相信AP相信我们会像童话故事里AP',
|
53 |
+
'notes': 'rest | G#3 | A#3 C4 | D#4 | D#4 F4 | rest | E4 F4 | F4 | D#4 A#3 | A#3 | A#3 | C#4 | B3 C4 | C#4 | B3 C4 | A#3 | G#3 | rest',
|
54 |
+
'notes_duration': '0.14 | 0.47 | 0.1905 0.1895 | 0.41 | 0.3005 0.3895 | 0.21 | 0.2391 0.1809 | 0.32 | 0.4105 0.2095 | 0.35 | 0.43 | 0.45 | 0.2309 0.2291 | 0.48 | 0.225 0.195 | 0.29 | 0.71 | 0.14',
|
55 |
+
'input_type': 'word',
|
56 |
+
}
|
57 |
+
|
58 |
+
c = {
|
59 |
+
'spk_name': 'Tenor-1',
|
60 |
+
'text': '你要相信相信我们会像童话故事里',
|
61 |
+
'ph_seq': '<AP> n i iao iao x iang x in in <AP> x iang iang x in uo uo m en h uei x iang t ong ong h ua g u u sh i l i <AP>',
|
62 |
+
'note_seq': 'rest G#3 G#3 A#3 C4 D#4 D#4 D#4 D#4 F4 rest E4 E4 F4 F4 F4 D#4 A#3 A#3 A#3 A#3 A#3 C#4 C#4 B3 B3 C4 C#4 C#4 B3 B3 C4 A#3 A#3 G#3 G#3 rest',
|
63 |
+
'note_dur_seq': '0.14 0.47 0.47 0.1905 0.1895 0.41 0.41 0.3005 0.3005 0.3895 0.21 0.2391 0.2391 0.1809 0.32 0.32 0.4105 0.2095 0.35 0.35 0.43 0.43 0.45 0.45 0.2309 0.2309 0.2291 0.48 0.48 0.225 0.225 0.195 0.29 0.29 0.71 0.71 0.14',
|
64 |
+
'is_slur_seq': '0 0 0 0 1 0 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0',
|
65 |
+
'input_type': 'phoneme'
|
66 |
+
}
|
67 |
+
DiffSingerE2EInfer.example_run(inp)
|
inference/m4singer/gradio/gradio_settings.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
title: 'M4Singer'
|
2 |
+
description: |
|
3 |
+
This page aims to display the singing voice synthesis function of M4Singer. SingerID can be switched freely to preview the timbre of each singer. Click examples below to quickly load scores and audio.
|
4 |
+
(本页面为M4Singer歌声合成功能展示。SingerID可以自由切换用以预览各歌手的音色。点击下方Examples可以快速加载乐谱和音频。)
|
5 |
+
|
6 |
+
Please assign pitch and duration values to each Chinese character. The corresponding pitch and duration value of each character should be separated by a | separator. It is necessary to ensure that the note window separated by the separator is consistent with the number of Chinese characters. AP (aspirate) or SP (silence) is also viewed as a Chinese character.
|
7 |
+
(请给每个汉字分配音高和时值, 每个字对应的音高和时值需要用 | 分隔符隔开。需要保证分隔符分割出来的音符窗口与汉字个数一致。换气或静音符也算一个汉字。)
|
8 |
+
|
9 |
+
The notes corresponding to AP and SP are fixed as rest. If there are multiple notes in a window (| .... |), it means that the Chinese character corresponding to the window is glissando, and each note needs to be assigned a duration.
|
10 |
+
(AP和SP对应的音符固定为rest。若一个窗口(| .... |)内有多个音符, 代表该窗口对应的汉字为滑音, 需要为每个音符都分配时长。)
|
11 |
+
|
12 |
+
article: |
|
13 |
+
Note: This page is running on CPU, please refer to <a href='https://github.com/M4Singer/M4Singer' style='color:blue;' target='_blank\'>Github REPO</a> for the local running solutions and for our dataset.
|
14 |
+
|
15 |
+
--------
|
16 |
+
If our work is useful for your research, please consider citing:
|
17 |
+
```bibtex
|
18 |
+
@inproceedings{
|
19 |
+
zhang2022msinger,
|
20 |
+
title={M4Singer: A Multi-Style, Multi-Singer and Musical Score Provided Mandarin Singing Corpus},
|
21 |
+
author={Lichao Zhang and Ruiqi Li and Shoutong Wang and Liqun Deng and Jinglin Liu and Yi Ren and Jinzheng He and Rongjie Huang and Jieming Zhu and Xiao Chen and Zhou Zhao},
|
22 |
+
booktitle={Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
|
23 |
+
year={2022},
|
24 |
+
}
|
25 |
+
```
|
26 |
+
|
27 |
+
![visitors](https://visitor-badge.laobi.icu/badge?page_id=zlc99/M4Singer)
|
28 |
+
example_inputs:
|
29 |
+
- |-
|
30 |
+
Tenor-1<sep>AP你要相信AP相信我们会像童话故事里AP<sep>rest | G#3 | A#3 C4 | D#4 | D#4 F4 | rest | E4 F4 | F4 | D#4 A#3 | A#3 | A#3 | C#4 | B3 C4 | C#4 | B3 C4 | A#3 | G#3 | rest<sep>0.14 | 0.47 | 0.1905 0.1895 | 0.41 | 0.3005 0.3895 | 0.21 | 0.2391 0.1809 | 0.32 | 0.4105 0.2095 | 0.35 | 0.43 | 0.45 | 0.2309 0.2291 | 0.48 | 0.225 0.195 | 0.29 | 0.71 | 0.14
|
31 |
+
- |-
|
32 |
+
Tenor-1<sep>AP因为在一千年以后AP世界早已没有我AP<sep>rest | C#4 | D4 | E4 | F#4 | E4 | D4 G#3 | A3 | D4 E4 | rest | F#4 | E4 | D4 | C#4 | B3 F#3 | F#3 | C4 C#4 | rest<sep>0.18 | 0.32 | 0.38 | 0.81 | 0.38 | 0.39 | 0.3155 0.2045 | 0.28 | 0.4609 1.0291 | 0.27 | 0.42 | 0.15 | 0.53 | 0.22 | 0.3059 0.2841 | 0.4 | 0.2909 1.1091 | 0.3
|
33 |
+
- |-
|
34 |
+
Tenor-2<sep>AP可是你在敲打AP我的窗棂AP<sep>rest | G#3 | B3 | B3 C#4 | E4 | C#4 B3 | G#3 | rest | C3 | E3 | B3 G#3 | F#3 | rest<sep>0.2 | 0.38 | 0.48 | 0.41 0.72 | 0.39 | 0.5195 0.2905 | 0.5 | 0.33 | 0.4 | 0.31 | 0.565 0.265 | 1.15 | 0.24
|
35 |
+
- |-
|
36 |
+
Tenor-2<sep>SP一杯敬朝阳一杯敬月光AP<sep>rest | G#3 | G#3 | G#3 | G3 | G3 G#3 | G3 | C4 | C4 | A#3 | C4 | rest<sep>0.33 | 0.26 | 0.23 | 0.27 | 0.36 | 0.3159 0.4041 | 0.54 | 0.21 | 0.32 | 0.24 | 0.58 | 0.17
|
37 |
+
- |-
|
38 |
+
Soprano-1<sep>SP乱石穿空AP惊涛拍岸AP<sep>rest | C#5 | D#5 | F5 D#5 | C#5 | rest | C#5 | C#5 | C#5 G#4 | G#4 | rest<sep>0.325 | 0.75 | 0.54 | 0.48 0.55 | 1.38 | 0.31 | 0.55 | 0.48 | 0.4891 0.4709 | 1.15 | 0.22
|
39 |
+
- |-
|
40 |
+
Soprano-1<sep>AP点点滴滴染绿了村寨AP<sep>rest | C5 | A#4 | C5 | D#5 F5 D#5 | D#5 | C5 | C5 | C5 | A#4 | rest<sep>0.175 | 0.24 | 0.26 | 1.08 | 0.3541 0.4364 0.2195 | 0.47 | 0.27 | 0.12 | 0.51 | 0.72 | 0.15
|
41 |
+
- |-
|
42 |
+
Alto-2<sep>AP拒绝声色的张扬AP不拒绝你AP<sep>rest | C4 | C4 | C4 | B3 A3 | C4 | C4 D4 | D4 | rest | D4 | D4 | C4 | G4 E4 | rest<sep>0.49 | 0.31 | 0.18 | 0.48 | 0.3 0.4 | 0.25 | 0.3591 0.2409 | 0.46 | 0.34 | 0.4 | 0.45 | 0.45 | 2.4545 0.9855 | 0.215
|
43 |
+
- |-
|
44 |
+
Alto-2<sep>AP半醒着AP笑着哭着都快活AP<sep>rest | D4 | B3 | C4 D4 | rest | E4 | D4 | E4 | D4 | E4 | E4 F#4 | F4 F#4 | rest<sep>0.165 | 0.45 | 0.53 | 0.3859 0.2441 | 0.35 | 0.38 | 0.17 | 0.32 | 0.26 | 0.33 | 0.38 0.21 | 0.3309 0.9491 | 0.125
|
45 |
+
|
46 |
+
|
47 |
+
inference_cls: inference.m4singer.ds_e2e.DiffSingerE2EInfer
|
48 |
+
exp_name: m4singer_diff_e2e
|
inference/m4singer/gradio/infer.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import re
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import yaml
|
6 |
+
from gradio.components import Textbox, Dropdown
|
7 |
+
|
8 |
+
from inference.m4singer.base_svs_infer import BaseSVSInfer
|
9 |
+
from utils.hparams import set_hparams
|
10 |
+
from utils.hparams import hparams as hp
|
11 |
+
import numpy as np
|
12 |
+
from inference.m4singer.gradio.share_btn import community_icon_html, loading_icon_html, share_js
|
13 |
+
|
14 |
+
class GradioInfer:
|
15 |
+
def __init__(self, exp_name, inference_cls, title, description, article, example_inputs):
|
16 |
+
self.exp_name = exp_name
|
17 |
+
self.title = title
|
18 |
+
self.description = description
|
19 |
+
self.article = article
|
20 |
+
self.example_inputs = example_inputs
|
21 |
+
pkg = ".".join(inference_cls.split(".")[:-1])
|
22 |
+
cls_name = inference_cls.split(".")[-1]
|
23 |
+
self.inference_cls = getattr(importlib.import_module(pkg), cls_name)
|
24 |
+
|
25 |
+
def greet(self, singer, text, notes, notes_duration):
|
26 |
+
PUNCS = '。?;:'
|
27 |
+
sents = re.split(rf'([{PUNCS}])', text.replace('\n', ','))
|
28 |
+
sents_notes = re.split(rf'([{PUNCS}])', notes.replace('\n', ','))
|
29 |
+
sents_notes_dur = re.split(rf'([{PUNCS}])', notes_duration.replace('\n', ','))
|
30 |
+
|
31 |
+
if sents[-1] not in list(PUNCS):
|
32 |
+
sents = sents + ['']
|
33 |
+
sents_notes = sents_notes + ['']
|
34 |
+
sents_notes_dur = sents_notes_dur + ['']
|
35 |
+
|
36 |
+
audio_outs = []
|
37 |
+
s, n, n_dur = "", "", ""
|
38 |
+
for i in range(0, len(sents), 2):
|
39 |
+
if len(sents[i]) > 0:
|
40 |
+
s += sents[i] + sents[i + 1]
|
41 |
+
n += sents_notes[i] + sents_notes[i+1]
|
42 |
+
n_dur += sents_notes_dur[i] + sents_notes_dur[i+1]
|
43 |
+
if len(s) >= 400 or (i >= len(sents) - 2 and len(s) > 0):
|
44 |
+
audio_out = self.infer_ins.infer_once({
|
45 |
+
'spk_name': singer,
|
46 |
+
'text': s,
|
47 |
+
'notes': n,
|
48 |
+
'notes_duration': n_dur,
|
49 |
+
})
|
50 |
+
audio_out = audio_out * 32767
|
51 |
+
audio_out = audio_out.astype(np.int16)
|
52 |
+
audio_outs.append(audio_out)
|
53 |
+
audio_outs.append(np.zeros(int(hp['audio_sample_rate'] * 0.3)).astype(np.int16))
|
54 |
+
s = ""
|
55 |
+
n = ""
|
56 |
+
audio_outs = np.concatenate(audio_outs)
|
57 |
+
return (hp['audio_sample_rate'], audio_outs), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
58 |
+
|
59 |
+
def run(self):
|
60 |
+
set_hparams(config=f'checkpoints/{self.exp_name}/config.yaml', exp_name=self.exp_name, print_hparams=False)
|
61 |
+
infer_cls = self.inference_cls
|
62 |
+
self.infer_ins: BaseSVSInfer = infer_cls(hp)
|
63 |
+
example_inputs = self.example_inputs
|
64 |
+
for i in range(len(example_inputs)):
|
65 |
+
singer, text, notes, notes_dur = example_inputs[i].split('<sep>')
|
66 |
+
example_inputs[i] = [singer, text, notes, notes_dur]
|
67 |
+
|
68 |
+
singerList = \
|
69 |
+
[
|
70 |
+
'Tenor-1', 'Tenor-2', 'Tenor-3', 'Tenor-4', 'Tenor-5', 'Tenor-6', 'Tenor-7',
|
71 |
+
'Alto-1', 'Alto-2', 'Alto-3', 'Alto-4', 'Alto-5', 'Alto-6', 'Alto-7',
|
72 |
+
'Soprano-1', 'Soprano-2', 'Soprano-3',
|
73 |
+
'Bass-1', 'Bass-2', 'Bass-3',
|
74 |
+
]
|
75 |
+
|
76 |
+
css = """
|
77 |
+
#share-btn-container {
|
78 |
+
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
|
79 |
+
}
|
80 |
+
#share-btn {
|
81 |
+
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0;
|
82 |
+
}
|
83 |
+
#share-btn * {
|
84 |
+
all: unset;
|
85 |
+
}
|
86 |
+
#share-btn-container div:nth-child(-n+2){
|
87 |
+
width: auto !important;
|
88 |
+
min-height: 0px !important;
|
89 |
+
}
|
90 |
+
#share-btn-container .wrap {
|
91 |
+
display: none !important;
|
92 |
+
}
|
93 |
+
"""
|
94 |
+
with gr.Blocks(css=css) as demo:
|
95 |
+
gr.HTML("""<div style="text-align: center; margin: 0 auto;">
|
96 |
+
<div
|
97 |
+
style="
|
98 |
+
display: inline-flex;
|
99 |
+
align-items: center;
|
100 |
+
gap: 0.8rem;
|
101 |
+
font-size: 1.75rem;
|
102 |
+
"
|
103 |
+
>
|
104 |
+
<h1 style="font-weight: 900; margin-bottom: 10px; margin-top: 14px;">
|
105 |
+
M4Singer
|
106 |
+
</h1>
|
107 |
+
</div>
|
108 |
+
</div>
|
109 |
+
"""
|
110 |
+
)
|
111 |
+
gr.Markdown(self.description)
|
112 |
+
with gr.Row():
|
113 |
+
with gr.Column():
|
114 |
+
singer_l = Dropdown(choices=singerList, value=example_inputs[0][0], label="SingerID", elem_id="inp_singer")
|
115 |
+
inp_text = Textbox(lines=2, placeholder=None, value=example_inputs[0][1], label="input text", elem_id="inp_text")
|
116 |
+
inp_note = Textbox(lines=2, placeholder=None, value=example_inputs[0][2], label="input note", elem_id="inp_note")
|
117 |
+
inp_duration = Textbox(lines=2, placeholder=None, value=example_inputs[0][3], label="input duration", elem_id="inp_duration")
|
118 |
+
generate = gr.Button("Generate Singing Voice from Musical Score")
|
119 |
+
with gr.Column(lem_id="col-container"):
|
120 |
+
singing_output = gr.Audio(label="Result", type="numpy", elem_id="music-output")
|
121 |
+
|
122 |
+
with gr.Group(elem_id="share-btn-container"):
|
123 |
+
community_icon = gr.HTML(community_icon_html, visible=False)
|
124 |
+
loading_icon = gr.HTML(loading_icon_html, visible=False)
|
125 |
+
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
|
126 |
+
gr.Examples(examples=self.example_inputs,
|
127 |
+
inputs=[singer_l, inp_text, inp_note, inp_duration],
|
128 |
+
outputs=[singing_output, share_button, community_icon, loading_icon],
|
129 |
+
fn=self.greet,
|
130 |
+
cache_examples=True)
|
131 |
+
gr.Markdown(self.article)
|
132 |
+
generate.click(self.greet,
|
133 |
+
inputs=[singer_l, inp_text, inp_note, inp_duration],
|
134 |
+
outputs=[singing_output, share_button, community_icon, loading_icon],)
|
135 |
+
share_button.click(None, [], [], _js=share_js)
|
136 |
+
demo.queue().launch(share=False)
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == '__main__':
|
140 |
+
gradio_config = yaml.safe_load(open('inference/m4singer/gradio/gradio_settings.yaml'))
|
141 |
+
g = GradioInfer(**gradio_config)
|
142 |
+
g.run()
|
143 |
+
|
inference/m4singer/gradio/share_btn.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
community_icon_html = """<svg id="share-btn-share-icon" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32">
|
2 |
+
<path d="M20.6081 3C21.7684 3 22.8053 3.49196 23.5284 4.38415C23.9756 4.93678 24.4428 5.82749 24.4808 7.16133C24.9674 7.01707 25.4353 6.93643 25.8725 6.93643C26.9833 6.93643 27.9865 7.37587 28.696 8.17411C29.6075 9.19872 30.0124 10.4579 29.8361 11.7177C29.7523 12.3177 29.5581 12.8555 29.2678 13.3534C29.8798 13.8646 30.3306 14.5763 30.5485 15.4322C30.719 16.1032 30.8939 17.5006 29.9808 18.9403C30.0389 19.0342 30.0934 19.1319 30.1442 19.2318C30.6932 20.3074 30.7283 21.5229 30.2439 22.6548C29.5093 24.3704 27.6841 25.7219 24.1397 27.1727C21.9347 28.0753 19.9174 28.6523 19.8994 28.6575C16.9842 29.4379 14.3477 29.8345 12.0653 29.8345C7.87017 29.8345 4.8668 28.508 3.13831 25.8921C0.356375 21.6797 0.754104 17.8269 4.35369 14.1131C6.34591 12.058 7.67023 9.02782 7.94613 8.36275C8.50224 6.39343 9.97271 4.20438 12.4172 4.20438H12.4179C12.6236 4.20438 12.8314 4.2214 13.0364 4.25468C14.107 4.42854 15.0428 5.06476 15.7115 6.02205C16.4331 5.09583 17.134 4.359 17.7682 3.94323C18.7242 3.31737 19.6794 3 20.6081 3ZM20.6081 5.95917C20.2427 5.95917 19.7963 6.1197 19.3039 6.44225C17.7754 7.44319 14.8258 12.6772 13.7458 14.7131C13.3839 15.3952 12.7655 15.6837 12.2086 15.6837C11.1036 15.6837 10.2408 14.5497 12.1076 13.1085C14.9146 10.9402 13.9299 7.39584 12.5898 7.1776C12.5311 7.16799 12.4731 7.16355 12.4172 7.16355C11.1989 7.16355 10.6615 9.33114 10.6615 9.33114C10.6615 9.33114 9.0863 13.4148 6.38031 16.206C3.67434 18.998 3.5346 21.2388 5.50675 24.2246C6.85185 26.2606 9.42666 26.8753 12.0653 26.8753C14.8021 26.8753 17.6077 26.2139 19.1799 25.793C19.2574 25.7723 28.8193 22.984 27.6081 20.6107C27.4046 20.212 27.0693 20.0522 26.6471 20.0522C24.9416 20.0522 21.8393 22.6726 20.5057 22.6726C20.2076 22.6726 19.9976 22.5416 19.9116 22.222C19.3433 20.1173 28.552 19.2325 27.7758 16.1839C27.639 15.6445 27.2677 15.4256 26.746 15.4263C24.4923 15.4263 19.4358 19.5181 18.3759 19.5181C18.2949 19.5181 18.2368 19.4937 18.2053 19.4419C17.6743 18.557 17.9653 17.9394 21.7082 15.6009C25.4511 13.2617 28.0783 11.8545 26.5841 10.1752C26.4121 9.98141 26.1684 9.8956 25.8725 9.8956C23.6001 9.89634 18.2311 14.9403 18.2311 14.9403C18.2311 14.9403 16.7821 16.496 15.9057 16.496C15.7043 16.496 15.533 16.4139 15.4169 16.2112C14.7956 15.1296 21.1879 10.1286 21.5484 8.06535C21.7928 6.66715 21.3771 5.95917 20.6081 5.95917Z" fill="#FF9D00"></path>
|
3 |
+
<path d="M5.50686 24.2246C3.53472 21.2387 3.67446 18.9979 6.38043 16.206C9.08641 13.4147 10.6615 9.33111 10.6615 9.33111C10.6615 9.33111 11.2499 6.95933 12.59 7.17757C13.93 7.39581 14.9139 10.9401 12.1069 13.1084C9.29997 15.276 12.6659 16.7489 13.7459 14.713C14.8258 12.6772 17.7747 7.44316 19.304 6.44221C20.8326 5.44128 21.9089 6.00204 21.5484 8.06532C21.188 10.1286 14.795 15.1295 15.4171 16.2118C16.0391 17.2934 18.2312 14.9402 18.2312 14.9402C18.2312 14.9402 25.0907 8.49588 26.5842 10.1752C28.0776 11.8545 25.4512 13.2616 21.7082 15.6008C17.9646 17.9393 17.6744 18.557 18.2054 19.4418C18.7372 20.3266 26.9998 13.1351 27.7759 16.1838C28.5513 19.2324 19.3434 20.1173 19.9117 22.2219C20.48 24.3274 26.3979 18.2382 27.6082 20.6107C28.8193 22.9839 19.2574 25.7722 19.18 25.7929C16.0914 26.62 8.24723 28.3726 5.50686 24.2246Z" fill="#FFD21E"></path>
|
4 |
+
</svg>"""
|
5 |
+
|
6 |
+
loading_icon_html = """<svg id="share-btn-loading-icon" style="display:none;" class="animate-spin"
|
7 |
+
style="color: #ffffff;
|
8 |
+
"
|
9 |
+
xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" fill="none" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 24 24"><circle style="opacity: 0.25;" cx="12" cy="12" r="10" stroke="white" stroke-width="4"></circle><path style="opacity: 0.75;" fill="white" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>"""
|
10 |
+
|
11 |
+
share_js = """async () => {
|
12 |
+
async function uploadFile(file){
|
13 |
+
const UPLOAD_URL = 'https://huggingface.co/uploads';
|
14 |
+
const response = await fetch(UPLOAD_URL, {
|
15 |
+
method: 'POST',
|
16 |
+
headers: {
|
17 |
+
'Content-Type': file.type,
|
18 |
+
'X-Requested-With': 'XMLHttpRequest',
|
19 |
+
},
|
20 |
+
body: file, /// <- File inherits from Blob
|
21 |
+
});
|
22 |
+
const url = await response.text();
|
23 |
+
return url;
|
24 |
+
}
|
25 |
+
|
26 |
+
async function getOutputMusicFile(audioEL){
|
27 |
+
const res = await fetch(audioEL.src);
|
28 |
+
const blob = await res.blob();
|
29 |
+
const audioId = Date.now() % 200;
|
30 |
+
const fileName = `SVS-${{audioId}}.wav`;
|
31 |
+
const musicBlob = new File([blob], fileName, { type: 'audio/wav' });
|
32 |
+
return musicBlob;
|
33 |
+
}
|
34 |
+
|
35 |
+
const gradioEl = document.querySelector('body > gradio-app');
|
36 |
+
|
37 |
+
//const gradioEl = document.querySelector("gradio-app").shadowRoot;
|
38 |
+
const inputSinger = gradioEl.querySelector('#inp_singer select').value;
|
39 |
+
const inputText = gradioEl.querySelector('#inp_text textarea').value;
|
40 |
+
const inputNote = gradioEl.querySelector('#inp_note textarea').value;
|
41 |
+
const inputDuration = gradioEl.querySelector('#inp_duration textarea').value;
|
42 |
+
const outputMusic = gradioEl.querySelector('#music-output audio');
|
43 |
+
const outputMusic_src = gradioEl.querySelector('#music-output audio').src;
|
44 |
+
|
45 |
+
const outputMusic_name = outputMusic_src.split('/').pop();
|
46 |
+
let titleTxt = outputMusic_name;
|
47 |
+
if(titleTxt.length > 30){
|
48 |
+
titleTxt = 'demo';
|
49 |
+
}
|
50 |
+
const shareBtnEl = gradioEl.querySelector('#share-btn');
|
51 |
+
const shareIconEl = gradioEl.querySelector('#share-btn-share-icon');
|
52 |
+
const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon');
|
53 |
+
if(!outputMusic){
|
54 |
+
return;
|
55 |
+
};
|
56 |
+
shareBtnEl.style.pointerEvents = 'none';
|
57 |
+
shareIconEl.style.display = 'none';
|
58 |
+
loadingIconEl.style.removeProperty('display');
|
59 |
+
const musicFile = await getOutputMusicFile(outputMusic);
|
60 |
+
const dataOutputMusic = await uploadFile(musicFile);
|
61 |
+
const descriptionMd = `#### Input Musical Score:
|
62 |
+
${inputSinger}
|
63 |
+
|
64 |
+
${inputText}
|
65 |
+
|
66 |
+
${inputNote}
|
67 |
+
|
68 |
+
${inputDuration}
|
69 |
+
|
70 |
+
#### Singing Voice:
|
71 |
+
|
72 |
+
<audio controls>
|
73 |
+
<source src="${dataOutputMusic}" type="audio/wav">
|
74 |
+
Your browser does not support the audio element.
|
75 |
+
</audio>
|
76 |
+
`;
|
77 |
+
const params = new URLSearchParams({
|
78 |
+
title: titleTxt,
|
79 |
+
description: descriptionMd,
|
80 |
+
});
|
81 |
+
const paramsStr = params.toString();
|
82 |
+
window.open(`https://huggingface.co/spaces/zlc99/M4Singer/discussions/new?${paramsStr}`, '_blank');
|
83 |
+
shareBtnEl.style.removeProperty('pointer-events');
|
84 |
+
shareIconEl.style.removeProperty('display');
|
85 |
+
loadingIconEl.style.display = 'none';
|
86 |
+
}"""
|
inference/m4singer/m4singer/m4singer_pinyin2ph.txt
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
| a | a |
|
2 |
+
| ai | ai |
|
3 |
+
| an | an |
|
4 |
+
| ang | ang |
|
5 |
+
| ao | ao |
|
6 |
+
| ba | b a |
|
7 |
+
| bai | b ai |
|
8 |
+
| ban | b an |
|
9 |
+
| bang | b ang |
|
10 |
+
| bao | b ao |
|
11 |
+
| bei | b ei |
|
12 |
+
| ben | b en |
|
13 |
+
| beng | b eng |
|
14 |
+
| bi | b i |
|
15 |
+
| bian | b ian |
|
16 |
+
| biao | b iao |
|
17 |
+
| bie | b ie |
|
18 |
+
| bin | b in |
|
19 |
+
| bing | b ing |
|
20 |
+
| bo | b o |
|
21 |
+
| bu | b u |
|
22 |
+
| ca | c a |
|
23 |
+
| cai | c ai |
|
24 |
+
| can | c an |
|
25 |
+
| cang | c ang |
|
26 |
+
| cao | c ao |
|
27 |
+
| ce | c e |
|
28 |
+
| cei | c ei |
|
29 |
+
| cen | c en |
|
30 |
+
| ceng | c eng |
|
31 |
+
| cha | ch a |
|
32 |
+
| chai | ch ai |
|
33 |
+
| chan | ch an |
|
34 |
+
| chang | ch ang |
|
35 |
+
| chao | ch ao |
|
36 |
+
| che | ch e |
|
37 |
+
| chen | ch en |
|
38 |
+
| cheng | ch eng |
|
39 |
+
| chi | ch i |
|
40 |
+
| chong | ch ong |
|
41 |
+
| chou | ch ou |
|
42 |
+
| chu | ch u |
|
43 |
+
| chua | ch ua |
|
44 |
+
| chuai | ch uai |
|
45 |
+
| chuan | ch uan |
|
46 |
+
| chuang | ch uang |
|
47 |
+
| chui | ch uei |
|
48 |
+
| chun | ch uen |
|
49 |
+
| chuo | ch uo |
|
50 |
+
| ci | c i |
|
51 |
+
| cong | c ong |
|
52 |
+
| cou | c ou |
|
53 |
+
| cu | c u |
|
54 |
+
| cuan | c uan |
|
55 |
+
| cui | c uei |
|
56 |
+
| cun | c uen |
|
57 |
+
| cuo | c uo |
|
58 |
+
| da | d a |
|
59 |
+
| dai | d ai |
|
60 |
+
| dan | d an |
|
61 |
+
| dang | d ang |
|
62 |
+
| dao | d ao |
|
63 |
+
| de | d e |
|
64 |
+
| dei | d ei |
|
65 |
+
| den | d en |
|
66 |
+
| deng | d eng |
|
67 |
+
| di | d i |
|
68 |
+
| dia | d ia |
|
69 |
+
| dian | d ian |
|
70 |
+
| diao | d iao |
|
71 |
+
| die | d ie |
|
72 |
+
| ding | d ing |
|
73 |
+
| diu | d iou |
|
74 |
+
| dong | d ong |
|
75 |
+
| dou | d ou |
|
76 |
+
| du | d u |
|
77 |
+
| duan | d uan |
|
78 |
+
| dui | d uei |
|
79 |
+
| dun | d uen |
|
80 |
+
| duo | d uo |
|
81 |
+
| e | e |
|
82 |
+
| ei | ei |
|
83 |
+
| en | en |
|
84 |
+
| eng | eng |
|
85 |
+
| er | er |
|
86 |
+
| fa | f a |
|
87 |
+
| fan | f an |
|
88 |
+
| fang | f ang |
|
89 |
+
| fei | f ei |
|
90 |
+
| fen | f en |
|
91 |
+
| feng | f eng |
|
92 |
+
| fo | f o |
|
93 |
+
| fou | f ou |
|
94 |
+
| fu | f u |
|
95 |
+
| ga | g a |
|
96 |
+
| gai | g ai |
|
97 |
+
| gan | g an |
|
98 |
+
| gang | g ang |
|
99 |
+
| gao | g ao |
|
100 |
+
| ge | g e |
|
101 |
+
| gei | g ei |
|
102 |
+
| gen | g en |
|
103 |
+
| geng | g eng |
|
104 |
+
| gong | g ong |
|
105 |
+
| gou | g ou |
|
106 |
+
| gu | g u |
|
107 |
+
| gua | g ua |
|
108 |
+
| guai | g uai |
|
109 |
+
| guan | g uan |
|
110 |
+
| guang | g uang |
|
111 |
+
| gui | g uei |
|
112 |
+
| gun | g uen |
|
113 |
+
| guo | g uo |
|
114 |
+
| ha | h a |
|
115 |
+
| hai | h ai |
|
116 |
+
| han | h an |
|
117 |
+
| hang | h ang |
|
118 |
+
| hao | h ao |
|
119 |
+
| he | h e |
|
120 |
+
| hei | h ei |
|
121 |
+
| hen | h en |
|
122 |
+
| heng | h eng |
|
123 |
+
| hong | h ong |
|
124 |
+
| hou | h ou |
|
125 |
+
| hu | h u |
|
126 |
+
| hua | h ua |
|
127 |
+
| huai | h uai |
|
128 |
+
| huan | h uan |
|
129 |
+
| huang | h uang |
|
130 |
+
| hui | h uei |
|
131 |
+
| hun | h uen |
|
132 |
+
| huo | h uo |
|
133 |
+
| ji | j i |
|
134 |
+
| jia | j ia |
|
135 |
+
| jian | j ian |
|
136 |
+
| jiang | j iang |
|
137 |
+
| jiao | j iao |
|
138 |
+
| jie | j ie |
|
139 |
+
| jin | j in |
|
140 |
+
| jing | j ing |
|
141 |
+
| jiong | j iong |
|
142 |
+
| jiu | j iou |
|
143 |
+
| ju | j v |
|
144 |
+
| juan | j van |
|
145 |
+
| jue | j ve |
|
146 |
+
| jun | j vn |
|
147 |
+
| ka | k a |
|
148 |
+
| kai | k ai |
|
149 |
+
| kan | k an |
|
150 |
+
| kang | k ang |
|
151 |
+
| kao | k ao |
|
152 |
+
| ke | k e |
|
153 |
+
| kei | k ei |
|
154 |
+
| ken | k en |
|
155 |
+
| keng | k eng |
|
156 |
+
| kong | k ong |
|
157 |
+
| kou | k ou |
|
158 |
+
| ku | k u |
|
159 |
+
| kua | k ua |
|
160 |
+
| kuai | k uai |
|
161 |
+
| kuan | k uan |
|
162 |
+
| kuang | k uang |
|
163 |
+
| kui | k uei |
|
164 |
+
| kun | k uen |
|
165 |
+
| kuo | k uo |
|
166 |
+
| la | l a |
|
167 |
+
| lai | l ai |
|
168 |
+
| lan | l an |
|
169 |
+
| lang | l ang |
|
170 |
+
| lao | l ao |
|
171 |
+
| le | l e |
|
172 |
+
| lei | l ei |
|
173 |
+
| leng | l eng |
|
174 |
+
| li | l i |
|
175 |
+
| lia | l ia |
|
176 |
+
| lian | l ian |
|
177 |
+
| liang | l iang |
|
178 |
+
| liao | l iao |
|
179 |
+
| lie | l ie |
|
180 |
+
| lin | l in |
|
181 |
+
| ling | l ing |
|
182 |
+
| liu | l iou |
|
183 |
+
| lo | l o |
|
184 |
+
| long | l ong |
|
185 |
+
| lou | l ou |
|
186 |
+
| lu | l u |
|
187 |
+
| luan | l uan |
|
188 |
+
| lun | l uen |
|
189 |
+
| luo | l uo |
|
190 |
+
| lv | l v |
|
191 |
+
| lve | l ve |
|
192 |
+
| m | m |
|
193 |
+
| ma | m a |
|
194 |
+
| mai | m ai |
|
195 |
+
| man | m an |
|
196 |
+
| mang | m ang |
|
197 |
+
| mao | m ao |
|
198 |
+
| me | m e |
|
199 |
+
| mei | m ei |
|
200 |
+
| men | m en |
|
201 |
+
| meng | m eng |
|
202 |
+
| mi | m i |
|
203 |
+
| mian | m ian |
|
204 |
+
| miao | m iao |
|
205 |
+
| mie | m ie |
|
206 |
+
| min | m in |
|
207 |
+
| ming | m ing |
|
208 |
+
| miu | m iou |
|
209 |
+
| mo | m o |
|
210 |
+
| mou | m ou |
|
211 |
+
| mu | m u |
|
212 |
+
| n | n |
|
213 |
+
| na | n a |
|
214 |
+
| nai | n ai |
|
215 |
+
| nan | n an |
|
216 |
+
| nang | n ang |
|
217 |
+
| nao | n ao |
|
218 |
+
| ne | n e |
|
219 |
+
| nei | n ei |
|
220 |
+
| nen | n en |
|
221 |
+
| neng | n eng |
|
222 |
+
| ni | n i |
|
223 |
+
| nian | n ian |
|
224 |
+
| niang | n iang |
|
225 |
+
| niao | n iao |
|
226 |
+
| nie | n ie |
|
227 |
+
| nin | n in |
|
228 |
+
| ning | n ing |
|
229 |
+
| niu | n iou |
|
230 |
+
| nong | n ong |
|
231 |
+
| nou | n ou |
|
232 |
+
| nu | n u |
|
233 |
+
| nuan | n uan |
|
234 |
+
| nuo | n uo |
|
235 |
+
| nv | n v |
|
236 |
+
| nve | n ve |
|
237 |
+
| o | o |
|
238 |
+
| ou | ou |
|
239 |
+
| pa | p a |
|
240 |
+
| pai | p ai |
|
241 |
+
| pan | p an |
|
242 |
+
| pang | p ang |
|
243 |
+
| pao | p ao |
|
244 |
+
| pei | p ei |
|
245 |
+
| pen | p en |
|
246 |
+
| peng | p eng |
|
247 |
+
| pi | p i |
|
248 |
+
| pian | p ian |
|
249 |
+
| piao | p iao |
|
250 |
+
| pie | p ie |
|
251 |
+
| pin | p in |
|
252 |
+
| ping | p ing |
|
253 |
+
| po | p o |
|
254 |
+
| pou | p ou |
|
255 |
+
| pu | p u |
|
256 |
+
| qi | q i |
|
257 |
+
| qia | q ia |
|
258 |
+
| qian | q ian |
|
259 |
+
| qiang | q iang |
|
260 |
+
| qiao | q iao |
|
261 |
+
| qie | q ie |
|
262 |
+
| qin | q in |
|
263 |
+
| qing | q ing |
|
264 |
+
| qiong | q iong |
|
265 |
+
| qiu | q iou |
|
266 |
+
| qu | q v |
|
267 |
+
| quan | q van |
|
268 |
+
| que | q ve |
|
269 |
+
| qun | q vn |
|
270 |
+
| ran | r an |
|
271 |
+
| rang | r ang |
|
272 |
+
| rao | r ao |
|
273 |
+
| re | r e |
|
274 |
+
| ren | r en |
|
275 |
+
| reng | r eng |
|
276 |
+
| ri | r i |
|
277 |
+
| rong | r ong |
|
278 |
+
| rou | r ou |
|
279 |
+
| ru | r u |
|
280 |
+
| rua | r ua |
|
281 |
+
| ruan | r uan |
|
282 |
+
| rui | r uei |
|
283 |
+
| run | r uen |
|
284 |
+
| ruo | r uo |
|
285 |
+
| sa | s a |
|
286 |
+
| sai | s ai |
|
287 |
+
| san | s an |
|
288 |
+
| sang | s ang |
|
289 |
+
| sao | s ao |
|
290 |
+
| se | s e |
|
291 |
+
| sen | s en |
|
292 |
+
| seng | s eng |
|
293 |
+
| sha | sh a |
|
294 |
+
| shai | sh ai |
|
295 |
+
| shan | sh an |
|
296 |
+
| shang | sh ang |
|
297 |
+
| shao | sh ao |
|
298 |
+
| she | sh e |
|
299 |
+
| shei | sh ei |
|
300 |
+
| shen | sh en |
|
301 |
+
| sheng | sh eng |
|
302 |
+
| shi | sh i |
|
303 |
+
| shou | sh ou |
|
304 |
+
| shu | sh u |
|
305 |
+
| shua | sh ua |
|
306 |
+
| shuai | sh uai |
|
307 |
+
| shuan | sh uan |
|
308 |
+
| shuang | sh uang |
|
309 |
+
| shui | sh uei |
|
310 |
+
| shun | sh uen |
|
311 |
+
| shuo | sh uo |
|
312 |
+
| si | s i |
|
313 |
+
| song | s ong |
|
314 |
+
| sou | s ou |
|
315 |
+
| su | s u |
|
316 |
+
| suan | s uan |
|
317 |
+
| sui | s uei |
|
318 |
+
| sun | s uen |
|
319 |
+
| suo | s uo |
|
320 |
+
| ta | t a |
|
321 |
+
| tai | t ai |
|
322 |
+
| tan | t an |
|
323 |
+
| tang | t ang |
|
324 |
+
| tao | t ao |
|
325 |
+
| te | t e |
|
326 |
+
| tei | t ei |
|
327 |
+
| teng | t eng |
|
328 |
+
| ti | t i |
|
329 |
+
| tian | t ian |
|
330 |
+
| tiao | t iao |
|
331 |
+
| tie | t ie |
|
332 |
+
| ting | t ing |
|
333 |
+
| tong | t ong |
|
334 |
+
| tou | t ou |
|
335 |
+
| tu | t u |
|
336 |
+
| tuan | t uan |
|
337 |
+
| tui | t uei |
|
338 |
+
| tun | t uen |
|
339 |
+
| tuo | t uo |
|
340 |
+
| wa | ua |
|
341 |
+
| wai | uai |
|
342 |
+
| wan | uan |
|
343 |
+
| wang | uang |
|
344 |
+
| wei | uei |
|
345 |
+
| wen | uen |
|
346 |
+
| weng | ueng |
|
347 |
+
| wo | uo |
|
348 |
+
| wu | u |
|
349 |
+
| xi | x i |
|
350 |
+
| xia | x ia |
|
351 |
+
| xian | x ian |
|
352 |
+
| xiang | x iang |
|
353 |
+
| xiao | x iao |
|
354 |
+
| xie | x ie |
|
355 |
+
| xin | x in |
|
356 |
+
| xing | x ing |
|
357 |
+
| xiong | x iong |
|
358 |
+
| xiu | x iou |
|
359 |
+
| xu | x v |
|
360 |
+
| xuan | x van |
|
361 |
+
| xue | x ve |
|
362 |
+
| xun | x vn |
|
363 |
+
| ya | ia |
|
364 |
+
| yan | ian |
|
365 |
+
| yang | iang |
|
366 |
+
| yao | iao |
|
367 |
+
| ye | ie |
|
368 |
+
| yi | i |
|
369 |
+
| yin | in |
|
370 |
+
| ying | ing |
|
371 |
+
| yong | iong |
|
372 |
+
| you | iou |
|
373 |
+
| yu | v |
|
374 |
+
| yuan | van |
|
375 |
+
| yue | ve |
|
376 |
+
| yun | vn |
|
377 |
+
| za | z a |
|
378 |
+
| zai | z ai |
|
379 |
+
| zan | z an |
|
380 |
+
| zang | z ang |
|
381 |
+
| zao | z ao |
|
382 |
+
| ze | z e |
|
383 |
+
| zei | z ei |
|
384 |
+
| zen | z en |
|
385 |
+
| zeng | z eng |
|
386 |
+
| zha | zh a |
|
387 |
+
| zhai | zh ai |
|
388 |
+
| zhan | zh an |
|
389 |
+
| zhang | zh ang |
|
390 |
+
| zhao | zh ao |
|
391 |
+
| zhe | zh e |
|
392 |
+
| zhei | zh ei |
|
393 |
+
| zhen | zh en |
|
394 |
+
| zheng | zh eng |
|
395 |
+
| zhi | zh i |
|
396 |
+
| zhong | zh ong |
|
397 |
+
| zhou | zh ou |
|
398 |
+
| zhu | zh u |
|
399 |
+
| zhua | zh ua |
|
400 |
+
| zhuai | zh uai |
|
401 |
+
| zhuan | zh uan |
|
402 |
+
| zhuang | zh uang |
|
403 |
+
| zhui | zh uei |
|
404 |
+
| zhun | zh uen |
|
405 |
+
| zhuo | zh uo |
|
406 |
+
| zi | z i |
|
407 |
+
| zong | z ong |
|
408 |
+
| zou | z ou |
|
409 |
+
| zu | z u |
|
410 |
+
| zuan | z uan |
|
411 |
+
| zui | z uei |
|
412 |
+
| zun | z uen |
|
413 |
+
| zuo | z uo |
|
inference/m4singer/m4singer/map.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def m4singer_pinyin2ph_func():
|
2 |
+
pinyin2phs = {'AP': '<AP>', 'SP': '<SP>'}
|
3 |
+
with open('inference/m4singer/m4singer/m4singer_pinyin2ph.txt') as rf:
|
4 |
+
for line in rf.readlines():
|
5 |
+
elements = [x.strip() for x in line.split('|') if x.strip() != '']
|
6 |
+
pinyin2phs[elements[0]] = elements[1]
|
7 |
+
return pinyin2phs
|
modules/__init__.py
ADDED
File without changes
|
modules/commons/common_layers.py
ADDED
@@ -0,0 +1,668 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import Parameter
|
5 |
+
import torch.onnx.operators
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import utils
|
8 |
+
|
9 |
+
|
10 |
+
class Reshape(nn.Module):
|
11 |
+
def __init__(self, *args):
|
12 |
+
super(Reshape, self).__init__()
|
13 |
+
self.shape = args
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
return x.view(self.shape)
|
17 |
+
|
18 |
+
|
19 |
+
class Permute(nn.Module):
|
20 |
+
def __init__(self, *args):
|
21 |
+
super(Permute, self).__init__()
|
22 |
+
self.args = args
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return x.permute(self.args)
|
26 |
+
|
27 |
+
|
28 |
+
class LinearNorm(torch.nn.Module):
|
29 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
30 |
+
super(LinearNorm, self).__init__()
|
31 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
32 |
+
|
33 |
+
torch.nn.init.xavier_uniform_(
|
34 |
+
self.linear_layer.weight,
|
35 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.linear_layer(x)
|
39 |
+
|
40 |
+
|
41 |
+
class ConvNorm(torch.nn.Module):
|
42 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
43 |
+
padding=None, dilation=1, bias=True, w_init_gain='linear'):
|
44 |
+
super(ConvNorm, self).__init__()
|
45 |
+
if padding is None:
|
46 |
+
assert (kernel_size % 2 == 1)
|
47 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
48 |
+
|
49 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
50 |
+
kernel_size=kernel_size, stride=stride,
|
51 |
+
padding=padding, dilation=dilation,
|
52 |
+
bias=bias)
|
53 |
+
|
54 |
+
torch.nn.init.xavier_uniform_(
|
55 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
56 |
+
|
57 |
+
def forward(self, signal):
|
58 |
+
conv_signal = self.conv(signal)
|
59 |
+
return conv_signal
|
60 |
+
|
61 |
+
|
62 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx=None):
|
63 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
64 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
65 |
+
if padding_idx is not None:
|
66 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
67 |
+
return m
|
68 |
+
|
69 |
+
|
70 |
+
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
|
71 |
+
if not export and torch.cuda.is_available():
|
72 |
+
try:
|
73 |
+
from apex.normalization import FusedLayerNorm
|
74 |
+
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
|
75 |
+
except ImportError:
|
76 |
+
pass
|
77 |
+
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
78 |
+
|
79 |
+
|
80 |
+
def Linear(in_features, out_features, bias=True):
|
81 |
+
m = nn.Linear(in_features, out_features, bias)
|
82 |
+
nn.init.xavier_uniform_(m.weight)
|
83 |
+
if bias:
|
84 |
+
nn.init.constant_(m.bias, 0.)
|
85 |
+
return m
|
86 |
+
|
87 |
+
|
88 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
89 |
+
"""This module produces sinusoidal positional embeddings of any length.
|
90 |
+
|
91 |
+
Padding symbols are ignored.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
95 |
+
super().__init__()
|
96 |
+
self.embedding_dim = embedding_dim
|
97 |
+
self.padding_idx = padding_idx
|
98 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
99 |
+
init_size,
|
100 |
+
embedding_dim,
|
101 |
+
padding_idx,
|
102 |
+
)
|
103 |
+
self.register_buffer('_float_tensor', torch.FloatTensor(1))
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
107 |
+
"""Build sinusoidal embeddings.
|
108 |
+
|
109 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
110 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
111 |
+
"""
|
112 |
+
half_dim = embedding_dim // 2
|
113 |
+
emb = math.log(10000) / (half_dim - 1)
|
114 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
115 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
116 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
117 |
+
if embedding_dim % 2 == 1:
|
118 |
+
# zero pad
|
119 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
120 |
+
if padding_idx is not None:
|
121 |
+
emb[padding_idx, :] = 0
|
122 |
+
return emb
|
123 |
+
|
124 |
+
def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
|
125 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
126 |
+
bsz, seq_len = input.shape[:2]
|
127 |
+
max_pos = self.padding_idx + 1 + seq_len
|
128 |
+
if self.weights is None or max_pos > self.weights.size(0):
|
129 |
+
# recompute/expand embeddings if needed
|
130 |
+
self.weights = SinusoidalPositionalEmbedding.get_embedding(
|
131 |
+
max_pos,
|
132 |
+
self.embedding_dim,
|
133 |
+
self.padding_idx,
|
134 |
+
)
|
135 |
+
self.weights = self.weights.to(self._float_tensor)
|
136 |
+
|
137 |
+
if incremental_state is not None:
|
138 |
+
# positions is the same for every token when decoding a single step
|
139 |
+
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
|
140 |
+
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
|
141 |
+
|
142 |
+
positions = utils.make_positions(input, self.padding_idx) if positions is None else positions
|
143 |
+
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
|
144 |
+
|
145 |
+
def max_positions(self):
|
146 |
+
"""Maximum number of supported positions."""
|
147 |
+
return int(1e5) # an arbitrary large number
|
148 |
+
|
149 |
+
|
150 |
+
class ConvTBC(nn.Module):
|
151 |
+
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
|
152 |
+
super(ConvTBC, self).__init__()
|
153 |
+
self.in_channels = in_channels
|
154 |
+
self.out_channels = out_channels
|
155 |
+
self.kernel_size = kernel_size
|
156 |
+
self.padding = padding
|
157 |
+
|
158 |
+
self.weight = torch.nn.Parameter(torch.Tensor(
|
159 |
+
self.kernel_size, in_channels, out_channels))
|
160 |
+
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
|
161 |
+
|
162 |
+
def forward(self, input):
|
163 |
+
return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
|
164 |
+
|
165 |
+
|
166 |
+
class MultiheadAttention(nn.Module):
|
167 |
+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
|
168 |
+
add_bias_kv=False, add_zero_attn=False, self_attention=False,
|
169 |
+
encoder_decoder_attention=False):
|
170 |
+
super().__init__()
|
171 |
+
self.embed_dim = embed_dim
|
172 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
173 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
174 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
175 |
+
|
176 |
+
self.num_heads = num_heads
|
177 |
+
self.dropout = dropout
|
178 |
+
self.head_dim = embed_dim // num_heads
|
179 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
180 |
+
self.scaling = self.head_dim ** -0.5
|
181 |
+
|
182 |
+
self.self_attention = self_attention
|
183 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
184 |
+
|
185 |
+
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
|
186 |
+
'value to be of the same size'
|
187 |
+
|
188 |
+
if self.qkv_same_dim:
|
189 |
+
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
|
190 |
+
else:
|
191 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
192 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
193 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
194 |
+
|
195 |
+
if bias:
|
196 |
+
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
|
197 |
+
else:
|
198 |
+
self.register_parameter('in_proj_bias', None)
|
199 |
+
|
200 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
201 |
+
|
202 |
+
if add_bias_kv:
|
203 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
204 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
205 |
+
else:
|
206 |
+
self.bias_k = self.bias_v = None
|
207 |
+
|
208 |
+
self.add_zero_attn = add_zero_attn
|
209 |
+
|
210 |
+
self.reset_parameters()
|
211 |
+
|
212 |
+
self.enable_torch_version = False
|
213 |
+
if hasattr(F, "multi_head_attention_forward"):
|
214 |
+
self.enable_torch_version = True
|
215 |
+
else:
|
216 |
+
self.enable_torch_version = False
|
217 |
+
self.last_attn_probs = None
|
218 |
+
|
219 |
+
def reset_parameters(self):
|
220 |
+
if self.qkv_same_dim:
|
221 |
+
nn.init.xavier_uniform_(self.in_proj_weight)
|
222 |
+
else:
|
223 |
+
nn.init.xavier_uniform_(self.k_proj_weight)
|
224 |
+
nn.init.xavier_uniform_(self.v_proj_weight)
|
225 |
+
nn.init.xavier_uniform_(self.q_proj_weight)
|
226 |
+
|
227 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
228 |
+
if self.in_proj_bias is not None:
|
229 |
+
nn.init.constant_(self.in_proj_bias, 0.)
|
230 |
+
nn.init.constant_(self.out_proj.bias, 0.)
|
231 |
+
if self.bias_k is not None:
|
232 |
+
nn.init.xavier_normal_(self.bias_k)
|
233 |
+
if self.bias_v is not None:
|
234 |
+
nn.init.xavier_normal_(self.bias_v)
|
235 |
+
|
236 |
+
def forward(
|
237 |
+
self,
|
238 |
+
query, key, value,
|
239 |
+
key_padding_mask=None,
|
240 |
+
incremental_state=None,
|
241 |
+
need_weights=True,
|
242 |
+
static_kv=False,
|
243 |
+
attn_mask=None,
|
244 |
+
before_softmax=False,
|
245 |
+
need_head_weights=False,
|
246 |
+
enc_dec_attn_constraint_mask=None,
|
247 |
+
reset_attn_weight=None
|
248 |
+
):
|
249 |
+
"""Input shape: Time x Batch x Channel
|
250 |
+
|
251 |
+
Args:
|
252 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
253 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
254 |
+
padding elements are indicated by 1s.
|
255 |
+
need_weights (bool, optional): return the attention weights,
|
256 |
+
averaged over heads (default: False).
|
257 |
+
attn_mask (ByteTensor, optional): typically used to
|
258 |
+
implement causal attention, where the mask prevents the
|
259 |
+
attention from looking forward in time (default: None).
|
260 |
+
before_softmax (bool, optional): return the raw attention
|
261 |
+
weights and values before the attention softmax.
|
262 |
+
need_head_weights (bool, optional): return the attention
|
263 |
+
weights for each head. Implies *need_weights*. Default:
|
264 |
+
return the average attention weights over all heads.
|
265 |
+
"""
|
266 |
+
if need_head_weights:
|
267 |
+
need_weights = True
|
268 |
+
|
269 |
+
tgt_len, bsz, embed_dim = query.size()
|
270 |
+
assert embed_dim == self.embed_dim
|
271 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
272 |
+
|
273 |
+
if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
|
274 |
+
if self.qkv_same_dim:
|
275 |
+
return F.multi_head_attention_forward(query, key, value,
|
276 |
+
self.embed_dim, self.num_heads,
|
277 |
+
self.in_proj_weight,
|
278 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
279 |
+
self.add_zero_attn, self.dropout,
|
280 |
+
self.out_proj.weight, self.out_proj.bias,
|
281 |
+
self.training, key_padding_mask, need_weights,
|
282 |
+
attn_mask)
|
283 |
+
else:
|
284 |
+
return F.multi_head_attention_forward(query, key, value,
|
285 |
+
self.embed_dim, self.num_heads,
|
286 |
+
torch.empty([0]),
|
287 |
+
self.in_proj_bias, self.bias_k, self.bias_v,
|
288 |
+
self.add_zero_attn, self.dropout,
|
289 |
+
self.out_proj.weight, self.out_proj.bias,
|
290 |
+
self.training, key_padding_mask, need_weights,
|
291 |
+
attn_mask, use_separate_proj_weight=True,
|
292 |
+
q_proj_weight=self.q_proj_weight,
|
293 |
+
k_proj_weight=self.k_proj_weight,
|
294 |
+
v_proj_weight=self.v_proj_weight)
|
295 |
+
|
296 |
+
if incremental_state is not None:
|
297 |
+
print('Not implemented error.')
|
298 |
+
exit()
|
299 |
+
else:
|
300 |
+
saved_state = None
|
301 |
+
|
302 |
+
if self.self_attention:
|
303 |
+
# self-attention
|
304 |
+
q, k, v = self.in_proj_qkv(query)
|
305 |
+
elif self.encoder_decoder_attention:
|
306 |
+
# encoder-decoder attention
|
307 |
+
q = self.in_proj_q(query)
|
308 |
+
if key is None:
|
309 |
+
assert value is None
|
310 |
+
k = v = None
|
311 |
+
else:
|
312 |
+
k = self.in_proj_k(key)
|
313 |
+
v = self.in_proj_v(key)
|
314 |
+
|
315 |
+
else:
|
316 |
+
q = self.in_proj_q(query)
|
317 |
+
k = self.in_proj_k(key)
|
318 |
+
v = self.in_proj_v(value)
|
319 |
+
q *= self.scaling
|
320 |
+
|
321 |
+
if self.bias_k is not None:
|
322 |
+
assert self.bias_v is not None
|
323 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
324 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
325 |
+
if attn_mask is not None:
|
326 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
327 |
+
if key_padding_mask is not None:
|
328 |
+
key_padding_mask = torch.cat(
|
329 |
+
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
|
330 |
+
|
331 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
332 |
+
if k is not None:
|
333 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
334 |
+
if v is not None:
|
335 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
336 |
+
|
337 |
+
if saved_state is not None:
|
338 |
+
print('Not implemented error.')
|
339 |
+
exit()
|
340 |
+
|
341 |
+
src_len = k.size(1)
|
342 |
+
|
343 |
+
# This is part of a workaround to get around fork/join parallelism
|
344 |
+
# not supporting Optional types.
|
345 |
+
if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
|
346 |
+
key_padding_mask = None
|
347 |
+
|
348 |
+
if key_padding_mask is not None:
|
349 |
+
assert key_padding_mask.size(0) == bsz
|
350 |
+
assert key_padding_mask.size(1) == src_len
|
351 |
+
|
352 |
+
if self.add_zero_attn:
|
353 |
+
src_len += 1
|
354 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
355 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
356 |
+
if attn_mask is not None:
|
357 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
358 |
+
if key_padding_mask is not None:
|
359 |
+
key_padding_mask = torch.cat(
|
360 |
+
[key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
|
361 |
+
|
362 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
363 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
364 |
+
|
365 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
366 |
+
|
367 |
+
if attn_mask is not None:
|
368 |
+
if len(attn_mask.shape) == 2:
|
369 |
+
attn_mask = attn_mask.unsqueeze(0)
|
370 |
+
elif len(attn_mask.shape) == 3:
|
371 |
+
attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
|
372 |
+
bsz * self.num_heads, tgt_len, src_len)
|
373 |
+
attn_weights = attn_weights + attn_mask
|
374 |
+
|
375 |
+
if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
|
376 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
377 |
+
attn_weights = attn_weights.masked_fill(
|
378 |
+
enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
|
379 |
+
-1e9,
|
380 |
+
)
|
381 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
382 |
+
|
383 |
+
if key_padding_mask is not None:
|
384 |
+
# don't attend to padding symbols
|
385 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
386 |
+
attn_weights = attn_weights.masked_fill(
|
387 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
388 |
+
-1e9,
|
389 |
+
)
|
390 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
391 |
+
|
392 |
+
attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
393 |
+
|
394 |
+
if before_softmax:
|
395 |
+
return attn_weights, v
|
396 |
+
|
397 |
+
attn_weights_float = utils.softmax(attn_weights, dim=-1)
|
398 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
399 |
+
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
|
400 |
+
|
401 |
+
if reset_attn_weight is not None:
|
402 |
+
if reset_attn_weight:
|
403 |
+
self.last_attn_probs = attn_probs.detach()
|
404 |
+
else:
|
405 |
+
assert self.last_attn_probs is not None
|
406 |
+
attn_probs = self.last_attn_probs
|
407 |
+
attn = torch.bmm(attn_probs, v)
|
408 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
409 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
410 |
+
attn = self.out_proj(attn)
|
411 |
+
|
412 |
+
if need_weights:
|
413 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
414 |
+
if not need_head_weights:
|
415 |
+
# average attention weights over heads
|
416 |
+
attn_weights = attn_weights.mean(dim=0)
|
417 |
+
else:
|
418 |
+
attn_weights = None
|
419 |
+
|
420 |
+
return attn, (attn_weights, attn_logits)
|
421 |
+
|
422 |
+
def in_proj_qkv(self, query):
|
423 |
+
return self._in_proj(query).chunk(3, dim=-1)
|
424 |
+
|
425 |
+
def in_proj_q(self, query):
|
426 |
+
if self.qkv_same_dim:
|
427 |
+
return self._in_proj(query, end=self.embed_dim)
|
428 |
+
else:
|
429 |
+
bias = self.in_proj_bias
|
430 |
+
if bias is not None:
|
431 |
+
bias = bias[:self.embed_dim]
|
432 |
+
return F.linear(query, self.q_proj_weight, bias)
|
433 |
+
|
434 |
+
def in_proj_k(self, key):
|
435 |
+
if self.qkv_same_dim:
|
436 |
+
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
|
437 |
+
else:
|
438 |
+
weight = self.k_proj_weight
|
439 |
+
bias = self.in_proj_bias
|
440 |
+
if bias is not None:
|
441 |
+
bias = bias[self.embed_dim:2 * self.embed_dim]
|
442 |
+
return F.linear(key, weight, bias)
|
443 |
+
|
444 |
+
def in_proj_v(self, value):
|
445 |
+
if self.qkv_same_dim:
|
446 |
+
return self._in_proj(value, start=2 * self.embed_dim)
|
447 |
+
else:
|
448 |
+
weight = self.v_proj_weight
|
449 |
+
bias = self.in_proj_bias
|
450 |
+
if bias is not None:
|
451 |
+
bias = bias[2 * self.embed_dim:]
|
452 |
+
return F.linear(value, weight, bias)
|
453 |
+
|
454 |
+
def _in_proj(self, input, start=0, end=None):
|
455 |
+
weight = self.in_proj_weight
|
456 |
+
bias = self.in_proj_bias
|
457 |
+
weight = weight[start:end, :]
|
458 |
+
if bias is not None:
|
459 |
+
bias = bias[start:end]
|
460 |
+
return F.linear(input, weight, bias)
|
461 |
+
|
462 |
+
|
463 |
+
def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
|
464 |
+
return attn_weights
|
465 |
+
|
466 |
+
|
467 |
+
class Swish(torch.autograd.Function):
|
468 |
+
@staticmethod
|
469 |
+
def forward(ctx, i):
|
470 |
+
result = i * torch.sigmoid(i)
|
471 |
+
ctx.save_for_backward(i)
|
472 |
+
return result
|
473 |
+
|
474 |
+
@staticmethod
|
475 |
+
def backward(ctx, grad_output):
|
476 |
+
i = ctx.saved_variables[0]
|
477 |
+
sigmoid_i = torch.sigmoid(i)
|
478 |
+
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
479 |
+
|
480 |
+
|
481 |
+
class CustomSwish(nn.Module):
|
482 |
+
def forward(self, input_tensor):
|
483 |
+
return Swish.apply(input_tensor)
|
484 |
+
|
485 |
+
|
486 |
+
class TransformerFFNLayer(nn.Module):
|
487 |
+
def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
|
488 |
+
super().__init__()
|
489 |
+
self.kernel_size = kernel_size
|
490 |
+
self.dropout = dropout
|
491 |
+
self.act = act
|
492 |
+
if padding == 'SAME':
|
493 |
+
self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
|
494 |
+
elif padding == 'LEFT':
|
495 |
+
self.ffn_1 = nn.Sequential(
|
496 |
+
nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
|
497 |
+
nn.Conv1d(hidden_size, filter_size, kernel_size)
|
498 |
+
)
|
499 |
+
self.ffn_2 = Linear(filter_size, hidden_size)
|
500 |
+
if self.act == 'swish':
|
501 |
+
self.swish_fn = CustomSwish()
|
502 |
+
|
503 |
+
def forward(self, x, incremental_state=None):
|
504 |
+
# x: T x B x C
|
505 |
+
if incremental_state is not None:
|
506 |
+
assert incremental_state is None, 'Nar-generation does not allow this.'
|
507 |
+
exit(1)
|
508 |
+
|
509 |
+
x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
|
510 |
+
x = x * self.kernel_size ** -0.5
|
511 |
+
|
512 |
+
if incremental_state is not None:
|
513 |
+
x = x[-1:]
|
514 |
+
if self.act == 'gelu':
|
515 |
+
x = F.gelu(x)
|
516 |
+
if self.act == 'relu':
|
517 |
+
x = F.relu(x)
|
518 |
+
if self.act == 'swish':
|
519 |
+
x = self.swish_fn(x)
|
520 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
521 |
+
x = self.ffn_2(x)
|
522 |
+
return x
|
523 |
+
|
524 |
+
|
525 |
+
class BatchNorm1dTBC(nn.Module):
|
526 |
+
def __init__(self, c):
|
527 |
+
super(BatchNorm1dTBC, self).__init__()
|
528 |
+
self.bn = nn.BatchNorm1d(c)
|
529 |
+
|
530 |
+
def forward(self, x):
|
531 |
+
"""
|
532 |
+
|
533 |
+
:param x: [T, B, C]
|
534 |
+
:return: [T, B, C]
|
535 |
+
"""
|
536 |
+
x = x.permute(1, 2, 0) # [B, C, T]
|
537 |
+
x = self.bn(x) # [B, C, T]
|
538 |
+
x = x.permute(2, 0, 1) # [T, B, C]
|
539 |
+
return x
|
540 |
+
|
541 |
+
|
542 |
+
class EncSALayer(nn.Module):
|
543 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
|
544 |
+
relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
|
545 |
+
super().__init__()
|
546 |
+
self.c = c
|
547 |
+
self.dropout = dropout
|
548 |
+
self.num_heads = num_heads
|
549 |
+
if num_heads > 0:
|
550 |
+
if norm == 'ln':
|
551 |
+
self.layer_norm1 = LayerNorm(c)
|
552 |
+
elif norm == 'bn':
|
553 |
+
self.layer_norm1 = BatchNorm1dTBC(c)
|
554 |
+
self.self_attn = MultiheadAttention(
|
555 |
+
self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False,
|
556 |
+
)
|
557 |
+
if norm == 'ln':
|
558 |
+
self.layer_norm2 = LayerNorm(c)
|
559 |
+
elif norm == 'bn':
|
560 |
+
self.layer_norm2 = BatchNorm1dTBC(c)
|
561 |
+
self.ffn = TransformerFFNLayer(
|
562 |
+
c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
|
563 |
+
|
564 |
+
def forward(self, x, encoder_padding_mask=None, **kwargs):
|
565 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
566 |
+
if layer_norm_training is not None:
|
567 |
+
self.layer_norm1.training = layer_norm_training
|
568 |
+
self.layer_norm2.training = layer_norm_training
|
569 |
+
if self.num_heads > 0:
|
570 |
+
residual = x
|
571 |
+
x = self.layer_norm1(x)
|
572 |
+
x, _, = self.self_attn(
|
573 |
+
query=x,
|
574 |
+
key=x,
|
575 |
+
value=x,
|
576 |
+
key_padding_mask=encoder_padding_mask
|
577 |
+
)
|
578 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
579 |
+
x = residual + x
|
580 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
581 |
+
|
582 |
+
residual = x
|
583 |
+
x = self.layer_norm2(x)
|
584 |
+
x = self.ffn(x)
|
585 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
586 |
+
x = residual + x
|
587 |
+
x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
|
588 |
+
return x
|
589 |
+
|
590 |
+
|
591 |
+
class DecSALayer(nn.Module):
|
592 |
+
def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, act='gelu'):
|
593 |
+
super().__init__()
|
594 |
+
self.c = c
|
595 |
+
self.dropout = dropout
|
596 |
+
self.layer_norm1 = LayerNorm(c)
|
597 |
+
self.self_attn = MultiheadAttention(
|
598 |
+
c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
|
599 |
+
)
|
600 |
+
self.layer_norm2 = LayerNorm(c)
|
601 |
+
self.encoder_attn = MultiheadAttention(
|
602 |
+
c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
|
603 |
+
)
|
604 |
+
self.layer_norm3 = LayerNorm(c)
|
605 |
+
self.ffn = TransformerFFNLayer(
|
606 |
+
c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
|
607 |
+
|
608 |
+
def forward(
|
609 |
+
self,
|
610 |
+
x,
|
611 |
+
encoder_out=None,
|
612 |
+
encoder_padding_mask=None,
|
613 |
+
incremental_state=None,
|
614 |
+
self_attn_mask=None,
|
615 |
+
self_attn_padding_mask=None,
|
616 |
+
attn_out=None,
|
617 |
+
reset_attn_weight=None,
|
618 |
+
**kwargs,
|
619 |
+
):
|
620 |
+
layer_norm_training = kwargs.get('layer_norm_training', None)
|
621 |
+
if layer_norm_training is not None:
|
622 |
+
self.layer_norm1.training = layer_norm_training
|
623 |
+
self.layer_norm2.training = layer_norm_training
|
624 |
+
self.layer_norm3.training = layer_norm_training
|
625 |
+
residual = x
|
626 |
+
x = self.layer_norm1(x)
|
627 |
+
x, _ = self.self_attn(
|
628 |
+
query=x,
|
629 |
+
key=x,
|
630 |
+
value=x,
|
631 |
+
key_padding_mask=self_attn_padding_mask,
|
632 |
+
incremental_state=incremental_state,
|
633 |
+
attn_mask=self_attn_mask
|
634 |
+
)
|
635 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
636 |
+
x = residual + x
|
637 |
+
|
638 |
+
residual = x
|
639 |
+
x = self.layer_norm2(x)
|
640 |
+
if encoder_out is not None:
|
641 |
+
x, attn = self.encoder_attn(
|
642 |
+
query=x,
|
643 |
+
key=encoder_out,
|
644 |
+
value=encoder_out,
|
645 |
+
key_padding_mask=encoder_padding_mask,
|
646 |
+
incremental_state=incremental_state,
|
647 |
+
static_kv=True,
|
648 |
+
enc_dec_attn_constraint_mask=None, #utils.get_incremental_state(self, incremental_state, 'enc_dec_attn_constraint_mask'),
|
649 |
+
reset_attn_weight=reset_attn_weight
|
650 |
+
)
|
651 |
+
attn_logits = attn[1]
|
652 |
+
else:
|
653 |
+
assert attn_out is not None
|
654 |
+
x = self.encoder_attn.in_proj_v(attn_out.transpose(0, 1))
|
655 |
+
attn_logits = None
|
656 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
657 |
+
x = residual + x
|
658 |
+
|
659 |
+
residual = x
|
660 |
+
x = self.layer_norm3(x)
|
661 |
+
x = self.ffn(x, incremental_state=incremental_state)
|
662 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
663 |
+
x = residual + x
|
664 |
+
# if len(attn_logits.size()) > 3:
|
665 |
+
# indices = attn_logits.softmax(-1).max(-1).values.sum(-1).argmax(-1)
|
666 |
+
# attn_logits = attn_logits.gather(1,
|
667 |
+
# indices[:, None, None, None].repeat(1, 1, attn_logits.size(-2), attn_logits.size(-1))).squeeze(1)
|
668 |
+
return x, attn_logits
|
modules/commons/espnet_positional_embedding.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class PositionalEncoding(torch.nn.Module):
|
6 |
+
"""Positional encoding.
|
7 |
+
Args:
|
8 |
+
d_model (int): Embedding dimension.
|
9 |
+
dropout_rate (float): Dropout rate.
|
10 |
+
max_len (int): Maximum input length.
|
11 |
+
reverse (bool): Whether to reverse the input position.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
15 |
+
"""Construct an PositionalEncoding object."""
|
16 |
+
super(PositionalEncoding, self).__init__()
|
17 |
+
self.d_model = d_model
|
18 |
+
self.reverse = reverse
|
19 |
+
self.xscale = math.sqrt(self.d_model)
|
20 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
21 |
+
self.pe = None
|
22 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
23 |
+
|
24 |
+
def extend_pe(self, x):
|
25 |
+
"""Reset the positional encodings."""
|
26 |
+
if self.pe is not None:
|
27 |
+
if self.pe.size(1) >= x.size(1):
|
28 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
29 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
30 |
+
return
|
31 |
+
pe = torch.zeros(x.size(1), self.d_model)
|
32 |
+
if self.reverse:
|
33 |
+
position = torch.arange(
|
34 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
35 |
+
).unsqueeze(1)
|
36 |
+
else:
|
37 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
38 |
+
div_term = torch.exp(
|
39 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
40 |
+
* -(math.log(10000.0) / self.d_model)
|
41 |
+
)
|
42 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
43 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
44 |
+
pe = pe.unsqueeze(0)
|
45 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
46 |
+
|
47 |
+
def forward(self, x: torch.Tensor):
|
48 |
+
"""Add positional encoding.
|
49 |
+
Args:
|
50 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
51 |
+
Returns:
|
52 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
53 |
+
"""
|
54 |
+
self.extend_pe(x)
|
55 |
+
x = x * self.xscale + self.pe[:, : x.size(1)]
|
56 |
+
return self.dropout(x)
|
57 |
+
|
58 |
+
|
59 |
+
class ScaledPositionalEncoding(PositionalEncoding):
|
60 |
+
"""Scaled positional encoding module.
|
61 |
+
See Sec. 3.2 https://arxiv.org/abs/1809.08895
|
62 |
+
Args:
|
63 |
+
d_model (int): Embedding dimension.
|
64 |
+
dropout_rate (float): Dropout rate.
|
65 |
+
max_len (int): Maximum input length.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
69 |
+
"""Initialize class."""
|
70 |
+
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
71 |
+
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
72 |
+
|
73 |
+
def reset_parameters(self):
|
74 |
+
"""Reset parameters."""
|
75 |
+
self.alpha.data = torch.tensor(1.0)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
"""Add positional encoding.
|
79 |
+
Args:
|
80 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
81 |
+
Returns:
|
82 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
83 |
+
"""
|
84 |
+
self.extend_pe(x)
|
85 |
+
x = x + self.alpha * self.pe[:, : x.size(1)]
|
86 |
+
return self.dropout(x)
|
87 |
+
|
88 |
+
|
89 |
+
class RelPositionalEncoding(PositionalEncoding):
|
90 |
+
"""Relative positional encoding module.
|
91 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
92 |
+
Args:
|
93 |
+
d_model (int): Embedding dimension.
|
94 |
+
dropout_rate (float): Dropout rate.
|
95 |
+
max_len (int): Maximum input length.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(self, d_model, dropout_rate, max_len=5000):
|
99 |
+
"""Initialize class."""
|
100 |
+
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
"""Compute positional encoding.
|
104 |
+
Args:
|
105 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
106 |
+
Returns:
|
107 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
108 |
+
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
109 |
+
"""
|
110 |
+
self.extend_pe(x)
|
111 |
+
x = x * self.xscale
|
112 |
+
pos_emb = self.pe[:, : x.size(1)]
|
113 |
+
return self.dropout(x) + self.dropout(pos_emb)
|
modules/commons/ssim.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# '''
|
2 |
+
# https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py
|
3 |
+
# '''
|
4 |
+
#
|
5 |
+
# import torch
|
6 |
+
# import torch.jit
|
7 |
+
# import torch.nn.functional as F
|
8 |
+
#
|
9 |
+
#
|
10 |
+
# @torch.jit.script
|
11 |
+
# def create_window(window_size: int, sigma: float, channel: int):
|
12 |
+
# '''
|
13 |
+
# Create 1-D gauss kernel
|
14 |
+
# :param window_size: the size of gauss kernel
|
15 |
+
# :param sigma: sigma of normal distribution
|
16 |
+
# :param channel: input channel
|
17 |
+
# :return: 1D kernel
|
18 |
+
# '''
|
19 |
+
# coords = torch.arange(window_size, dtype=torch.float)
|
20 |
+
# coords -= window_size // 2
|
21 |
+
#
|
22 |
+
# g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
|
23 |
+
# g /= g.sum()
|
24 |
+
#
|
25 |
+
# g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)
|
26 |
+
# return g
|
27 |
+
#
|
28 |
+
#
|
29 |
+
# @torch.jit.script
|
30 |
+
# def _gaussian_filter(x, window_1d, use_padding: bool):
|
31 |
+
# '''
|
32 |
+
# Blur input with 1-D kernel
|
33 |
+
# :param x: batch of tensors to be blured
|
34 |
+
# :param window_1d: 1-D gauss kernel
|
35 |
+
# :param use_padding: padding image before conv
|
36 |
+
# :return: blured tensors
|
37 |
+
# '''
|
38 |
+
# C = x.shape[1]
|
39 |
+
# padding = 0
|
40 |
+
# if use_padding:
|
41 |
+
# window_size = window_1d.shape[3]
|
42 |
+
# padding = window_size // 2
|
43 |
+
# out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)
|
44 |
+
# out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)
|
45 |
+
# return out
|
46 |
+
#
|
47 |
+
#
|
48 |
+
# @torch.jit.script
|
49 |
+
# def ssim(X, Y, window, data_range: float, use_padding: bool = False):
|
50 |
+
# '''
|
51 |
+
# Calculate ssim index for X and Y
|
52 |
+
# :param X: images [B, C, H, N_bins]
|
53 |
+
# :param Y: images [B, C, H, N_bins]
|
54 |
+
# :param window: 1-D gauss kernel
|
55 |
+
# :param data_range: value range of input images. (usually 1.0 or 255)
|
56 |
+
# :param use_padding: padding image before conv
|
57 |
+
# :return:
|
58 |
+
# '''
|
59 |
+
#
|
60 |
+
# K1 = 0.01
|
61 |
+
# K2 = 0.03
|
62 |
+
# compensation = 1.0
|
63 |
+
#
|
64 |
+
# C1 = (K1 * data_range) ** 2
|
65 |
+
# C2 = (K2 * data_range) ** 2
|
66 |
+
#
|
67 |
+
# mu1 = _gaussian_filter(X, window, use_padding)
|
68 |
+
# mu2 = _gaussian_filter(Y, window, use_padding)
|
69 |
+
# sigma1_sq = _gaussian_filter(X * X, window, use_padding)
|
70 |
+
# sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)
|
71 |
+
# sigma12 = _gaussian_filter(X * Y, window, use_padding)
|
72 |
+
#
|
73 |
+
# mu1_sq = mu1.pow(2)
|
74 |
+
# mu2_sq = mu2.pow(2)
|
75 |
+
# mu1_mu2 = mu1 * mu2
|
76 |
+
#
|
77 |
+
# sigma1_sq = compensation * (sigma1_sq - mu1_sq)
|
78 |
+
# sigma2_sq = compensation * (sigma2_sq - mu2_sq)
|
79 |
+
# sigma12 = compensation * (sigma12 - mu1_mu2)
|
80 |
+
#
|
81 |
+
# cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
|
82 |
+
# # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.
|
83 |
+
# cs_map = cs_map.clamp_min(0.)
|
84 |
+
# ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
|
85 |
+
#
|
86 |
+
# ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW
|
87 |
+
# cs = cs_map.mean(dim=(1, 2, 3))
|
88 |
+
#
|
89 |
+
# return ssim_val, cs
|
90 |
+
#
|
91 |
+
#
|
92 |
+
# @torch.jit.script
|
93 |
+
# def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):
|
94 |
+
# '''
|
95 |
+
# interface of ms-ssim
|
96 |
+
# :param X: a batch of images, (N,C,H,W)
|
97 |
+
# :param Y: a batch of images, (N,C,H,W)
|
98 |
+
# :param window: 1-D gauss kernel
|
99 |
+
# :param data_range: value range of input images. (usually 1.0 or 255)
|
100 |
+
# :param weights: weights for different levels
|
101 |
+
# :param use_padding: padding image before conv
|
102 |
+
# :param eps: use for avoid grad nan.
|
103 |
+
# :return:
|
104 |
+
# '''
|
105 |
+
# levels = weights.shape[0]
|
106 |
+
# cs_vals = []
|
107 |
+
# ssim_vals = []
|
108 |
+
# for _ in range(levels):
|
109 |
+
# ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)
|
110 |
+
# # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
|
111 |
+
# ssim_val = ssim_val.clamp_min(eps)
|
112 |
+
# cs = cs.clamp_min(eps)
|
113 |
+
# cs_vals.append(cs)
|
114 |
+
#
|
115 |
+
# ssim_vals.append(ssim_val)
|
116 |
+
# padding = (X.shape[2] % 2, X.shape[3] % 2)
|
117 |
+
# X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)
|
118 |
+
# Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)
|
119 |
+
#
|
120 |
+
# cs_vals = torch.stack(cs_vals, dim=0)
|
121 |
+
# ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)
|
122 |
+
# return ms_ssim_val
|
123 |
+
#
|
124 |
+
#
|
125 |
+
# class SSIM(torch.jit.ScriptModule):
|
126 |
+
# __constants__ = ['data_range', 'use_padding']
|
127 |
+
#
|
128 |
+
# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):
|
129 |
+
# '''
|
130 |
+
# :param window_size: the size of gauss kernel
|
131 |
+
# :param window_sigma: sigma of normal distribution
|
132 |
+
# :param data_range: value range of input images. (usually 1.0 or 255)
|
133 |
+
# :param channel: input channels (default: 3)
|
134 |
+
# :param use_padding: padding image before conv
|
135 |
+
# '''
|
136 |
+
# super().__init__()
|
137 |
+
# assert window_size % 2 == 1, 'Window size must be odd.'
|
138 |
+
# window = create_window(window_size, window_sigma, channel)
|
139 |
+
# self.register_buffer('window', window)
|
140 |
+
# self.data_range = data_range
|
141 |
+
# self.use_padding = use_padding
|
142 |
+
#
|
143 |
+
# @torch.jit.script_method
|
144 |
+
# def forward(self, X, Y):
|
145 |
+
# r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)
|
146 |
+
# return r[0]
|
147 |
+
#
|
148 |
+
#
|
149 |
+
# class MS_SSIM(torch.jit.ScriptModule):
|
150 |
+
# __constants__ = ['data_range', 'use_padding', 'eps']
|
151 |
+
#
|
152 |
+
# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,
|
153 |
+
# levels=None, eps=1e-8):
|
154 |
+
# '''
|
155 |
+
# class for ms-ssim
|
156 |
+
# :param window_size: the size of gauss kernel
|
157 |
+
# :param window_sigma: sigma of normal distribution
|
158 |
+
# :param data_range: value range of input images. (usually 1.0 or 255)
|
159 |
+
# :param channel: input channels
|
160 |
+
# :param use_padding: padding image before conv
|
161 |
+
# :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
|
162 |
+
# :param levels: number of downsampling
|
163 |
+
# :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
|
164 |
+
# '''
|
165 |
+
# super().__init__()
|
166 |
+
# assert window_size % 2 == 1, 'Window size must be odd.'
|
167 |
+
# self.data_range = data_range
|
168 |
+
# self.use_padding = use_padding
|
169 |
+
# self.eps = eps
|
170 |
+
#
|
171 |
+
# window = create_window(window_size, window_sigma, channel)
|
172 |
+
# self.register_buffer('window', window)
|
173 |
+
#
|
174 |
+
# if weights is None:
|
175 |
+
# weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
|
176 |
+
# weights = torch.tensor(weights, dtype=torch.float)
|
177 |
+
#
|
178 |
+
# if levels is not None:
|
179 |
+
# weights = weights[:levels]
|
180 |
+
# weights = weights / weights.sum()
|
181 |
+
#
|
182 |
+
# self.register_buffer('weights', weights)
|
183 |
+
#
|
184 |
+
# @torch.jit.script_method
|
185 |
+
# def forward(self, X, Y):
|
186 |
+
# return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,
|
187 |
+
# use_padding=self.use_padding, eps=self.eps)
|
188 |
+
#
|
189 |
+
#
|
190 |
+
# if __name__ == '__main__':
|
191 |
+
# print('Simple Test')
|
192 |
+
# im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
|
193 |
+
# img1 = im / 255
|
194 |
+
# img2 = img1 * 0.5
|
195 |
+
#
|
196 |
+
# losser = SSIM(data_range=1.).cuda()
|
197 |
+
# loss = losser(img1, img2).mean()
|
198 |
+
#
|
199 |
+
# losser2 = MS_SSIM(data_range=1.).cuda()
|
200 |
+
# loss2 = losser2(img1, img2).mean()
|
201 |
+
#
|
202 |
+
# print(loss.item())
|
203 |
+
# print(loss2.item())
|
204 |
+
#
|
205 |
+
# if __name__ == '__main__':
|
206 |
+
# print('Training Test')
|
207 |
+
# import cv2
|
208 |
+
# import torch.optim
|
209 |
+
# import numpy as np
|
210 |
+
# import imageio
|
211 |
+
# import time
|
212 |
+
#
|
213 |
+
# out_test_video = False
|
214 |
+
# # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF
|
215 |
+
# video_use_gif = False
|
216 |
+
#
|
217 |
+
# im = cv2.imread('test_img1.jpg', 1)
|
218 |
+
# t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.
|
219 |
+
#
|
220 |
+
# if out_test_video:
|
221 |
+
# if video_use_gif:
|
222 |
+
# fps = 0.5
|
223 |
+
# out_wh = (im.shape[1] // 2, im.shape[0] // 2)
|
224 |
+
# suffix = '.gif'
|
225 |
+
# else:
|
226 |
+
# fps = 5
|
227 |
+
# out_wh = (im.shape[1], im.shape[0])
|
228 |
+
# suffix = '.mkv'
|
229 |
+
# video_last_time = time.perf_counter()
|
230 |
+
# video = imageio.get_writer('ssim_test' + suffix, fps=fps)
|
231 |
+
#
|
232 |
+
# # 测试ssim
|
233 |
+
# print('Training SSIM')
|
234 |
+
# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
|
235 |
+
# rand_im.requires_grad = True
|
236 |
+
# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
|
237 |
+
# losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()
|
238 |
+
# ssim_score = 0
|
239 |
+
# while ssim_score < 0.999:
|
240 |
+
# optim.zero_grad()
|
241 |
+
# loss = losser(rand_im, t_im)
|
242 |
+
# (-loss).sum().backward()
|
243 |
+
# ssim_score = loss.item()
|
244 |
+
# optim.step()
|
245 |
+
# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
|
246 |
+
# r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
|
247 |
+
#
|
248 |
+
# if out_test_video:
|
249 |
+
# if time.perf_counter() - video_last_time > 1. / fps:
|
250 |
+
# video_last_time = time.perf_counter()
|
251 |
+
# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
|
252 |
+
# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
|
253 |
+
# if isinstance(out_frame, cv2.UMat):
|
254 |
+
# out_frame = out_frame.get()
|
255 |
+
# video.append_data(out_frame)
|
256 |
+
#
|
257 |
+
# cv2.imshow('ssim', r_im)
|
258 |
+
# cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)
|
259 |
+
# cv2.waitKey(1)
|
260 |
+
#
|
261 |
+
# if out_test_video:
|
262 |
+
# video.close()
|
263 |
+
#
|
264 |
+
# # 测试ms_ssim
|
265 |
+
# if out_test_video:
|
266 |
+
# if video_use_gif:
|
267 |
+
# fps = 0.5
|
268 |
+
# out_wh = (im.shape[1] // 2, im.shape[0] // 2)
|
269 |
+
# suffix = '.gif'
|
270 |
+
# else:
|
271 |
+
# fps = 5
|
272 |
+
# out_wh = (im.shape[1], im.shape[0])
|
273 |
+
# suffix = '.mkv'
|
274 |
+
# video_last_time = time.perf_counter()
|
275 |
+
# video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)
|
276 |
+
#
|
277 |
+
# print('Training MS_SSIM')
|
278 |
+
# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
|
279 |
+
# rand_im.requires_grad = True
|
280 |
+
# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
|
281 |
+
# losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()
|
282 |
+
# ssim_score = 0
|
283 |
+
# while ssim_score < 0.999:
|
284 |
+
# optim.zero_grad()
|
285 |
+
# loss = losser(rand_im, t_im)
|
286 |
+
# (-loss).sum().backward()
|
287 |
+
# ssim_score = loss.item()
|
288 |
+
# optim.step()
|
289 |
+
# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
|
290 |
+
# r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
|
291 |
+
#
|
292 |
+
# if out_test_video:
|
293 |
+
# if time.perf_counter() - video_last_time > 1. / fps:
|
294 |
+
# video_last_time = time.perf_counter()
|
295 |
+
# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
|
296 |
+
# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
|
297 |
+
# if isinstance(out_frame, cv2.UMat):
|
298 |
+
# out_frame = out_frame.get()
|
299 |
+
# video.append_data(out_frame)
|
300 |
+
#
|
301 |
+
# cv2.imshow('ms_ssim', r_im)
|
302 |
+
# cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)
|
303 |
+
# cv2.waitKey(1)
|
304 |
+
#
|
305 |
+
# if out_test_video:
|
306 |
+
# video.close()
|
307 |
+
|
308 |
+
"""
|
309 |
+
Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
|
310 |
+
"""
|
311 |
+
|
312 |
+
import torch
|
313 |
+
import torch.nn.functional as F
|
314 |
+
from torch.autograd import Variable
|
315 |
+
import numpy as np
|
316 |
+
from math import exp
|
317 |
+
|
318 |
+
|
319 |
+
def gaussian(window_size, sigma):
|
320 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
321 |
+
return gauss / gauss.sum()
|
322 |
+
|
323 |
+
|
324 |
+
def create_window(window_size, channel):
|
325 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
326 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
327 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
328 |
+
return window
|
329 |
+
|
330 |
+
|
331 |
+
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
332 |
+
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
333 |
+
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
334 |
+
|
335 |
+
mu1_sq = mu1.pow(2)
|
336 |
+
mu2_sq = mu2.pow(2)
|
337 |
+
mu1_mu2 = mu1 * mu2
|
338 |
+
|
339 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
340 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
341 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
342 |
+
|
343 |
+
C1 = 0.01 ** 2
|
344 |
+
C2 = 0.03 ** 2
|
345 |
+
|
346 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
347 |
+
|
348 |
+
if size_average:
|
349 |
+
return ssim_map.mean()
|
350 |
+
else:
|
351 |
+
return ssim_map.mean(1)
|
352 |
+
|
353 |
+
|
354 |
+
class SSIM(torch.nn.Module):
|
355 |
+
def __init__(self, window_size=11, size_average=True):
|
356 |
+
super(SSIM, self).__init__()
|
357 |
+
self.window_size = window_size
|
358 |
+
self.size_average = size_average
|
359 |
+
self.channel = 1
|
360 |
+
self.window = create_window(window_size, self.channel)
|
361 |
+
|
362 |
+
def forward(self, img1, img2):
|
363 |
+
(_, channel, _, _) = img1.size()
|
364 |
+
|
365 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
366 |
+
window = self.window
|
367 |
+
else:
|
368 |
+
window = create_window(self.window_size, channel)
|
369 |
+
|
370 |
+
if img1.is_cuda:
|
371 |
+
window = window.cuda(img1.get_device())
|
372 |
+
window = window.type_as(img1)
|
373 |
+
|
374 |
+
self.window = window
|
375 |
+
self.channel = channel
|
376 |
+
|
377 |
+
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
378 |
+
|
379 |
+
|
380 |
+
window = None
|
381 |
+
|
382 |
+
|
383 |
+
def ssim(img1, img2, window_size=11, size_average=True):
|
384 |
+
(_, channel, _, _) = img1.size()
|
385 |
+
global window
|
386 |
+
if window is None:
|
387 |
+
window = create_window(window_size, channel)
|
388 |
+
if img1.is_cuda:
|
389 |
+
window = window.cuda(img1.get_device())
|
390 |
+
window = window.type_as(img1)
|
391 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
modules/diffsinger_midi/fs2.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.commons.common_layers import *
|
2 |
+
from modules.commons.common_layers import Embedding
|
3 |
+
from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
|
4 |
+
EnergyPredictor, FastspeechEncoder
|
5 |
+
from utils.cwt import cwt2f0
|
6 |
+
from utils.hparams import hparams
|
7 |
+
from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
|
8 |
+
from modules.fastspeech.fs2 import FastSpeech2
|
9 |
+
|
10 |
+
|
11 |
+
class FastspeechMIDIEncoder(FastspeechEncoder):
|
12 |
+
def forward_embedding(self, txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding):
|
13 |
+
# embed tokens and positions
|
14 |
+
x = self.embed_scale * self.embed_tokens(txt_tokens)
|
15 |
+
x = x + midi_embedding + midi_dur_embedding + slur_embedding
|
16 |
+
if hparams['use_pos_embed']:
|
17 |
+
if hparams.get('rel_pos') is not None and hparams['rel_pos']:
|
18 |
+
x = self.embed_positions(x)
|
19 |
+
else:
|
20 |
+
positions = self.embed_positions(txt_tokens)
|
21 |
+
x = x + positions
|
22 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
23 |
+
return x
|
24 |
+
|
25 |
+
def forward(self, txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding):
|
26 |
+
"""
|
27 |
+
|
28 |
+
:param txt_tokens: [B, T]
|
29 |
+
:return: {
|
30 |
+
'encoder_out': [T x B x C]
|
31 |
+
}
|
32 |
+
"""
|
33 |
+
encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
|
34 |
+
x = self.forward_embedding(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding) # [B, T, H]
|
35 |
+
x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
FS_ENCODERS = {
|
40 |
+
'fft': lambda hp, embed_tokens, d: FastspeechMIDIEncoder(
|
41 |
+
embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
|
42 |
+
num_heads=hp['num_heads']),
|
43 |
+
}
|
44 |
+
|
45 |
+
|
46 |
+
class FastSpeech2MIDI(FastSpeech2):
|
47 |
+
def __init__(self, dictionary, out_dims=None):
|
48 |
+
super().__init__(dictionary, out_dims)
|
49 |
+
del self.encoder
|
50 |
+
self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
|
51 |
+
self.midi_embed = Embedding(300, self.hidden_size, self.padding_idx)
|
52 |
+
self.midi_dur_layer = Linear(1, self.hidden_size)
|
53 |
+
self.is_slur_embed = Embedding(2, self.hidden_size)
|
54 |
+
|
55 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
56 |
+
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
|
57 |
+
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
|
58 |
+
ret = {}
|
59 |
+
|
60 |
+
midi_embedding = self.midi_embed(kwargs['pitch_midi'])
|
61 |
+
midi_dur_embedding, slur_embedding = 0, 0
|
62 |
+
if kwargs.get('midi_dur') is not None:
|
63 |
+
midi_dur_embedding = self.midi_dur_layer(kwargs['midi_dur'][:, :, None]) # [B, T, 1] -> [B, T, H]
|
64 |
+
if kwargs.get('is_slur') is not None:
|
65 |
+
slur_embedding = self.is_slur_embed(kwargs['is_slur'])
|
66 |
+
encoder_out = self.encoder(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding) # [B, T, C]
|
67 |
+
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
68 |
+
|
69 |
+
# add ref style embed
|
70 |
+
# Not implemented
|
71 |
+
# variance encoder
|
72 |
+
var_embed = 0
|
73 |
+
|
74 |
+
# encoder_out_dur denotes encoder outputs for duration predictor
|
75 |
+
# in speech adaptation, duration predictor use old speaker embedding
|
76 |
+
if hparams['use_spk_embed']:
|
77 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
|
78 |
+
elif hparams['use_spk_id']:
|
79 |
+
spk_embed_id = spk_embed
|
80 |
+
if spk_embed_dur_id is None:
|
81 |
+
spk_embed_dur_id = spk_embed_id
|
82 |
+
if spk_embed_f0_id is None:
|
83 |
+
spk_embed_f0_id = spk_embed_id
|
84 |
+
spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
|
85 |
+
spk_embed_dur = spk_embed_f0 = spk_embed
|
86 |
+
if hparams['use_split_spk_id']:
|
87 |
+
spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
|
88 |
+
spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
|
89 |
+
else:
|
90 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = 0
|
91 |
+
|
92 |
+
# add dur
|
93 |
+
dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
|
94 |
+
|
95 |
+
mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
|
96 |
+
|
97 |
+
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
|
98 |
+
|
99 |
+
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
|
100 |
+
decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
|
101 |
+
|
102 |
+
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
|
103 |
+
|
104 |
+
# add pitch and energy embed
|
105 |
+
pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
|
106 |
+
if hparams['use_pitch_embed']:
|
107 |
+
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
|
108 |
+
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
|
109 |
+
if hparams['use_energy_embed']:
|
110 |
+
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
|
111 |
+
|
112 |
+
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
|
113 |
+
|
114 |
+
if skip_decoder:
|
115 |
+
return ret
|
116 |
+
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
117 |
+
|
118 |
+
return ret
|
modules/fastspeech/fs2.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.commons.common_layers import *
|
2 |
+
from modules.commons.common_layers import Embedding
|
3 |
+
from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
|
4 |
+
EnergyPredictor, FastspeechEncoder
|
5 |
+
from utils.cwt import cwt2f0
|
6 |
+
from utils.hparams import hparams
|
7 |
+
from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
|
8 |
+
|
9 |
+
FS_ENCODERS = {
|
10 |
+
'fft': lambda hp, embed_tokens, d: FastspeechEncoder(
|
11 |
+
embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
|
12 |
+
num_heads=hp['num_heads']),
|
13 |
+
}
|
14 |
+
|
15 |
+
FS_DECODERS = {
|
16 |
+
'fft': lambda hp: FastspeechDecoder(
|
17 |
+
hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
class FastSpeech2(nn.Module):
|
22 |
+
def __init__(self, dictionary, out_dims=None):
|
23 |
+
super().__init__()
|
24 |
+
self.dictionary = dictionary
|
25 |
+
self.padding_idx = dictionary.pad()
|
26 |
+
self.enc_layers = hparams['enc_layers']
|
27 |
+
self.dec_layers = hparams['dec_layers']
|
28 |
+
self.hidden_size = hparams['hidden_size']
|
29 |
+
self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
|
30 |
+
self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
|
31 |
+
self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
|
32 |
+
self.out_dims = out_dims
|
33 |
+
if out_dims is None:
|
34 |
+
self.out_dims = hparams['audio_num_mel_bins']
|
35 |
+
self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
|
36 |
+
|
37 |
+
if hparams['use_spk_id']:
|
38 |
+
self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size)
|
39 |
+
if hparams['use_split_spk_id']:
|
40 |
+
self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size)
|
41 |
+
self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size)
|
42 |
+
elif hparams['use_spk_embed']:
|
43 |
+
self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
|
44 |
+
predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
|
45 |
+
self.dur_predictor = DurationPredictor(
|
46 |
+
self.hidden_size,
|
47 |
+
n_chans=predictor_hidden,
|
48 |
+
n_layers=hparams['dur_predictor_layers'],
|
49 |
+
dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
|
50 |
+
kernel_size=hparams['dur_predictor_kernel'])
|
51 |
+
self.length_regulator = LengthRegulator()
|
52 |
+
if hparams['use_pitch_embed']:
|
53 |
+
self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
|
54 |
+
if hparams['pitch_type'] == 'cwt':
|
55 |
+
h = hparams['cwt_hidden_size']
|
56 |
+
cwt_out_dims = 10
|
57 |
+
if hparams['use_uv']:
|
58 |
+
cwt_out_dims = cwt_out_dims + 1
|
59 |
+
self.cwt_predictor = nn.Sequential(
|
60 |
+
nn.Linear(self.hidden_size, h),
|
61 |
+
PitchPredictor(
|
62 |
+
h,
|
63 |
+
n_chans=predictor_hidden,
|
64 |
+
n_layers=hparams['predictor_layers'],
|
65 |
+
dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
|
66 |
+
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
|
67 |
+
self.cwt_stats_layers = nn.Sequential(
|
68 |
+
nn.Linear(self.hidden_size, h), nn.ReLU(),
|
69 |
+
nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
self.pitch_predictor = PitchPredictor(
|
73 |
+
self.hidden_size,
|
74 |
+
n_chans=predictor_hidden,
|
75 |
+
n_layers=hparams['predictor_layers'],
|
76 |
+
dropout_rate=hparams['predictor_dropout'],
|
77 |
+
odim=2 if hparams['pitch_type'] == 'frame' else 1,
|
78 |
+
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
79 |
+
if hparams['use_energy_embed']:
|
80 |
+
self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
|
81 |
+
self.energy_predictor = EnergyPredictor(
|
82 |
+
self.hidden_size,
|
83 |
+
n_chans=predictor_hidden,
|
84 |
+
n_layers=hparams['predictor_layers'],
|
85 |
+
dropout_rate=hparams['predictor_dropout'], odim=1,
|
86 |
+
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
87 |
+
|
88 |
+
def build_embedding(self, dictionary, embed_dim):
|
89 |
+
num_embeddings = len(dictionary)
|
90 |
+
emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
|
91 |
+
return emb
|
92 |
+
|
93 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
94 |
+
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
|
95 |
+
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
|
96 |
+
ret = {}
|
97 |
+
encoder_out = self.encoder(txt_tokens) # [B, T, C]
|
98 |
+
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
99 |
+
|
100 |
+
# add ref style embed
|
101 |
+
# Not implemented
|
102 |
+
# variance encoder
|
103 |
+
var_embed = 0
|
104 |
+
|
105 |
+
# encoder_out_dur denotes encoder outputs for duration predictor
|
106 |
+
# in speech adaptation, duration predictor use old speaker embedding
|
107 |
+
if hparams['use_spk_embed']:
|
108 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
|
109 |
+
elif hparams['use_spk_id']:
|
110 |
+
spk_embed_id = spk_embed
|
111 |
+
if spk_embed_dur_id is None:
|
112 |
+
spk_embed_dur_id = spk_embed_id
|
113 |
+
if spk_embed_f0_id is None:
|
114 |
+
spk_embed_f0_id = spk_embed_id
|
115 |
+
spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
|
116 |
+
spk_embed_dur = spk_embed_f0 = spk_embed
|
117 |
+
if hparams['use_split_spk_id']:
|
118 |
+
spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
|
119 |
+
spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
|
120 |
+
else:
|
121 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = 0
|
122 |
+
|
123 |
+
# add dur
|
124 |
+
dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
|
125 |
+
|
126 |
+
mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
|
127 |
+
|
128 |
+
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
|
129 |
+
|
130 |
+
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
|
131 |
+
decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
|
132 |
+
|
133 |
+
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
|
134 |
+
|
135 |
+
# add pitch and energy embed
|
136 |
+
pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
|
137 |
+
if hparams['use_pitch_embed']:
|
138 |
+
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
|
139 |
+
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
|
140 |
+
if hparams['use_energy_embed']:
|
141 |
+
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
|
142 |
+
|
143 |
+
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
|
144 |
+
|
145 |
+
if skip_decoder:
|
146 |
+
return ret
|
147 |
+
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
148 |
+
|
149 |
+
return ret
|
150 |
+
|
151 |
+
def add_dur(self, dur_input, mel2ph, txt_tokens, ret):
|
152 |
+
"""
|
153 |
+
|
154 |
+
:param dur_input: [B, T_txt, H]
|
155 |
+
:param mel2ph: [B, T_mel]
|
156 |
+
:param txt_tokens: [B, T_txt]
|
157 |
+
:param ret:
|
158 |
+
:return:
|
159 |
+
"""
|
160 |
+
src_padding = txt_tokens == 0
|
161 |
+
dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
|
162 |
+
if mel2ph is None:
|
163 |
+
dur, xs = self.dur_predictor.inference(dur_input, src_padding)
|
164 |
+
ret['dur'] = xs
|
165 |
+
ret['dur_choice'] = dur
|
166 |
+
mel2ph = self.length_regulator(dur, src_padding).detach()
|
167 |
+
# from modules.fastspeech.fake_modules import FakeLengthRegulator
|
168 |
+
# fake_lr = FakeLengthRegulator()
|
169 |
+
# fake_mel2ph = fake_lr(dur, (1 - src_padding.long()).sum(-1))[..., 0].detach()
|
170 |
+
# print(mel2ph == fake_mel2ph)
|
171 |
+
else:
|
172 |
+
ret['dur'] = self.dur_predictor(dur_input, src_padding)
|
173 |
+
ret['mel2ph'] = mel2ph
|
174 |
+
return mel2ph
|
175 |
+
|
176 |
+
def add_energy(self, decoder_inp, energy, ret):
|
177 |
+
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
|
178 |
+
ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
|
179 |
+
if energy is None:
|
180 |
+
energy = energy_pred
|
181 |
+
energy = torch.clamp(energy * 256 // 4, max=255).long()
|
182 |
+
energy_embed = self.energy_embed(energy)
|
183 |
+
return energy_embed
|
184 |
+
|
185 |
+
def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
|
186 |
+
if hparams['pitch_type'] == 'ph':
|
187 |
+
pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
|
188 |
+
pitch_padding = encoder_out.sum().abs() == 0
|
189 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp)
|
190 |
+
if f0 is None:
|
191 |
+
f0 = pitch_pred[:, :, 0]
|
192 |
+
ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
|
193 |
+
pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt]
|
194 |
+
pitch = F.pad(pitch, [1, 0])
|
195 |
+
pitch = torch.gather(pitch, 1, mel2ph) # [B, T_mel]
|
196 |
+
pitch_embed = self.pitch_embed(pitch)
|
197 |
+
return pitch_embed
|
198 |
+
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
|
199 |
+
|
200 |
+
pitch_padding = mel2ph == 0
|
201 |
+
|
202 |
+
if hparams['pitch_type'] == 'cwt':
|
203 |
+
pitch_padding = None
|
204 |
+
ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
|
205 |
+
stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2]
|
206 |
+
mean = ret['f0_mean'] = stats_out[:, 0]
|
207 |
+
std = ret['f0_std'] = stats_out[:, 1]
|
208 |
+
cwt_spec = cwt_out[:, :, :10]
|
209 |
+
if f0 is None:
|
210 |
+
std = std * hparams['cwt_std_scale']
|
211 |
+
f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
|
212 |
+
if hparams['use_uv']:
|
213 |
+
assert cwt_out.shape[-1] == 11
|
214 |
+
uv = cwt_out[:, :, -1] > 0
|
215 |
+
elif hparams['pitch_ar']:
|
216 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None)
|
217 |
+
if f0 is None:
|
218 |
+
f0 = pitch_pred[:, :, 0]
|
219 |
+
else:
|
220 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
|
221 |
+
if f0 is None:
|
222 |
+
f0 = pitch_pred[:, :, 0]
|
223 |
+
if hparams['use_uv'] and uv is None:
|
224 |
+
uv = pitch_pred[:, :, 1] > 0
|
225 |
+
ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
|
226 |
+
if pitch_padding is not None:
|
227 |
+
f0[pitch_padding] = 0
|
228 |
+
|
229 |
+
pitch = f0_to_coarse(f0_denorm) # start from 0
|
230 |
+
pitch_embed = self.pitch_embed(pitch)
|
231 |
+
return pitch_embed
|
232 |
+
|
233 |
+
def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs):
|
234 |
+
x = decoder_inp # [B, T, H]
|
235 |
+
x = self.decoder(x)
|
236 |
+
x = self.mel_out(x)
|
237 |
+
return x * tgt_nonpadding
|
238 |
+
|
239 |
+
def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
|
240 |
+
f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales'])
|
241 |
+
f0 = torch.cat(
|
242 |
+
[f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
|
243 |
+
f0_norm = norm_f0(f0, None, hparams)
|
244 |
+
return f0_norm
|
245 |
+
|
246 |
+
def out2mel(self, out):
|
247 |
+
return out
|
248 |
+
|
249 |
+
@staticmethod
|
250 |
+
def mel_norm(x):
|
251 |
+
return (x + 5.5) / (6.3 / 2) - 1
|
252 |
+
|
253 |
+
@staticmethod
|
254 |
+
def mel_denorm(x):
|
255 |
+
return (x + 1) * (6.3 / 2) - 5.5
|
modules/fastspeech/pe.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.commons.common_layers import *
|
2 |
+
from utils.hparams import hparams
|
3 |
+
from modules.fastspeech.tts_modules import PitchPredictor
|
4 |
+
from utils.pitch_utils import denorm_f0
|
5 |
+
|
6 |
+
|
7 |
+
class Prenet(nn.Module):
|
8 |
+
def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None):
|
9 |
+
super(Prenet, self).__init__()
|
10 |
+
padding = kernel // 2
|
11 |
+
self.layers = []
|
12 |
+
self.strides = strides if strides is not None else [1] * n_layers
|
13 |
+
for l in range(n_layers):
|
14 |
+
self.layers.append(nn.Sequential(
|
15 |
+
nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]),
|
16 |
+
nn.ReLU(),
|
17 |
+
nn.BatchNorm1d(out_dim)
|
18 |
+
))
|
19 |
+
in_dim = out_dim
|
20 |
+
self.layers = nn.ModuleList(self.layers)
|
21 |
+
self.out_proj = nn.Linear(out_dim, out_dim)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
"""
|
25 |
+
|
26 |
+
:param x: [B, T, 80]
|
27 |
+
:return: [L, B, T, H], [B, T, H]
|
28 |
+
"""
|
29 |
+
padding_mask = x.abs().sum(-1).eq(0).data # [B, T]
|
30 |
+
nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :] # [B, 1, T]
|
31 |
+
x = x.transpose(1, 2)
|
32 |
+
hiddens = []
|
33 |
+
for i, l in enumerate(self.layers):
|
34 |
+
nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]]
|
35 |
+
x = l(x) * nonpadding_mask_TB
|
36 |
+
hiddens.append(x)
|
37 |
+
hiddens = torch.stack(hiddens, 0) # [L, B, H, T]
|
38 |
+
hiddens = hiddens.transpose(2, 3) # [L, B, T, H]
|
39 |
+
x = self.out_proj(x.transpose(1, 2)) # [B, T, H]
|
40 |
+
x = x * nonpadding_mask_TB.transpose(1, 2)
|
41 |
+
return hiddens, x
|
42 |
+
|
43 |
+
|
44 |
+
class ConvBlock(nn.Module):
|
45 |
+
def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
|
46 |
+
super().__init__()
|
47 |
+
self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
|
48 |
+
self.norm = norm
|
49 |
+
if self.norm == 'bn':
|
50 |
+
self.norm = nn.BatchNorm1d(n_chans)
|
51 |
+
elif self.norm == 'in':
|
52 |
+
self.norm = nn.InstanceNorm1d(n_chans, affine=True)
|
53 |
+
elif self.norm == 'gn':
|
54 |
+
self.norm = nn.GroupNorm(n_chans // 16, n_chans)
|
55 |
+
elif self.norm == 'ln':
|
56 |
+
self.norm = LayerNorm(n_chans // 16, n_chans)
|
57 |
+
elif self.norm == 'wn':
|
58 |
+
self.conv = torch.nn.utils.weight_norm(self.conv.conv)
|
59 |
+
self.dropout = nn.Dropout(dropout)
|
60 |
+
self.relu = nn.ReLU()
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
"""
|
64 |
+
|
65 |
+
:param x: [B, C, T]
|
66 |
+
:return: [B, C, T]
|
67 |
+
"""
|
68 |
+
x = self.conv(x)
|
69 |
+
if not isinstance(self.norm, str):
|
70 |
+
if self.norm == 'none':
|
71 |
+
pass
|
72 |
+
elif self.norm == 'ln':
|
73 |
+
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
|
74 |
+
else:
|
75 |
+
x = self.norm(x)
|
76 |
+
x = self.relu(x)
|
77 |
+
x = self.dropout(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class ConvStacks(nn.Module):
|
82 |
+
def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
|
83 |
+
dropout=0, strides=None, res=True):
|
84 |
+
super().__init__()
|
85 |
+
self.conv = torch.nn.ModuleList()
|
86 |
+
self.kernel_size = kernel_size
|
87 |
+
self.res = res
|
88 |
+
self.in_proj = Linear(idim, n_chans)
|
89 |
+
if strides is None:
|
90 |
+
strides = [1] * n_layers
|
91 |
+
else:
|
92 |
+
assert len(strides) == n_layers
|
93 |
+
for idx in range(n_layers):
|
94 |
+
self.conv.append(ConvBlock(
|
95 |
+
n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
|
96 |
+
self.out_proj = Linear(n_chans, odim)
|
97 |
+
|
98 |
+
def forward(self, x, return_hiddens=False):
|
99 |
+
"""
|
100 |
+
|
101 |
+
:param x: [B, T, H]
|
102 |
+
:return: [B, T, H]
|
103 |
+
"""
|
104 |
+
x = self.in_proj(x)
|
105 |
+
x = x.transpose(1, -1) # (B, idim, Tmax)
|
106 |
+
hiddens = []
|
107 |
+
for f in self.conv:
|
108 |
+
x_ = f(x)
|
109 |
+
x = x + x_ if self.res else x_ # (B, C, Tmax)
|
110 |
+
hiddens.append(x)
|
111 |
+
x = x.transpose(1, -1)
|
112 |
+
x = self.out_proj(x) # (B, Tmax, H)
|
113 |
+
if return_hiddens:
|
114 |
+
hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
|
115 |
+
return x, hiddens
|
116 |
+
return x
|
117 |
+
|
118 |
+
|
119 |
+
class PitchExtractor(nn.Module):
|
120 |
+
def __init__(self, n_mel_bins=80, conv_layers=2):
|
121 |
+
super().__init__()
|
122 |
+
self.hidden_size = 256
|
123 |
+
self.predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
|
124 |
+
self.conv_layers = conv_layers
|
125 |
+
|
126 |
+
self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1])
|
127 |
+
if self.conv_layers > 0:
|
128 |
+
self.mel_encoder = ConvStacks(
|
129 |
+
idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers)
|
130 |
+
self.pitch_predictor = PitchPredictor(
|
131 |
+
self.hidden_size, n_chans=self.predictor_hidden,
|
132 |
+
n_layers=5, dropout_rate=0.5, odim=2,
|
133 |
+
padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
|
134 |
+
|
135 |
+
def forward(self, mel_input=None):
|
136 |
+
ret = {}
|
137 |
+
mel_hidden = self.mel_prenet(mel_input)[1]
|
138 |
+
if self.conv_layers > 0:
|
139 |
+
mel_hidden = self.mel_encoder(mel_hidden)
|
140 |
+
|
141 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden)
|
142 |
+
|
143 |
+
pitch_padding = mel_input.abs().sum(-1) == 0
|
144 |
+
use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv']
|
145 |
+
|
146 |
+
ret['f0_denorm_pred'] = denorm_f0(
|
147 |
+
pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None,
|
148 |
+
hparams, pitch_padding=pitch_padding)
|
149 |
+
return ret
|
modules/fastspeech/tts_modules.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from modules.commons.espnet_positional_embedding import RelPositionalEncoding
|
9 |
+
from modules.commons.common_layers import SinusoidalPositionalEmbedding, Linear, EncSALayer, DecSALayer, BatchNorm1dTBC
|
10 |
+
from utils.hparams import hparams
|
11 |
+
|
12 |
+
DEFAULT_MAX_SOURCE_POSITIONS = 2000
|
13 |
+
DEFAULT_MAX_TARGET_POSITIONS = 2000
|
14 |
+
|
15 |
+
|
16 |
+
class TransformerEncoderLayer(nn.Module):
|
17 |
+
def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
|
18 |
+
super().__init__()
|
19 |
+
self.hidden_size = hidden_size
|
20 |
+
self.dropout = dropout
|
21 |
+
self.num_heads = num_heads
|
22 |
+
self.op = EncSALayer(
|
23 |
+
hidden_size, num_heads, dropout=dropout,
|
24 |
+
attention_dropout=0.0, relu_dropout=dropout,
|
25 |
+
kernel_size=kernel_size
|
26 |
+
if kernel_size is not None else hparams['enc_ffn_kernel_size'],
|
27 |
+
padding=hparams['ffn_padding'],
|
28 |
+
norm=norm, act=hparams['ffn_act'])
|
29 |
+
|
30 |
+
def forward(self, x, **kwargs):
|
31 |
+
return self.op(x, **kwargs)
|
32 |
+
|
33 |
+
|
34 |
+
######################
|
35 |
+
# fastspeech modules
|
36 |
+
######################
|
37 |
+
class LayerNorm(torch.nn.LayerNorm):
|
38 |
+
"""Layer normalization module.
|
39 |
+
:param int nout: output dim size
|
40 |
+
:param int dim: dimension to be normalized
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, nout, dim=-1):
|
44 |
+
"""Construct an LayerNorm object."""
|
45 |
+
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
46 |
+
self.dim = dim
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
"""Apply layer normalization.
|
50 |
+
:param torch.Tensor x: input tensor
|
51 |
+
:return: layer normalized tensor
|
52 |
+
:rtype torch.Tensor
|
53 |
+
"""
|
54 |
+
if self.dim == -1:
|
55 |
+
return super(LayerNorm, self).forward(x)
|
56 |
+
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
57 |
+
|
58 |
+
|
59 |
+
class DurationPredictor(torch.nn.Module):
|
60 |
+
"""Duration predictor module.
|
61 |
+
This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
62 |
+
The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
|
63 |
+
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
64 |
+
https://arxiv.org/pdf/1905.09263.pdf
|
65 |
+
Note:
|
66 |
+
The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`,
|
67 |
+
the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
|
71 |
+
"""Initilize duration predictor module.
|
72 |
+
Args:
|
73 |
+
idim (int): Input dimension.
|
74 |
+
n_layers (int, optional): Number of convolutional layers.
|
75 |
+
n_chans (int, optional): Number of channels of convolutional layers.
|
76 |
+
kernel_size (int, optional): Kernel size of convolutional layers.
|
77 |
+
dropout_rate (float, optional): Dropout rate.
|
78 |
+
offset (float, optional): Offset value to avoid nan in log domain.
|
79 |
+
"""
|
80 |
+
super(DurationPredictor, self).__init__()
|
81 |
+
self.offset = offset
|
82 |
+
self.conv = torch.nn.ModuleList()
|
83 |
+
self.kernel_size = kernel_size
|
84 |
+
self.padding = padding
|
85 |
+
for idx in range(n_layers):
|
86 |
+
in_chans = idim if idx == 0 else n_chans
|
87 |
+
self.conv += [torch.nn.Sequential(
|
88 |
+
torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
|
89 |
+
if padding == 'SAME'
|
90 |
+
else (kernel_size - 1, 0), 0),
|
91 |
+
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
|
92 |
+
torch.nn.ReLU(),
|
93 |
+
LayerNorm(n_chans, dim=1),
|
94 |
+
torch.nn.Dropout(dropout_rate)
|
95 |
+
)]
|
96 |
+
if hparams['dur_loss'] in ['mse', 'huber']:
|
97 |
+
odims = 1
|
98 |
+
elif hparams['dur_loss'] == 'mog':
|
99 |
+
odims = 15
|
100 |
+
elif hparams['dur_loss'] == 'crf':
|
101 |
+
odims = 32
|
102 |
+
from torchcrf import CRF
|
103 |
+
self.crf = CRF(odims, batch_first=True)
|
104 |
+
self.linear = torch.nn.Linear(n_chans, odims)
|
105 |
+
|
106 |
+
def _forward(self, xs, x_masks=None, is_inference=False):
|
107 |
+
xs = xs.transpose(1, -1) # (B, idim, Tmax)
|
108 |
+
for f in self.conv:
|
109 |
+
xs = f(xs) # (B, C, Tmax)
|
110 |
+
if x_masks is not None:
|
111 |
+
xs = xs * (1 - x_masks.float())[:, None, :]
|
112 |
+
|
113 |
+
xs = self.linear(xs.transpose(1, -1)) # [B, T, C]
|
114 |
+
xs = xs * (1 - x_masks.float())[:, :, None] # (B, T, C)
|
115 |
+
if is_inference:
|
116 |
+
return self.out2dur(xs), xs
|
117 |
+
else:
|
118 |
+
if hparams['dur_loss'] in ['mse']:
|
119 |
+
xs = xs.squeeze(-1) # (B, Tmax)
|
120 |
+
return xs
|
121 |
+
|
122 |
+
def out2dur(self, xs):
|
123 |
+
if hparams['dur_loss'] in ['mse']:
|
124 |
+
# NOTE: calculate in log domain
|
125 |
+
xs = xs.squeeze(-1) # (B, Tmax)
|
126 |
+
dur = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
|
127 |
+
elif hparams['dur_loss'] == 'mog':
|
128 |
+
return NotImplementedError
|
129 |
+
elif hparams['dur_loss'] == 'crf':
|
130 |
+
dur = torch.LongTensor(self.crf.decode(xs)).cuda()
|
131 |
+
return dur
|
132 |
+
|
133 |
+
def forward(self, xs, x_masks=None):
|
134 |
+
"""Calculate forward propagation.
|
135 |
+
Args:
|
136 |
+
xs (Tensor): Batch of input sequences (B, Tmax, idim).
|
137 |
+
x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
|
138 |
+
Returns:
|
139 |
+
Tensor: Batch of predicted durations in log domain (B, Tmax).
|
140 |
+
"""
|
141 |
+
return self._forward(xs, x_masks, False)
|
142 |
+
|
143 |
+
def inference(self, xs, x_masks=None):
|
144 |
+
"""Inference duration.
|
145 |
+
Args:
|
146 |
+
xs (Tensor): Batch of input sequences (B, Tmax, idim).
|
147 |
+
x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
|
148 |
+
Returns:
|
149 |
+
LongTensor: Batch of predicted durations in linear domain (B, Tmax).
|
150 |
+
"""
|
151 |
+
return self._forward(xs, x_masks, True)
|
152 |
+
|
153 |
+
|
154 |
+
class LengthRegulator(torch.nn.Module):
|
155 |
+
def __init__(self, pad_value=0.0):
|
156 |
+
super(LengthRegulator, self).__init__()
|
157 |
+
self.pad_value = pad_value
|
158 |
+
|
159 |
+
def forward(self, dur, dur_padding=None, alpha=1.0):
|
160 |
+
"""
|
161 |
+
Example (no batch dim version):
|
162 |
+
1. dur = [2,2,3]
|
163 |
+
2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
|
164 |
+
3. token_mask = [[1,1,0,0,0,0,0],
|
165 |
+
[0,0,1,1,0,0,0],
|
166 |
+
[0,0,0,0,1,1,1]]
|
167 |
+
4. token_idx * token_mask = [[1,1,0,0,0,0,0],
|
168 |
+
[0,0,2,2,0,0,0],
|
169 |
+
[0,0,0,0,3,3,3]]
|
170 |
+
5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
|
171 |
+
|
172 |
+
:param dur: Batch of durations of each frame (B, T_txt)
|
173 |
+
:param dur_padding: Batch of padding of each frame (B, T_txt)
|
174 |
+
:param alpha: duration rescale coefficient
|
175 |
+
:return:
|
176 |
+
mel2ph (B, T_speech)
|
177 |
+
"""
|
178 |
+
assert alpha > 0
|
179 |
+
dur = torch.round(dur.float() * alpha).long()
|
180 |
+
if dur_padding is not None:
|
181 |
+
dur = dur * (1 - dur_padding.long())
|
182 |
+
token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
|
183 |
+
dur_cumsum = torch.cumsum(dur, 1)
|
184 |
+
dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
|
185 |
+
|
186 |
+
pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
|
187 |
+
token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
|
188 |
+
mel2ph = (token_idx * token_mask.long()).sum(1)
|
189 |
+
return mel2ph
|
190 |
+
|
191 |
+
|
192 |
+
class PitchPredictor(torch.nn.Module):
|
193 |
+
def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
|
194 |
+
dropout_rate=0.1, padding='SAME'):
|
195 |
+
"""Initilize pitch predictor module.
|
196 |
+
Args:
|
197 |
+
idim (int): Input dimension.
|
198 |
+
n_layers (int, optional): Number of convolutional layers.
|
199 |
+
n_chans (int, optional): Number of channels of convolutional layers.
|
200 |
+
kernel_size (int, optional): Kernel size of convolutional layers.
|
201 |
+
dropout_rate (float, optional): Dropout rate.
|
202 |
+
"""
|
203 |
+
super(PitchPredictor, self).__init__()
|
204 |
+
self.conv = torch.nn.ModuleList()
|
205 |
+
self.kernel_size = kernel_size
|
206 |
+
self.padding = padding
|
207 |
+
for idx in range(n_layers):
|
208 |
+
in_chans = idim if idx == 0 else n_chans
|
209 |
+
self.conv += [torch.nn.Sequential(
|
210 |
+
torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
|
211 |
+
if padding == 'SAME'
|
212 |
+
else (kernel_size - 1, 0), 0),
|
213 |
+
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
|
214 |
+
torch.nn.ReLU(),
|
215 |
+
LayerNorm(n_chans, dim=1),
|
216 |
+
torch.nn.Dropout(dropout_rate)
|
217 |
+
)]
|
218 |
+
self.linear = torch.nn.Linear(n_chans, odim)
|
219 |
+
self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096)
|
220 |
+
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1]))
|
221 |
+
|
222 |
+
def forward(self, xs):
|
223 |
+
"""
|
224 |
+
|
225 |
+
:param xs: [B, T, H]
|
226 |
+
:return: [B, T, H]
|
227 |
+
"""
|
228 |
+
positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0])
|
229 |
+
xs = xs + positions
|
230 |
+
xs = xs.transpose(1, -1) # (B, idim, Tmax)
|
231 |
+
for f in self.conv:
|
232 |
+
xs = f(xs) # (B, C, Tmax)
|
233 |
+
# NOTE: calculate in log domain
|
234 |
+
xs = self.linear(xs.transpose(1, -1)) # (B, Tmax, H)
|
235 |
+
return xs
|
236 |
+
|
237 |
+
|
238 |
+
class EnergyPredictor(PitchPredictor):
|
239 |
+
pass
|
240 |
+
|
241 |
+
|
242 |
+
def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
|
243 |
+
B, _ = mel2ph.shape
|
244 |
+
dur = mel2ph.new_zeros(B, T_txt + 1).scatter_add(1, mel2ph, torch.ones_like(mel2ph))
|
245 |
+
dur = dur[:, 1:]
|
246 |
+
if max_dur is not None:
|
247 |
+
dur = dur.clamp(max=max_dur)
|
248 |
+
return dur
|
249 |
+
|
250 |
+
|
251 |
+
class FFTBlocks(nn.Module):
|
252 |
+
def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, num_heads=2,
|
253 |
+
use_pos_embed=True, use_last_norm=True, norm='ln', use_pos_embed_alpha=True):
|
254 |
+
super().__init__()
|
255 |
+
self.num_layers = num_layers
|
256 |
+
embed_dim = self.hidden_size = hidden_size
|
257 |
+
self.dropout = dropout if dropout is not None else hparams['dropout']
|
258 |
+
self.use_pos_embed = use_pos_embed
|
259 |
+
self.use_last_norm = use_last_norm
|
260 |
+
if use_pos_embed:
|
261 |
+
self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
|
262 |
+
self.padding_idx = 0
|
263 |
+
self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
|
264 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
265 |
+
embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
266 |
+
)
|
267 |
+
|
268 |
+
self.layers = nn.ModuleList([])
|
269 |
+
self.layers.extend([
|
270 |
+
TransformerEncoderLayer(self.hidden_size, self.dropout,
|
271 |
+
kernel_size=ffn_kernel_size, num_heads=num_heads)
|
272 |
+
for _ in range(self.num_layers)
|
273 |
+
])
|
274 |
+
if self.use_last_norm:
|
275 |
+
if norm == 'ln':
|
276 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
277 |
+
elif norm == 'bn':
|
278 |
+
self.layer_norm = BatchNorm1dTBC(embed_dim)
|
279 |
+
else:
|
280 |
+
self.layer_norm = None
|
281 |
+
|
282 |
+
def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
|
283 |
+
"""
|
284 |
+
:param x: [B, T, C]
|
285 |
+
:param padding_mask: [B, T]
|
286 |
+
:return: [B, T, C] or [L, B, T, C]
|
287 |
+
"""
|
288 |
+
padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
|
289 |
+
nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
|
290 |
+
if self.use_pos_embed:
|
291 |
+
positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
|
292 |
+
x = x + positions
|
293 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
294 |
+
# B x T x C -> T x B x C
|
295 |
+
x = x.transpose(0, 1) * nonpadding_mask_TB
|
296 |
+
hiddens = []
|
297 |
+
for layer in self.layers:
|
298 |
+
x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
|
299 |
+
hiddens.append(x)
|
300 |
+
if self.use_last_norm:
|
301 |
+
x = self.layer_norm(x) * nonpadding_mask_TB
|
302 |
+
if return_hiddens:
|
303 |
+
x = torch.stack(hiddens, 0) # [L, T, B, C]
|
304 |
+
x = x.transpose(1, 2) # [L, B, T, C]
|
305 |
+
else:
|
306 |
+
x = x.transpose(0, 1) # [B, T, C]
|
307 |
+
return x
|
308 |
+
|
309 |
+
|
310 |
+
class FastspeechEncoder(FFTBlocks):
|
311 |
+
def __init__(self, embed_tokens, hidden_size=None, num_layers=None, kernel_size=None, num_heads=2):
|
312 |
+
hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
|
313 |
+
kernel_size = hparams['enc_ffn_kernel_size'] if kernel_size is None else kernel_size
|
314 |
+
num_layers = hparams['dec_layers'] if num_layers is None else num_layers
|
315 |
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
|
316 |
+
use_pos_embed=False) # use_pos_embed_alpha for compatibility
|
317 |
+
self.embed_tokens = embed_tokens
|
318 |
+
self.embed_scale = math.sqrt(hidden_size)
|
319 |
+
self.padding_idx = 0
|
320 |
+
if hparams.get('rel_pos') is not None and hparams['rel_pos']:
|
321 |
+
self.embed_positions = RelPositionalEncoding(hidden_size, dropout_rate=0.0)
|
322 |
+
else:
|
323 |
+
self.embed_positions = SinusoidalPositionalEmbedding(
|
324 |
+
hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
|
325 |
+
)
|
326 |
+
|
327 |
+
def forward(self, txt_tokens):
|
328 |
+
"""
|
329 |
+
|
330 |
+
:param txt_tokens: [B, T]
|
331 |
+
:return: {
|
332 |
+
'encoder_out': [T x B x C]
|
333 |
+
}
|
334 |
+
"""
|
335 |
+
encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
|
336 |
+
x = self.forward_embedding(txt_tokens) # [B, T, H]
|
337 |
+
x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
|
338 |
+
return x
|
339 |
+
|
340 |
+
def forward_embedding(self, txt_tokens):
|
341 |
+
# embed tokens and positions
|
342 |
+
x = self.embed_scale * self.embed_tokens(txt_tokens)
|
343 |
+
if hparams['use_pos_embed']:
|
344 |
+
positions = self.embed_positions(txt_tokens)
|
345 |
+
x = x + positions
|
346 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
347 |
+
return x
|
348 |
+
|
349 |
+
|
350 |
+
class FastspeechDecoder(FFTBlocks):
|
351 |
+
def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
|
352 |
+
num_heads = hparams['num_heads'] if num_heads is None else num_heads
|
353 |
+
hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
|
354 |
+
kernel_size = hparams['dec_ffn_kernel_size'] if kernel_size is None else kernel_size
|
355 |
+
num_layers = hparams['dec_layers'] if num_layers is None else num_layers
|
356 |
+
super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
|
357 |
+
|
modules/hifigan/hifigan.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
6 |
+
|
7 |
+
from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
|
8 |
+
from modules.parallel_wavegan.models.source import SourceModuleHnNSF
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
LRELU_SLOPE = 0.1
|
12 |
+
|
13 |
+
|
14 |
+
def init_weights(m, mean=0.0, std=0.01):
|
15 |
+
classname = m.__class__.__name__
|
16 |
+
if classname.find("Conv") != -1:
|
17 |
+
m.weight.data.normal_(mean, std)
|
18 |
+
|
19 |
+
|
20 |
+
def apply_weight_norm(m):
|
21 |
+
classname = m.__class__.__name__
|
22 |
+
if classname.find("Conv") != -1:
|
23 |
+
weight_norm(m)
|
24 |
+
|
25 |
+
|
26 |
+
def get_padding(kernel_size, dilation=1):
|
27 |
+
return int((kernel_size * dilation - dilation) / 2)
|
28 |
+
|
29 |
+
|
30 |
+
class ResBlock1(torch.nn.Module):
|
31 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
32 |
+
super(ResBlock1, self).__init__()
|
33 |
+
self.h = h
|
34 |
+
self.convs1 = nn.ModuleList([
|
35 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
36 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
37 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
38 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
39 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
40 |
+
padding=get_padding(kernel_size, dilation[2])))
|
41 |
+
])
|
42 |
+
self.convs1.apply(init_weights)
|
43 |
+
|
44 |
+
self.convs2 = nn.ModuleList([
|
45 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
46 |
+
padding=get_padding(kernel_size, 1))),
|
47 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
48 |
+
padding=get_padding(kernel_size, 1))),
|
49 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
50 |
+
padding=get_padding(kernel_size, 1)))
|
51 |
+
])
|
52 |
+
self.convs2.apply(init_weights)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
56 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
57 |
+
xt = c1(xt)
|
58 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
59 |
+
xt = c2(xt)
|
60 |
+
x = xt + x
|
61 |
+
return x
|
62 |
+
|
63 |
+
def remove_weight_norm(self):
|
64 |
+
for l in self.convs1:
|
65 |
+
remove_weight_norm(l)
|
66 |
+
for l in self.convs2:
|
67 |
+
remove_weight_norm(l)
|
68 |
+
|
69 |
+
|
70 |
+
class ResBlock2(torch.nn.Module):
|
71 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
72 |
+
super(ResBlock2, self).__init__()
|
73 |
+
self.h = h
|
74 |
+
self.convs = nn.ModuleList([
|
75 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
76 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
77 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
78 |
+
padding=get_padding(kernel_size, dilation[1])))
|
79 |
+
])
|
80 |
+
self.convs.apply(init_weights)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
for c in self.convs:
|
84 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
85 |
+
xt = c(xt)
|
86 |
+
x = xt + x
|
87 |
+
return x
|
88 |
+
|
89 |
+
def remove_weight_norm(self):
|
90 |
+
for l in self.convs:
|
91 |
+
remove_weight_norm(l)
|
92 |
+
|
93 |
+
|
94 |
+
class Conv1d1x1(Conv1d):
|
95 |
+
"""1x1 Conv1d with customized initialization."""
|
96 |
+
|
97 |
+
def __init__(self, in_channels, out_channels, bias):
|
98 |
+
"""Initialize 1x1 Conv1d module."""
|
99 |
+
super(Conv1d1x1, self).__init__(in_channels, out_channels,
|
100 |
+
kernel_size=1, padding=0,
|
101 |
+
dilation=1, bias=bias)
|
102 |
+
|
103 |
+
|
104 |
+
class HifiGanGenerator(torch.nn.Module):
|
105 |
+
def __init__(self, h, c_out=1):
|
106 |
+
super(HifiGanGenerator, self).__init__()
|
107 |
+
self.h = h
|
108 |
+
self.num_kernels = len(h['resblock_kernel_sizes'])
|
109 |
+
self.num_upsamples = len(h['upsample_rates'])
|
110 |
+
|
111 |
+
if h['use_pitch_embed']:
|
112 |
+
self.harmonic_num = 8
|
113 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
|
114 |
+
self.m_source = SourceModuleHnNSF(
|
115 |
+
sampling_rate=h['audio_sample_rate'],
|
116 |
+
harmonic_num=self.harmonic_num)
|
117 |
+
self.noise_convs = nn.ModuleList()
|
118 |
+
self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
|
119 |
+
resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
|
120 |
+
|
121 |
+
self.ups = nn.ModuleList()
|
122 |
+
for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
|
123 |
+
c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
|
124 |
+
self.ups.append(weight_norm(
|
125 |
+
ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
|
126 |
+
if h['use_pitch_embed']:
|
127 |
+
if i + 1 < len(h['upsample_rates']):
|
128 |
+
stride_f0 = np.prod(h['upsample_rates'][i + 1:])
|
129 |
+
self.noise_convs.append(Conv1d(
|
130 |
+
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
|
131 |
+
else:
|
132 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
133 |
+
|
134 |
+
self.resblocks = nn.ModuleList()
|
135 |
+
for i in range(len(self.ups)):
|
136 |
+
ch = h['upsample_initial_channel'] // (2 ** (i + 1))
|
137 |
+
for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
|
138 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
139 |
+
|
140 |
+
self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
|
141 |
+
self.ups.apply(init_weights)
|
142 |
+
self.conv_post.apply(init_weights)
|
143 |
+
|
144 |
+
def forward(self, x, f0=None):
|
145 |
+
if f0 is not None:
|
146 |
+
# harmonic-source signal, noise-source signal, uv flag
|
147 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
|
148 |
+
har_source, noi_source, uv = self.m_source(f0)
|
149 |
+
har_source = har_source.transpose(1, 2)
|
150 |
+
|
151 |
+
x = self.conv_pre(x)
|
152 |
+
for i in range(self.num_upsamples):
|
153 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
154 |
+
x = self.ups[i](x)
|
155 |
+
if f0 is not None:
|
156 |
+
x_source = self.noise_convs[i](har_source)
|
157 |
+
x_source = torch.nn.functional.relu(x_source)
|
158 |
+
tmp_shape = x_source.shape[1]
|
159 |
+
x_source = torch.nn.functional.layer_norm(x_source.transpose(1, -1), (tmp_shape, )).transpose(1, -1)
|
160 |
+
x = x + x_source
|
161 |
+
xs = None
|
162 |
+
for j in range(self.num_kernels):
|
163 |
+
xs_ = self.resblocks[i * self.num_kernels + j](x)
|
164 |
+
if xs is None:
|
165 |
+
xs = xs_
|
166 |
+
else:
|
167 |
+
xs += xs_
|
168 |
+
x = xs / self.num_kernels
|
169 |
+
x = F.leaky_relu(x)
|
170 |
+
x = self.conv_post(x)
|
171 |
+
x = torch.tanh(x)
|
172 |
+
|
173 |
+
return x
|
174 |
+
|
175 |
+
def remove_weight_norm(self):
|
176 |
+
print('Removing weight norm...')
|
177 |
+
for l in self.ups:
|
178 |
+
remove_weight_norm(l)
|
179 |
+
for l in self.resblocks:
|
180 |
+
l.remove_weight_norm()
|
181 |
+
remove_weight_norm(self.conv_pre)
|
182 |
+
remove_weight_norm(self.conv_post)
|
183 |
+
|
184 |
+
|
185 |
+
class DiscriminatorP(torch.nn.Module):
|
186 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
|
187 |
+
super(DiscriminatorP, self).__init__()
|
188 |
+
self.use_cond = use_cond
|
189 |
+
if use_cond:
|
190 |
+
from utils.hparams import hparams
|
191 |
+
t = hparams['hop_size']
|
192 |
+
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
|
193 |
+
c_in = 2
|
194 |
+
|
195 |
+
self.period = period
|
196 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
197 |
+
self.convs = nn.ModuleList([
|
198 |
+
norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
199 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
200 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
201 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
202 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
203 |
+
])
|
204 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
205 |
+
|
206 |
+
def forward(self, x, mel):
|
207 |
+
fmap = []
|
208 |
+
if self.use_cond:
|
209 |
+
x_mel = self.cond_net(mel)
|
210 |
+
x = torch.cat([x_mel, x], 1)
|
211 |
+
# 1d to 2d
|
212 |
+
b, c, t = x.shape
|
213 |
+
if t % self.period != 0: # pad first
|
214 |
+
n_pad = self.period - (t % self.period)
|
215 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
216 |
+
t = t + n_pad
|
217 |
+
x = x.view(b, c, t // self.period, self.period)
|
218 |
+
|
219 |
+
for l in self.convs:
|
220 |
+
x = l(x)
|
221 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
222 |
+
fmap.append(x)
|
223 |
+
x = self.conv_post(x)
|
224 |
+
fmap.append(x)
|
225 |
+
x = torch.flatten(x, 1, -1)
|
226 |
+
|
227 |
+
return x, fmap
|
228 |
+
|
229 |
+
|
230 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
231 |
+
def __init__(self, use_cond=False, c_in=1):
|
232 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
233 |
+
self.discriminators = nn.ModuleList([
|
234 |
+
DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
|
235 |
+
DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
|
236 |
+
DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
|
237 |
+
DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
|
238 |
+
DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
|
239 |
+
])
|
240 |
+
|
241 |
+
def forward(self, y, y_hat, mel=None):
|
242 |
+
y_d_rs = []
|
243 |
+
y_d_gs = []
|
244 |
+
fmap_rs = []
|
245 |
+
fmap_gs = []
|
246 |
+
for i, d in enumerate(self.discriminators):
|
247 |
+
y_d_r, fmap_r = d(y, mel)
|
248 |
+
y_d_g, fmap_g = d(y_hat, mel)
|
249 |
+
y_d_rs.append(y_d_r)
|
250 |
+
fmap_rs.append(fmap_r)
|
251 |
+
y_d_gs.append(y_d_g)
|
252 |
+
fmap_gs.append(fmap_g)
|
253 |
+
|
254 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
255 |
+
|
256 |
+
|
257 |
+
class DiscriminatorS(torch.nn.Module):
|
258 |
+
def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
|
259 |
+
super(DiscriminatorS, self).__init__()
|
260 |
+
self.use_cond = use_cond
|
261 |
+
if use_cond:
|
262 |
+
t = np.prod(upsample_rates)
|
263 |
+
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
|
264 |
+
c_in = 2
|
265 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
266 |
+
self.convs = nn.ModuleList([
|
267 |
+
norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
|
268 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
269 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
270 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
271 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
272 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
273 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
274 |
+
])
|
275 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
276 |
+
|
277 |
+
def forward(self, x, mel):
|
278 |
+
if self.use_cond:
|
279 |
+
x_mel = self.cond_net(mel)
|
280 |
+
x = torch.cat([x_mel, x], 1)
|
281 |
+
fmap = []
|
282 |
+
for l in self.convs:
|
283 |
+
x = l(x)
|
284 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
285 |
+
fmap.append(x)
|
286 |
+
x = self.conv_post(x)
|
287 |
+
fmap.append(x)
|
288 |
+
x = torch.flatten(x, 1, -1)
|
289 |
+
|
290 |
+
return x, fmap
|
291 |
+
|
292 |
+
|
293 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
294 |
+
def __init__(self, use_cond=False, c_in=1):
|
295 |
+
super(MultiScaleDiscriminator, self).__init__()
|
296 |
+
from utils.hparams import hparams
|
297 |
+
self.discriminators = nn.ModuleList([
|
298 |
+
DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
|
299 |
+
upsample_rates=[4, 4, hparams['hop_size'] // 16],
|
300 |
+
c_in=c_in),
|
301 |
+
DiscriminatorS(use_cond=use_cond,
|
302 |
+
upsample_rates=[4, 4, hparams['hop_size'] // 32],
|
303 |
+
c_in=c_in),
|
304 |
+
DiscriminatorS(use_cond=use_cond,
|
305 |
+
upsample_rates=[4, 4, hparams['hop_size'] // 64],
|
306 |
+
c_in=c_in),
|
307 |
+
])
|
308 |
+
self.meanpools = nn.ModuleList([
|
309 |
+
AvgPool1d(4, 2, padding=1),
|
310 |
+
AvgPool1d(4, 2, padding=1)
|
311 |
+
])
|
312 |
+
|
313 |
+
def forward(self, y, y_hat, mel=None):
|
314 |
+
y_d_rs = []
|
315 |
+
y_d_gs = []
|
316 |
+
fmap_rs = []
|
317 |
+
fmap_gs = []
|
318 |
+
for i, d in enumerate(self.discriminators):
|
319 |
+
if i != 0:
|
320 |
+
y = self.meanpools[i - 1](y)
|
321 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
322 |
+
y_d_r, fmap_r = d(y, mel)
|
323 |
+
y_d_g, fmap_g = d(y_hat, mel)
|
324 |
+
y_d_rs.append(y_d_r)
|
325 |
+
fmap_rs.append(fmap_r)
|
326 |
+
y_d_gs.append(y_d_g)
|
327 |
+
fmap_gs.append(fmap_g)
|
328 |
+
|
329 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
330 |
+
|
331 |
+
|
332 |
+
def feature_loss(fmap_r, fmap_g):
|
333 |
+
loss = 0
|
334 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
335 |
+
for rl, gl in zip(dr, dg):
|
336 |
+
loss += torch.mean(torch.abs(rl - gl))
|
337 |
+
|
338 |
+
return loss * 2
|
339 |
+
|
340 |
+
|
341 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
342 |
+
r_losses = 0
|
343 |
+
g_losses = 0
|
344 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
345 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
346 |
+
g_loss = torch.mean(dg ** 2)
|
347 |
+
r_losses += r_loss
|
348 |
+
g_losses += g_loss
|
349 |
+
r_losses = r_losses / len(disc_real_outputs)
|
350 |
+
g_losses = g_losses / len(disc_real_outputs)
|
351 |
+
return r_losses, g_losses
|
352 |
+
|
353 |
+
|
354 |
+
def cond_discriminator_loss(outputs):
|
355 |
+
loss = 0
|
356 |
+
for dg in outputs:
|
357 |
+
g_loss = torch.mean(dg ** 2)
|
358 |
+
loss += g_loss
|
359 |
+
loss = loss / len(outputs)
|
360 |
+
return loss
|
361 |
+
|
362 |
+
|
363 |
+
def generator_loss(disc_outputs):
|
364 |
+
loss = 0
|
365 |
+
for dg in disc_outputs:
|
366 |
+
l = torch.mean((1 - dg) ** 2)
|
367 |
+
loss += l
|
368 |
+
loss = loss / len(disc_outputs)
|
369 |
+
return loss
|
370 |
+
|
modules/hifigan/mel_utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
from librosa.filters import mel as librosa_mel_fn
|
5 |
+
from scipy.io.wavfile import read
|
6 |
+
|
7 |
+
MAX_WAV_VALUE = 32768.0
|
8 |
+
|
9 |
+
|
10 |
+
def load_wav(full_path):
|
11 |
+
sampling_rate, data = read(full_path)
|
12 |
+
return data, sampling_rate
|
13 |
+
|
14 |
+
|
15 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
16 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
17 |
+
|
18 |
+
|
19 |
+
def dynamic_range_decompression(x, C=1):
|
20 |
+
return np.exp(x) / C
|
21 |
+
|
22 |
+
|
23 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
24 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
25 |
+
|
26 |
+
|
27 |
+
def dynamic_range_decompression_torch(x, C=1):
|
28 |
+
return torch.exp(x) / C
|
29 |
+
|
30 |
+
|
31 |
+
def spectral_normalize_torch(magnitudes):
|
32 |
+
output = dynamic_range_compression_torch(magnitudes)
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
def spectral_de_normalize_torch(magnitudes):
|
37 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
38 |
+
return output
|
39 |
+
|
40 |
+
|
41 |
+
mel_basis = {}
|
42 |
+
hann_window = {}
|
43 |
+
|
44 |
+
|
45 |
+
def mel_spectrogram(y, hparams, center=False, complex=False):
|
46 |
+
# hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
|
47 |
+
# win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
|
48 |
+
# fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
49 |
+
# fmax: 10000 # To be increased/reduced depending on data.
|
50 |
+
# fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
|
51 |
+
# n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
|
52 |
+
n_fft = hparams['fft_size']
|
53 |
+
num_mels = hparams['audio_num_mel_bins']
|
54 |
+
sampling_rate = hparams['audio_sample_rate']
|
55 |
+
hop_size = hparams['hop_size']
|
56 |
+
win_size = hparams['win_size']
|
57 |
+
fmin = hparams['fmin']
|
58 |
+
fmax = hparams['fmax']
|
59 |
+
y = y.clamp(min=-1., max=1.)
|
60 |
+
global mel_basis, hann_window
|
61 |
+
if fmax not in mel_basis:
|
62 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
63 |
+
mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
64 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
65 |
+
|
66 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
67 |
+
mode='reflect')
|
68 |
+
y = y.squeeze(1)
|
69 |
+
|
70 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
71 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
72 |
+
|
73 |
+
if not complex:
|
74 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
75 |
+
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
|
76 |
+
spec = spectral_normalize_torch(spec)
|
77 |
+
else:
|
78 |
+
B, C, T, _ = spec.shape
|
79 |
+
spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
|
80 |
+
return spec
|
81 |
+
|
modules/parallel_wavegan/__init__.py
ADDED
File without changes
|