Spaces:
Sleeping
Sleeping
yuancwang
commited on
Commit
·
b725c5a
1
Parent(s):
3e8a9fc
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +61 -0
- app.py +31 -0
- config/audioldm.json +92 -0
- config/autoencoderkl.json +69 -0
- config/base.json +220 -0
- config/comosvc.json +216 -0
- config/diffusion.json +227 -0
- config/fs2.json +118 -0
- config/ns2.json +88 -0
- config/transformer.json +180 -0
- config/tts.json +23 -0
- config/valle.json +53 -0
- config/vits.json +101 -0
- config/vitssvc.json +192 -0
- config/vocoder.json +84 -0
- evaluation/__init__.py +0 -0
- evaluation/features/__init__.py +0 -0
- evaluation/features/long_term_average_spectrum.py +19 -0
- evaluation/features/signal_to_noise_ratio.py +133 -0
- evaluation/features/singing_power_ratio.py +108 -0
- evaluation/metrics/__init__.py +0 -0
- evaluation/metrics/energy/__init__.py +0 -0
- evaluation/metrics/energy/energy_pearson_coefficients.py +91 -0
- evaluation/metrics/energy/energy_rmse.py +86 -0
- evaluation/metrics/f0/__init__.py +0 -0
- evaluation/metrics/f0/f0_pearson_coefficients.py +111 -0
- evaluation/metrics/f0/f0_periodicity_rmse.py +112 -0
- evaluation/metrics/f0/f0_rmse.py +110 -0
- evaluation/metrics/f0/v_uv_f1.py +110 -0
- evaluation/metrics/intelligibility/__init__.py +0 -0
- evaluation/metrics/intelligibility/character_error_rate.py +81 -0
- evaluation/metrics/intelligibility/word_error_rate.py +81 -0
- evaluation/metrics/similarity/__init__.py +0 -0
- evaluation/metrics/similarity/models/RawNetBasicBlock.py +146 -0
- evaluation/metrics/similarity/models/RawNetModel.py +142 -0
- evaluation/metrics/similarity/models/__init__.py +0 -0
- evaluation/metrics/similarity/speaker_similarity.py +119 -0
- evaluation/metrics/spectrogram/__init__.py +0 -0
- evaluation/metrics/spectrogram/frechet_distance.py +31 -0
- evaluation/metrics/spectrogram/mel_cepstral_distortion.py +21 -0
- evaluation/metrics/spectrogram/multi_resolution_stft_distance.py +225 -0
- evaluation/metrics/spectrogram/pesq.py +56 -0
- evaluation/metrics/spectrogram/scale_invariant_signal_to_distortion_ratio.py +45 -0
- evaluation/metrics/spectrogram/scale_invariant_signal_to_noise_ratio.py +45 -0
- evaluation/metrics/spectrogram/short_time_objective_intelligibility.py +56 -0
- models/tts/base/__init__.py +7 -0
- models/tts/base/tts_dataset.py +389 -0
- models/tts/base/tts_inferece.py +268 -0
- models/tts/base/tts_trainer.py +699 -0
- models/tts/fastspeech2/__init__.py +0 -0
.gitignore
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Mac OS files
|
2 |
+
.DS_Store
|
3 |
+
|
4 |
+
# IDEs
|
5 |
+
.idea
|
6 |
+
.vs
|
7 |
+
.vscode
|
8 |
+
.cache
|
9 |
+
|
10 |
+
# GitHub files
|
11 |
+
.github
|
12 |
+
|
13 |
+
# Byte-compiled / optimized / DLL / cached files
|
14 |
+
__pycache__/
|
15 |
+
*.py[cod]
|
16 |
+
*$py.class
|
17 |
+
*.pyc
|
18 |
+
.temp
|
19 |
+
*.c
|
20 |
+
*.so
|
21 |
+
*.o
|
22 |
+
|
23 |
+
# Developing mode
|
24 |
+
_*.sh
|
25 |
+
_*.json
|
26 |
+
*.lst
|
27 |
+
yard*
|
28 |
+
*.out
|
29 |
+
evaluation/evalset_selection
|
30 |
+
mfa
|
31 |
+
egs/svc/*wavmark
|
32 |
+
egs/svc/custom
|
33 |
+
egs/svc/*/dev*
|
34 |
+
egs/svc/dev_exp_config.json
|
35 |
+
bins/svc/demo*
|
36 |
+
bins/svc/preprocess_custom.py
|
37 |
+
data
|
38 |
+
|
39 |
+
# Data and ckpt
|
40 |
+
*.pkl
|
41 |
+
*.pt
|
42 |
+
*.npy
|
43 |
+
*.npz
|
44 |
+
!modules/whisper_extractor/assets/mel_filters.npz
|
45 |
+
*.tar.gz
|
46 |
+
*.ckpt
|
47 |
+
*.wav
|
48 |
+
*.flac
|
49 |
+
pretrained/wenet/*conformer_exp
|
50 |
+
|
51 |
+
# Runtime data dirs
|
52 |
+
processed_data
|
53 |
+
data
|
54 |
+
model_ckpt
|
55 |
+
logs
|
56 |
+
*.ipynb
|
57 |
+
*.lst
|
58 |
+
source_audio
|
59 |
+
result
|
60 |
+
conversion_results
|
61 |
+
get_available_gpu.py
|
app.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
def build_codec():
|
8 |
+
...
|
9 |
+
|
10 |
+
def build_model():
|
11 |
+
...
|
12 |
+
|
13 |
+
def ns2_inference(
|
14 |
+
prmopt_audio_path,
|
15 |
+
text,
|
16 |
+
diffusion_steps=100,
|
17 |
+
):
|
18 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
19 |
+
|
20 |
+
demo_inputs = ...
|
21 |
+
demo_outputs = ...
|
22 |
+
|
23 |
+
demo = gr.Interface(
|
24 |
+
fn=ns2_inference,
|
25 |
+
inputs=demo_inputs,
|
26 |
+
outputs=demo_outputs,
|
27 |
+
title="Amphion Zero-Shot TTS NaturalSpeech2"
|
28 |
+
)
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
demo.launch()
|
config/audioldm.json
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/base.json",
|
3 |
+
"model_type": "AudioLDM",
|
4 |
+
"task_type": "tta",
|
5 |
+
"dataset": [
|
6 |
+
"AudioCaps"
|
7 |
+
],
|
8 |
+
"preprocess": {
|
9 |
+
// feature used for model training
|
10 |
+
"use_spkid": false,
|
11 |
+
"use_uv": false,
|
12 |
+
"use_frame_pitch": false,
|
13 |
+
"use_phone_pitch": false,
|
14 |
+
"use_frame_energy": false,
|
15 |
+
"use_phone_energy": false,
|
16 |
+
"use_mel": false,
|
17 |
+
"use_audio": false,
|
18 |
+
"use_label": false,
|
19 |
+
"use_one_hot": false,
|
20 |
+
"cond_mask_prob": 0.1
|
21 |
+
},
|
22 |
+
// model
|
23 |
+
"model": {
|
24 |
+
"audioldm": {
|
25 |
+
"image_size": 32,
|
26 |
+
"in_channels": 4,
|
27 |
+
"out_channels": 4,
|
28 |
+
"model_channels": 256,
|
29 |
+
"attention_resolutions": [
|
30 |
+
4,
|
31 |
+
2,
|
32 |
+
1
|
33 |
+
],
|
34 |
+
"num_res_blocks": 2,
|
35 |
+
"channel_mult": [
|
36 |
+
1,
|
37 |
+
2,
|
38 |
+
4
|
39 |
+
],
|
40 |
+
"num_heads": 8,
|
41 |
+
"use_spatial_transformer": true,
|
42 |
+
"transformer_depth": 1,
|
43 |
+
"context_dim": 768,
|
44 |
+
"use_checkpoint": true,
|
45 |
+
"legacy": false
|
46 |
+
},
|
47 |
+
"autoencoderkl": {
|
48 |
+
"ch": 128,
|
49 |
+
"ch_mult": [
|
50 |
+
1,
|
51 |
+
1,
|
52 |
+
2,
|
53 |
+
2,
|
54 |
+
4
|
55 |
+
],
|
56 |
+
"num_res_blocks": 2,
|
57 |
+
"in_channels": 1,
|
58 |
+
"z_channels": 4,
|
59 |
+
"out_ch": 1,
|
60 |
+
"double_z": true
|
61 |
+
},
|
62 |
+
"noise_scheduler": {
|
63 |
+
"num_train_timesteps": 1000,
|
64 |
+
"beta_start": 0.00085,
|
65 |
+
"beta_end": 0.012,
|
66 |
+
"beta_schedule": "scaled_linear",
|
67 |
+
"clip_sample": false,
|
68 |
+
"steps_offset": 1,
|
69 |
+
"set_alpha_to_one": false,
|
70 |
+
"skip_prk_steps": true,
|
71 |
+
"prediction_type": "epsilon"
|
72 |
+
}
|
73 |
+
},
|
74 |
+
// train
|
75 |
+
"train": {
|
76 |
+
"lronPlateau": {
|
77 |
+
"factor": 0.9,
|
78 |
+
"patience": 100,
|
79 |
+
"min_lr": 4.0e-5,
|
80 |
+
"verbose": true
|
81 |
+
},
|
82 |
+
"adam": {
|
83 |
+
"lr": 5.0e-5,
|
84 |
+
"betas": [
|
85 |
+
0.9,
|
86 |
+
0.999
|
87 |
+
],
|
88 |
+
"weight_decay": 1.0e-2,
|
89 |
+
"eps": 1.0e-8
|
90 |
+
}
|
91 |
+
}
|
92 |
+
}
|
config/autoencoderkl.json
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/base.json",
|
3 |
+
"model_type": "AutoencoderKL",
|
4 |
+
"task_type": "tta",
|
5 |
+
"dataset": [
|
6 |
+
"AudioCaps"
|
7 |
+
],
|
8 |
+
"preprocess": {
|
9 |
+
// feature used for model training
|
10 |
+
"use_spkid": false,
|
11 |
+
"use_uv": false,
|
12 |
+
"use_frame_pitch": false,
|
13 |
+
"use_phone_pitch": false,
|
14 |
+
"use_frame_energy": false,
|
15 |
+
"use_phone_energy": false,
|
16 |
+
"use_mel": false,
|
17 |
+
"use_audio": false,
|
18 |
+
"use_label": false,
|
19 |
+
"use_one_hot": false
|
20 |
+
},
|
21 |
+
// model
|
22 |
+
"model": {
|
23 |
+
"autoencoderkl": {
|
24 |
+
"ch": 128,
|
25 |
+
"ch_mult": [
|
26 |
+
1,
|
27 |
+
1,
|
28 |
+
2,
|
29 |
+
2,
|
30 |
+
4
|
31 |
+
],
|
32 |
+
"num_res_blocks": 2,
|
33 |
+
"in_channels": 1,
|
34 |
+
"z_channels": 4,
|
35 |
+
"out_ch": 1,
|
36 |
+
"double_z": true
|
37 |
+
},
|
38 |
+
"loss": {
|
39 |
+
"kl_weight": 1e-8,
|
40 |
+
"disc_weight": 0.5,
|
41 |
+
"disc_factor": 1.0,
|
42 |
+
"logvar_init": 0.0,
|
43 |
+
"min_adapt_d_weight": 0.0,
|
44 |
+
"max_adapt_d_weight": 10.0,
|
45 |
+
"disc_start": 50001,
|
46 |
+
"disc_in_channels": 1,
|
47 |
+
"disc_num_layers": 3,
|
48 |
+
"use_actnorm": false
|
49 |
+
}
|
50 |
+
},
|
51 |
+
// train
|
52 |
+
"train": {
|
53 |
+
"lronPlateau": {
|
54 |
+
"factor": 0.9,
|
55 |
+
"patience": 100,
|
56 |
+
"min_lr": 4.0e-5,
|
57 |
+
"verbose": true
|
58 |
+
},
|
59 |
+
"adam": {
|
60 |
+
"lr": 4.0e-4,
|
61 |
+
"betas": [
|
62 |
+
0.9,
|
63 |
+
0.999
|
64 |
+
],
|
65 |
+
"weight_decay": 1.0e-2,
|
66 |
+
"eps": 1.0e-8
|
67 |
+
}
|
68 |
+
}
|
69 |
+
}
|
config/base.json
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"supported_model_type": [
|
3 |
+
"GANVocoder",
|
4 |
+
"Fastspeech2",
|
5 |
+
"DiffSVC",
|
6 |
+
"Transformer",
|
7 |
+
"EDM",
|
8 |
+
"CD"
|
9 |
+
],
|
10 |
+
"task_type": "",
|
11 |
+
"dataset": [],
|
12 |
+
"use_custom_dataset": false,
|
13 |
+
"preprocess": {
|
14 |
+
"phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon"
|
15 |
+
// trim audio silence
|
16 |
+
"data_augment": false,
|
17 |
+
"trim_silence": false,
|
18 |
+
"num_silent_frames": 8,
|
19 |
+
"trim_fft_size": 512, // fft size used in trimming
|
20 |
+
"trim_hop_size": 128, // hop size used in trimming
|
21 |
+
"trim_top_db": 30, // top db used in trimming sensitive to each dataset
|
22 |
+
// acoustic features
|
23 |
+
"extract_mel": false,
|
24 |
+
"mel_extract_mode": "",
|
25 |
+
"extract_linear_spec": false,
|
26 |
+
"extract_mcep": false,
|
27 |
+
"extract_pitch": false,
|
28 |
+
"extract_acoustic_token": false,
|
29 |
+
"pitch_remove_outlier": false,
|
30 |
+
"extract_uv": false,
|
31 |
+
"pitch_norm": false,
|
32 |
+
"extract_audio": false,
|
33 |
+
"extract_label": false,
|
34 |
+
"pitch_extractor": "parselmouth", // pyin, dio, pyworld, pyreaper, parselmouth, CWT (Continuous Wavelet Transform)
|
35 |
+
"extract_energy": false,
|
36 |
+
"energy_remove_outlier": false,
|
37 |
+
"energy_norm": false,
|
38 |
+
"energy_extract_mode": "from_mel",
|
39 |
+
"extract_duration": false,
|
40 |
+
"extract_amplitude_phase": false,
|
41 |
+
"mel_min_max_norm": false,
|
42 |
+
// lingusitic features
|
43 |
+
"extract_phone": false,
|
44 |
+
"lexicon_path": "./text/lexicon/librispeech-lexicon.txt",
|
45 |
+
// content features
|
46 |
+
"extract_whisper_feature": false,
|
47 |
+
"extract_contentvec_feature": false,
|
48 |
+
"extract_mert_feature": false,
|
49 |
+
"extract_wenet_feature": false,
|
50 |
+
// Settings for data preprocessing
|
51 |
+
"n_mel": 80,
|
52 |
+
"win_size": 480,
|
53 |
+
"hop_size": 120,
|
54 |
+
"sample_rate": 24000,
|
55 |
+
"n_fft": 1024,
|
56 |
+
"fmin": 0,
|
57 |
+
"fmax": 12000,
|
58 |
+
"min_level_db": -115,
|
59 |
+
"ref_level_db": 20,
|
60 |
+
"bits": 8,
|
61 |
+
// Directory names of processed data or extracted features
|
62 |
+
"processed_dir": "processed_data",
|
63 |
+
"trimmed_wav_dir": "trimmed_wavs", // directory name of silence trimed wav
|
64 |
+
"raw_data": "raw_data",
|
65 |
+
"phone_dir": "phones",
|
66 |
+
"wav_dir": "wavs", // directory name of processed wav (such as downsampled waveform)
|
67 |
+
"audio_dir": "audios",
|
68 |
+
"log_amplitude_dir": "log_amplitudes",
|
69 |
+
"phase_dir": "phases",
|
70 |
+
"real_dir": "reals",
|
71 |
+
"imaginary_dir": "imaginarys",
|
72 |
+
"label_dir": "labels",
|
73 |
+
"linear_dir": "linears",
|
74 |
+
"mel_dir": "mels", // directory name of extraced mel features
|
75 |
+
"mcep_dir": "mcep", // directory name of extraced mcep features
|
76 |
+
"dur_dir": "durs",
|
77 |
+
"symbols_dict": "symbols.dict",
|
78 |
+
"lab_dir": "labs", // directory name of extraced label features
|
79 |
+
"wenet_dir": "wenet", // directory name of extraced wenet features
|
80 |
+
"contentvec_dir": "contentvec", // directory name of extraced wenet features
|
81 |
+
"pitch_dir": "pitches", // directory name of extraced pitch features
|
82 |
+
"energy_dir": "energys", // directory name of extracted energy features
|
83 |
+
"phone_pitch_dir": "phone_pitches", // directory name of extraced pitch features
|
84 |
+
"phone_energy_dir": "phone_energys", // directory name of extracted energy features
|
85 |
+
"uv_dir": "uvs", // directory name of extracted unvoiced features
|
86 |
+
"duration_dir": "duration", // ground-truth duration file
|
87 |
+
"phone_seq_file": "phone_seq_file", // phoneme sequence file
|
88 |
+
"file_lst": "file.lst",
|
89 |
+
"train_file": "train.json", // training set, the json file contains detailed information about the dataset, including dataset name, utterance id, duration of the utterance
|
90 |
+
"valid_file": "valid.json", // validattion set
|
91 |
+
"spk2id": "spk2id.json", // used for multi-speaker dataset
|
92 |
+
"utt2spk": "utt2spk", // used for multi-speaker dataset
|
93 |
+
"emo2id": "emo2id.json", // used for multi-emotion dataset
|
94 |
+
"utt2emo": "utt2emo", // used for multi-emotion dataset
|
95 |
+
// Features used for model training
|
96 |
+
"use_text": false,
|
97 |
+
"use_phone": false,
|
98 |
+
"use_phn_seq": false,
|
99 |
+
"use_lab": false,
|
100 |
+
"use_linear": false,
|
101 |
+
"use_mel": false,
|
102 |
+
"use_min_max_norm_mel": false,
|
103 |
+
"use_wav": false,
|
104 |
+
"use_phone_pitch": false,
|
105 |
+
"use_log_scale_pitch": false,
|
106 |
+
"use_phone_energy": false,
|
107 |
+
"use_phone_duration": false,
|
108 |
+
"use_log_scale_energy": false,
|
109 |
+
"use_wenet": false,
|
110 |
+
"use_dur": false,
|
111 |
+
"use_spkid": false, // True: use speaker id for multi-speaker dataset
|
112 |
+
"use_emoid": false, // True: use emotion id for multi-emotion dataset
|
113 |
+
"use_frame_pitch": false,
|
114 |
+
"use_uv": false,
|
115 |
+
"use_frame_energy": false,
|
116 |
+
"use_frame_duration": false,
|
117 |
+
"use_audio": false,
|
118 |
+
"use_label": false,
|
119 |
+
"use_one_hot": false,
|
120 |
+
"use_amplitude_phase": false,
|
121 |
+
"data_augment": false,
|
122 |
+
"align_mel_duration": false
|
123 |
+
},
|
124 |
+
"train": {
|
125 |
+
"ddp": true,
|
126 |
+
"random_seed": 970227,
|
127 |
+
"batch_size": 16,
|
128 |
+
"max_steps": 1000000,
|
129 |
+
// Trackers
|
130 |
+
"tracker": [
|
131 |
+
"tensorboard"
|
132 |
+
// "wandb",
|
133 |
+
// "cometml",
|
134 |
+
// "mlflow",
|
135 |
+
],
|
136 |
+
"max_epoch": -1,
|
137 |
+
// -1 means no limit
|
138 |
+
"save_checkpoint_stride": [
|
139 |
+
5,
|
140 |
+
20
|
141 |
+
],
|
142 |
+
// unit is epoch
|
143 |
+
"keep_last": [
|
144 |
+
3,
|
145 |
+
-1
|
146 |
+
],
|
147 |
+
// -1 means infinite, if one number will broadcast
|
148 |
+
"run_eval": [
|
149 |
+
false,
|
150 |
+
true
|
151 |
+
],
|
152 |
+
// if one number will broadcast
|
153 |
+
// Fix the random seed
|
154 |
+
"random_seed": 10086,
|
155 |
+
// Optimizer
|
156 |
+
"optimizer": "AdamW",
|
157 |
+
"adamw": {
|
158 |
+
"lr": 4.0e-4
|
159 |
+
// nn model lr
|
160 |
+
},
|
161 |
+
// LR Scheduler
|
162 |
+
"scheduler": "ReduceLROnPlateau",
|
163 |
+
"reducelronplateau": {
|
164 |
+
"factor": 0.8,
|
165 |
+
"patience": 10,
|
166 |
+
// unit is epoch
|
167 |
+
"min_lr": 1.0e-4
|
168 |
+
},
|
169 |
+
// Batchsampler
|
170 |
+
"sampler": {
|
171 |
+
"holistic_shuffle": true,
|
172 |
+
"drop_last": true
|
173 |
+
},
|
174 |
+
// Dataloader
|
175 |
+
"dataloader": {
|
176 |
+
"num_worker": 32,
|
177 |
+
"pin_memory": true
|
178 |
+
},
|
179 |
+
"gradient_accumulation_step": 1,
|
180 |
+
"total_training_steps": 50000,
|
181 |
+
"save_summary_steps": 500,
|
182 |
+
"save_checkpoints_steps": 10000,
|
183 |
+
"valid_interval": 10000,
|
184 |
+
"keep_checkpoint_max": 5,
|
185 |
+
"multi_speaker_training": false, // True: train multi-speaker model; False: training single-speaker model;
|
186 |
+
"max_epoch": -1,
|
187 |
+
// -1 means no limit
|
188 |
+
"save_checkpoint_stride": [
|
189 |
+
5,
|
190 |
+
20
|
191 |
+
],
|
192 |
+
// unit is epoch
|
193 |
+
"keep_last": [
|
194 |
+
3,
|
195 |
+
-1
|
196 |
+
],
|
197 |
+
// -1 means infinite, if one number will broadcast
|
198 |
+
"run_eval": [
|
199 |
+
false,
|
200 |
+
true
|
201 |
+
],
|
202 |
+
// Batchsampler
|
203 |
+
"sampler": {
|
204 |
+
"holistic_shuffle": true,
|
205 |
+
"drop_last": true
|
206 |
+
},
|
207 |
+
// Dataloader
|
208 |
+
"dataloader": {
|
209 |
+
"num_worker": 32,
|
210 |
+
"pin_memory": true
|
211 |
+
},
|
212 |
+
// Trackers
|
213 |
+
"tracker": [
|
214 |
+
"tensorboard"
|
215 |
+
// "wandb",
|
216 |
+
// "cometml",
|
217 |
+
// "mlflow",
|
218 |
+
],
|
219 |
+
},
|
220 |
+
}
|
config/comosvc.json
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/base.json",
|
3 |
+
"model_type": "DiffComoSVC",
|
4 |
+
"task_type": "svc",
|
5 |
+
"use_custom_dataset": false,
|
6 |
+
"preprocess": {
|
7 |
+
// data augmentations
|
8 |
+
"use_pitch_shift": false,
|
9 |
+
"use_formant_shift": false,
|
10 |
+
"use_time_stretch": false,
|
11 |
+
"use_equalizer": false,
|
12 |
+
// acoustic features
|
13 |
+
"extract_mel": true,
|
14 |
+
"mel_min_max_norm": true,
|
15 |
+
"extract_pitch": true,
|
16 |
+
"pitch_extractor": "parselmouth",
|
17 |
+
"extract_uv": true,
|
18 |
+
"extract_energy": true,
|
19 |
+
// content features
|
20 |
+
"extract_whisper_feature": false,
|
21 |
+
"whisper_sample_rate": 16000,
|
22 |
+
"extract_contentvec_feature": false,
|
23 |
+
"contentvec_sample_rate": 16000,
|
24 |
+
"extract_wenet_feature": false,
|
25 |
+
"wenet_sample_rate": 16000,
|
26 |
+
"extract_mert_feature": false,
|
27 |
+
"mert_sample_rate": 16000,
|
28 |
+
// Default config for whisper
|
29 |
+
"whisper_frameshift": 0.01,
|
30 |
+
"whisper_downsample_rate": 2,
|
31 |
+
// Default config for content vector
|
32 |
+
"contentvec_frameshift": 0.02,
|
33 |
+
// Default config for mert
|
34 |
+
"mert_model": "m-a-p/MERT-v1-330M",
|
35 |
+
"mert_feature_layer": -1,
|
36 |
+
"mert_hop_size": 320,
|
37 |
+
// 24k
|
38 |
+
"mert_frameshit": 0.01333,
|
39 |
+
// 10ms
|
40 |
+
"wenet_frameshift": 0.01,
|
41 |
+
// wenetspeech is 4, gigaspeech is 6
|
42 |
+
"wenet_downsample_rate": 4,
|
43 |
+
// Default config
|
44 |
+
"n_mel": 100,
|
45 |
+
"win_size": 1024,
|
46 |
+
// todo
|
47 |
+
"hop_size": 256,
|
48 |
+
"sample_rate": 24000,
|
49 |
+
"n_fft": 1024,
|
50 |
+
// todo
|
51 |
+
"fmin": 0,
|
52 |
+
"fmax": 12000,
|
53 |
+
// todo
|
54 |
+
"f0_min": 50,
|
55 |
+
// ~C2
|
56 |
+
"f0_max": 1100,
|
57 |
+
//1100, // ~C6(1100), ~G5(800)
|
58 |
+
"pitch_bin": 256,
|
59 |
+
"pitch_max": 1100.0,
|
60 |
+
"pitch_min": 50.0,
|
61 |
+
"is_label": true,
|
62 |
+
"is_mu_law": true,
|
63 |
+
"bits": 8,
|
64 |
+
"mel_min_max_stats_dir": "mel_min_max_stats",
|
65 |
+
"whisper_dir": "whisper",
|
66 |
+
"contentvec_dir": "contentvec",
|
67 |
+
"wenet_dir": "wenet",
|
68 |
+
"mert_dir": "mert",
|
69 |
+
// Extract content features using dataloader
|
70 |
+
"pin_memory": true,
|
71 |
+
"num_workers": 8,
|
72 |
+
"content_feature_batch_size": 16,
|
73 |
+
// Features used for model training
|
74 |
+
"use_mel": true,
|
75 |
+
"use_min_max_norm_mel": true,
|
76 |
+
"use_frame_pitch": true,
|
77 |
+
"use_uv": true,
|
78 |
+
"use_frame_energy": true,
|
79 |
+
"use_log_scale_pitch": false,
|
80 |
+
"use_log_scale_energy": false,
|
81 |
+
"use_spkid": true,
|
82 |
+
// Meta file
|
83 |
+
"train_file": "train.json",
|
84 |
+
"valid_file": "test.json",
|
85 |
+
"spk2id": "singers.json",
|
86 |
+
"utt2spk": "utt2singer"
|
87 |
+
},
|
88 |
+
"model": {
|
89 |
+
"teacher_model_path": "[Your Teacher Model Path].bin",
|
90 |
+
"condition_encoder": {
|
91 |
+
"merge_mode": "add",
|
92 |
+
"input_melody_dim": 1,
|
93 |
+
"use_log_f0": true,
|
94 |
+
"n_bins_melody": 256,
|
95 |
+
//# Quantization (0 for not quantization)
|
96 |
+
"output_melody_dim": 384,
|
97 |
+
"input_loudness_dim": 1,
|
98 |
+
"use_log_loudness": true,
|
99 |
+
"n_bins_loudness": 256,
|
100 |
+
"output_loudness_dim": 384,
|
101 |
+
"use_whisper": false,
|
102 |
+
"use_contentvec": false,
|
103 |
+
"use_wenet": false,
|
104 |
+
"use_mert": false,
|
105 |
+
"whisper_dim": 1024,
|
106 |
+
"contentvec_dim": 256,
|
107 |
+
"mert_dim": 256,
|
108 |
+
"wenet_dim": 512,
|
109 |
+
"content_encoder_dim": 384,
|
110 |
+
"output_singer_dim": 384,
|
111 |
+
"singer_table_size": 512,
|
112 |
+
"output_content_dim": 384,
|
113 |
+
"use_spkid": true
|
114 |
+
},
|
115 |
+
"comosvc": {
|
116 |
+
"distill": false,
|
117 |
+
// conformer encoder
|
118 |
+
"input_dim": 384,
|
119 |
+
"output_dim": 100,
|
120 |
+
"n_heads": 2,
|
121 |
+
"n_layers": 6,
|
122 |
+
"filter_channels": 512,
|
123 |
+
"dropout": 0.1,
|
124 |
+
// karras diffusion
|
125 |
+
"P_mean": -1.2,
|
126 |
+
"P_std": 1.2,
|
127 |
+
"sigma_data": 0.5,
|
128 |
+
"sigma_min": 0.002,
|
129 |
+
"sigma_max": 80,
|
130 |
+
"rho": 7,
|
131 |
+
"n_timesteps": 40,
|
132 |
+
},
|
133 |
+
"diffusion": {
|
134 |
+
// Diffusion steps encoder
|
135 |
+
"step_encoder": {
|
136 |
+
"dim_raw_embedding": 128,
|
137 |
+
"dim_hidden_layer": 512,
|
138 |
+
"activation": "SiLU",
|
139 |
+
"num_layer": 2,
|
140 |
+
"max_period": 10000
|
141 |
+
},
|
142 |
+
// Diffusion decoder
|
143 |
+
"model_type": "bidilconv",
|
144 |
+
// bidilconv, unet2d, TODO: unet1d
|
145 |
+
"bidilconv": {
|
146 |
+
"base_channel": 384,
|
147 |
+
"n_res_block": 20,
|
148 |
+
"conv_kernel_size": 3,
|
149 |
+
"dilation_cycle_length": 4,
|
150 |
+
// specially, 1 means no dilation
|
151 |
+
"conditioner_size": 100
|
152 |
+
}
|
153 |
+
},
|
154 |
+
},
|
155 |
+
"train": {
|
156 |
+
// Basic settings
|
157 |
+
"fast_steps": 0,
|
158 |
+
"batch_size": 32,
|
159 |
+
"gradient_accumulation_step": 1,
|
160 |
+
"max_epoch": -1,
|
161 |
+
// -1 means no limit
|
162 |
+
"save_checkpoint_stride": [
|
163 |
+
10,
|
164 |
+
100
|
165 |
+
],
|
166 |
+
// unit is epoch
|
167 |
+
"keep_last": [
|
168 |
+
3,
|
169 |
+
-1
|
170 |
+
],
|
171 |
+
// -1 means infinite, if one number will broadcast
|
172 |
+
"run_eval": [
|
173 |
+
false,
|
174 |
+
true
|
175 |
+
],
|
176 |
+
// if one number will broadcast
|
177 |
+
// Fix the random seed
|
178 |
+
"random_seed": 10086,
|
179 |
+
// Batchsampler
|
180 |
+
"sampler": {
|
181 |
+
"holistic_shuffle": true,
|
182 |
+
"drop_last": true
|
183 |
+
},
|
184 |
+
// Dataloader
|
185 |
+
"dataloader": {
|
186 |
+
"num_worker": 32,
|
187 |
+
"pin_memory": true
|
188 |
+
},
|
189 |
+
// Trackers
|
190 |
+
"tracker": [
|
191 |
+
"tensorboard"
|
192 |
+
// "wandb",
|
193 |
+
// "cometml",
|
194 |
+
// "mlflow",
|
195 |
+
],
|
196 |
+
// Optimizer
|
197 |
+
"optimizer": "AdamW",
|
198 |
+
"adamw": {
|
199 |
+
"lr": 4.0e-4
|
200 |
+
// nn model lr
|
201 |
+
},
|
202 |
+
// LR Scheduler
|
203 |
+
"scheduler": "ReduceLROnPlateau",
|
204 |
+
"reducelronplateau": {
|
205 |
+
"factor": 0.8,
|
206 |
+
"patience": 10,
|
207 |
+
// unit is epoch
|
208 |
+
"min_lr": 1.0e-4
|
209 |
+
}
|
210 |
+
},
|
211 |
+
"inference": {
|
212 |
+
"comosvc": {
|
213 |
+
"inference_steps": 40
|
214 |
+
}
|
215 |
+
}
|
216 |
+
}
|
config/diffusion.json
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// FIXME: THESE ARE LEGACY
|
3 |
+
"base_config": "config/base.json",
|
4 |
+
"model_type": "diffusion",
|
5 |
+
"task_type": "svc",
|
6 |
+
"use_custom_dataset": false,
|
7 |
+
"preprocess": {
|
8 |
+
// data augmentations
|
9 |
+
"use_pitch_shift": false,
|
10 |
+
"use_formant_shift": false,
|
11 |
+
"use_time_stretch": false,
|
12 |
+
"use_equalizer": false,
|
13 |
+
// acoustic features
|
14 |
+
"extract_mel": true,
|
15 |
+
"mel_min_max_norm": true,
|
16 |
+
"extract_pitch": true,
|
17 |
+
"pitch_extractor": "parselmouth",
|
18 |
+
"extract_uv": true,
|
19 |
+
"extract_energy": true,
|
20 |
+
// content features
|
21 |
+
"extract_whisper_feature": false,
|
22 |
+
"whisper_sample_rate": 16000,
|
23 |
+
"extract_contentvec_feature": false,
|
24 |
+
"contentvec_sample_rate": 16000,
|
25 |
+
"extract_wenet_feature": false,
|
26 |
+
"wenet_sample_rate": 16000,
|
27 |
+
"extract_mert_feature": false,
|
28 |
+
"mert_sample_rate": 16000,
|
29 |
+
// Default config for whisper
|
30 |
+
"whisper_frameshift": 0.01,
|
31 |
+
"whisper_downsample_rate": 2,
|
32 |
+
// Default config for content vector
|
33 |
+
"contentvec_frameshift": 0.02,
|
34 |
+
// Default config for mert
|
35 |
+
"mert_model": "m-a-p/MERT-v1-330M",
|
36 |
+
"mert_feature_layer": -1,
|
37 |
+
"mert_hop_size": 320,
|
38 |
+
// 24k
|
39 |
+
"mert_frameshit": 0.01333,
|
40 |
+
// 10ms
|
41 |
+
"wenet_frameshift": 0.01,
|
42 |
+
// wenetspeech is 4, gigaspeech is 6
|
43 |
+
"wenet_downsample_rate": 4,
|
44 |
+
// Default config
|
45 |
+
"n_mel": 100,
|
46 |
+
"win_size": 1024,
|
47 |
+
// todo
|
48 |
+
"hop_size": 256,
|
49 |
+
"sample_rate": 24000,
|
50 |
+
"n_fft": 1024,
|
51 |
+
// todo
|
52 |
+
"fmin": 0,
|
53 |
+
"fmax": 12000,
|
54 |
+
// todo
|
55 |
+
"f0_min": 50,
|
56 |
+
// ~C2
|
57 |
+
"f0_max": 1100,
|
58 |
+
//1100, // ~C6(1100), ~G5(800)
|
59 |
+
"pitch_bin": 256,
|
60 |
+
"pitch_max": 1100.0,
|
61 |
+
"pitch_min": 50.0,
|
62 |
+
"is_label": true,
|
63 |
+
"is_mu_law": true,
|
64 |
+
"bits": 8,
|
65 |
+
"mel_min_max_stats_dir": "mel_min_max_stats",
|
66 |
+
"whisper_dir": "whisper",
|
67 |
+
"contentvec_dir": "contentvec",
|
68 |
+
"wenet_dir": "wenet",
|
69 |
+
"mert_dir": "mert",
|
70 |
+
// Extract content features using dataloader
|
71 |
+
"pin_memory": true,
|
72 |
+
"num_workers": 8,
|
73 |
+
"content_feature_batch_size": 16,
|
74 |
+
// Features used for model training
|
75 |
+
"use_mel": true,
|
76 |
+
"use_min_max_norm_mel": true,
|
77 |
+
"use_frame_pitch": true,
|
78 |
+
"use_uv": true,
|
79 |
+
"use_frame_energy": true,
|
80 |
+
"use_log_scale_pitch": false,
|
81 |
+
"use_log_scale_energy": false,
|
82 |
+
"use_spkid": true,
|
83 |
+
// Meta file
|
84 |
+
"train_file": "train.json",
|
85 |
+
"valid_file": "test.json",
|
86 |
+
"spk2id": "singers.json",
|
87 |
+
"utt2spk": "utt2singer"
|
88 |
+
},
|
89 |
+
"model": {
|
90 |
+
"condition_encoder": {
|
91 |
+
"merge_mode": "add",
|
92 |
+
"input_melody_dim": 1,
|
93 |
+
"use_log_f0": true,
|
94 |
+
"n_bins_melody": 256,
|
95 |
+
//# Quantization (0 for not quantization)
|
96 |
+
"output_melody_dim": 384,
|
97 |
+
"input_loudness_dim": 1,
|
98 |
+
"use_log_loudness": true,
|
99 |
+
"n_bins_loudness": 256,
|
100 |
+
"output_loudness_dim": 384,
|
101 |
+
"use_whisper": false,
|
102 |
+
"use_contentvec": false,
|
103 |
+
"use_wenet": false,
|
104 |
+
"use_mert": false,
|
105 |
+
"whisper_dim": 1024,
|
106 |
+
"contentvec_dim": 256,
|
107 |
+
"mert_dim": 256,
|
108 |
+
"wenet_dim": 512,
|
109 |
+
"content_encoder_dim": 384,
|
110 |
+
"output_singer_dim": 384,
|
111 |
+
"singer_table_size": 512,
|
112 |
+
"output_content_dim": 384,
|
113 |
+
"use_spkid": true
|
114 |
+
},
|
115 |
+
// FIXME: FOLLOWING ARE NEW!!
|
116 |
+
"diffusion": {
|
117 |
+
"scheduler": "ddpm",
|
118 |
+
"scheduler_settings": {
|
119 |
+
"num_train_timesteps": 1000,
|
120 |
+
"beta_start": 1.0e-4,
|
121 |
+
"beta_end": 0.02,
|
122 |
+
"beta_schedule": "linear"
|
123 |
+
},
|
124 |
+
// Diffusion steps encoder
|
125 |
+
"step_encoder": {
|
126 |
+
"dim_raw_embedding": 128,
|
127 |
+
"dim_hidden_layer": 512,
|
128 |
+
"activation": "SiLU",
|
129 |
+
"num_layer": 2,
|
130 |
+
"max_period": 10000
|
131 |
+
},
|
132 |
+
// Diffusion decoder
|
133 |
+
"model_type": "bidilconv",
|
134 |
+
// bidilconv, unet2d, TODO: unet1d
|
135 |
+
"bidilconv": {
|
136 |
+
"base_channel": 384,
|
137 |
+
"n_res_block": 20,
|
138 |
+
"conv_kernel_size": 3,
|
139 |
+
"dilation_cycle_length": 4,
|
140 |
+
// specially, 1 means no dilation
|
141 |
+
"conditioner_size": 384
|
142 |
+
},
|
143 |
+
"unet2d": {
|
144 |
+
"in_channels": 1,
|
145 |
+
"out_channels": 1,
|
146 |
+
"down_block_types": [
|
147 |
+
"CrossAttnDownBlock2D",
|
148 |
+
"CrossAttnDownBlock2D",
|
149 |
+
"CrossAttnDownBlock2D",
|
150 |
+
"DownBlock2D"
|
151 |
+
],
|
152 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
153 |
+
"up_block_types": [
|
154 |
+
"UpBlock2D",
|
155 |
+
"CrossAttnUpBlock2D",
|
156 |
+
"CrossAttnUpBlock2D",
|
157 |
+
"CrossAttnUpBlock2D"
|
158 |
+
],
|
159 |
+
"only_cross_attention": false
|
160 |
+
}
|
161 |
+
}
|
162 |
+
},
|
163 |
+
// FIXME: FOLLOWING ARE NEW!!
|
164 |
+
"train": {
|
165 |
+
// Basic settings
|
166 |
+
"batch_size": 64,
|
167 |
+
"gradient_accumulation_step": 1,
|
168 |
+
"max_epoch": -1,
|
169 |
+
// -1 means no limit
|
170 |
+
"save_checkpoint_stride": [
|
171 |
+
5,
|
172 |
+
20
|
173 |
+
],
|
174 |
+
// unit is epoch
|
175 |
+
"keep_last": [
|
176 |
+
3,
|
177 |
+
-1
|
178 |
+
],
|
179 |
+
// -1 means infinite, if one number will broadcast
|
180 |
+
"run_eval": [
|
181 |
+
false,
|
182 |
+
true
|
183 |
+
],
|
184 |
+
// if one number will broadcast
|
185 |
+
// Fix the random seed
|
186 |
+
"random_seed": 10086,
|
187 |
+
// Batchsampler
|
188 |
+
"sampler": {
|
189 |
+
"holistic_shuffle": true,
|
190 |
+
"drop_last": true
|
191 |
+
},
|
192 |
+
// Dataloader
|
193 |
+
"dataloader": {
|
194 |
+
"num_worker": 32,
|
195 |
+
"pin_memory": true
|
196 |
+
},
|
197 |
+
// Trackers
|
198 |
+
"tracker": [
|
199 |
+
"tensorboard"
|
200 |
+
// "wandb",
|
201 |
+
// "cometml",
|
202 |
+
// "mlflow",
|
203 |
+
],
|
204 |
+
// Optimizer
|
205 |
+
"optimizer": "AdamW",
|
206 |
+
"adamw": {
|
207 |
+
"lr": 4.0e-4
|
208 |
+
// nn model lr
|
209 |
+
},
|
210 |
+
// LR Scheduler
|
211 |
+
"scheduler": "ReduceLROnPlateau",
|
212 |
+
"reducelronplateau": {
|
213 |
+
"factor": 0.8,
|
214 |
+
"patience": 10,
|
215 |
+
// unit is epoch
|
216 |
+
"min_lr": 1.0e-4
|
217 |
+
}
|
218 |
+
},
|
219 |
+
"inference": {
|
220 |
+
"diffusion": {
|
221 |
+
"scheduler": "pndm",
|
222 |
+
"scheduler_settings": {
|
223 |
+
"num_inference_timesteps": 1000
|
224 |
+
}
|
225 |
+
}
|
226 |
+
}
|
227 |
+
}
|
config/fs2.json
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/tts.json",
|
3 |
+
"model_type": "FastSpeech2",
|
4 |
+
"task_type": "tts",
|
5 |
+
"dataset": ["LJSpeech"],
|
6 |
+
"preprocess": {
|
7 |
+
// acoustic features
|
8 |
+
"extract_audio": true,
|
9 |
+
"extract_mel": true,
|
10 |
+
"mel_extract_mode": "taco",
|
11 |
+
"mel_min_max_norm": false,
|
12 |
+
"extract_pitch": true,
|
13 |
+
"extract_uv": false,
|
14 |
+
"pitch_extractor": "dio",
|
15 |
+
"extract_energy": true,
|
16 |
+
"energy_extract_mode": "from_tacotron_stft",
|
17 |
+
"extract_duration": true,
|
18 |
+
"use_phone": true,
|
19 |
+
"pitch_norm": true,
|
20 |
+
"energy_norm": true,
|
21 |
+
"pitch_remove_outlier": true,
|
22 |
+
"energy_remove_outlier": true,
|
23 |
+
|
24 |
+
// Default config
|
25 |
+
"n_mel": 80,
|
26 |
+
"win_size": 1024, // todo
|
27 |
+
"hop_size": 256,
|
28 |
+
"sample_rate": 22050,
|
29 |
+
"n_fft": 1024, // todo
|
30 |
+
"fmin": 0,
|
31 |
+
"fmax": 8000, // todo
|
32 |
+
"raw_data": "raw_data",
|
33 |
+
"text_cleaners": ["english_cleaners"],
|
34 |
+
"f0_min": 71, // ~C2
|
35 |
+
"f0_max": 800, //1100, // ~C6(1100), ~G5(800)
|
36 |
+
"pitch_bin": 256,
|
37 |
+
"pitch_max": 1100.0,
|
38 |
+
"pitch_min": 50.0,
|
39 |
+
"is_label": true,
|
40 |
+
"is_mu_law": true,
|
41 |
+
"bits": 8,
|
42 |
+
|
43 |
+
"mel_min_max_stats_dir": "mel_min_max_stats",
|
44 |
+
"whisper_dir": "whisper",
|
45 |
+
"content_vector_dir": "content_vector",
|
46 |
+
"wenet_dir": "wenet",
|
47 |
+
"mert_dir": "mert",
|
48 |
+
"spk2id":"spk2id.json",
|
49 |
+
"utt2spk":"utt2spk",
|
50 |
+
|
51 |
+
// Features used for model training
|
52 |
+
"use_mel": true,
|
53 |
+
"use_min_max_norm_mel": false,
|
54 |
+
"use_frame_pitch": false,
|
55 |
+
"use_frame_energy": false,
|
56 |
+
"use_phone_pitch": true,
|
57 |
+
"use_phone_energy": true,
|
58 |
+
"use_log_scale_pitch": false,
|
59 |
+
"use_log_scale_energy": false,
|
60 |
+
"use_spkid": false,
|
61 |
+
"align_mel_duration": true,
|
62 |
+
"text_cleaners": ["english_cleaners"],
|
63 |
+
"phone_extractor": "lexicon", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)"
|
64 |
+
},
|
65 |
+
"model": {
|
66 |
+
// Settings for transformer
|
67 |
+
"transformer": {
|
68 |
+
"encoder_layer": 4,
|
69 |
+
"encoder_head": 2,
|
70 |
+
"encoder_hidden": 256,
|
71 |
+
"decoder_layer": 6,
|
72 |
+
"decoder_head": 2,
|
73 |
+
"decoder_hidden": 256,
|
74 |
+
"conv_filter_size": 1024,
|
75 |
+
"conv_kernel_size": [9, 1],
|
76 |
+
"encoder_dropout": 0.2,
|
77 |
+
"decoder_dropout": 0.2
|
78 |
+
},
|
79 |
+
|
80 |
+
// Settings for variance_predictor
|
81 |
+
"variance_predictor":{
|
82 |
+
"filter_size": 256,
|
83 |
+
"kernel_size": 3,
|
84 |
+
"dropout": 0.5
|
85 |
+
},
|
86 |
+
"variance_embedding":{
|
87 |
+
"pitch_quantization": "linear", // support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing
|
88 |
+
"energy_quantization": "linear", // support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing
|
89 |
+
"n_bins": 256
|
90 |
+
},
|
91 |
+
"max_seq_len": 1000
|
92 |
+
},
|
93 |
+
"train":{
|
94 |
+
"batch_size": 16,
|
95 |
+
"sort_sample": true,
|
96 |
+
"drop_last": true,
|
97 |
+
"group_size": 4,
|
98 |
+
"grad_clip_thresh": 1.0,
|
99 |
+
"dataloader": {
|
100 |
+
"num_worker": 8,
|
101 |
+
"pin_memory": true
|
102 |
+
},
|
103 |
+
"lr_scheduler":{
|
104 |
+
"num_warmup": 4000
|
105 |
+
},
|
106 |
+
// LR Scheduler
|
107 |
+
"scheduler": "NoamLR",
|
108 |
+
// Optimizer
|
109 |
+
"optimizer": "Adam",
|
110 |
+
"adam": {
|
111 |
+
"lr": 0.0625,
|
112 |
+
"betas": [0.9, 0.98],
|
113 |
+
"eps": 0.000000001,
|
114 |
+
"weight_decay": 0.0
|
115 |
+
},
|
116 |
+
}
|
117 |
+
|
118 |
+
}
|
config/ns2.json
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/base.json",
|
3 |
+
"model_type": "NaturalSpeech2",
|
4 |
+
"dataset": ["LibriTTS"],
|
5 |
+
"preprocess": {
|
6 |
+
"use_mel": false,
|
7 |
+
"use_code": true,
|
8 |
+
"use_spkid": true,
|
9 |
+
"use_pitch": true,
|
10 |
+
"use_duration": true,
|
11 |
+
"use_phone": true,
|
12 |
+
"use_len": true,
|
13 |
+
"use_cross_reference": true,
|
14 |
+
"train_file": "train.json",
|
15 |
+
"melspec_dir": "mel",
|
16 |
+
"code_dir": "code",
|
17 |
+
"pitch_dir": "pitch",
|
18 |
+
"duration_dir": "duration",
|
19 |
+
"clip_mode": "start"
|
20 |
+
},
|
21 |
+
"model": {
|
22 |
+
"latent_dim": 128,
|
23 |
+
"prior_encoder": {
|
24 |
+
"vocab_size": 100,
|
25 |
+
"pitch_min": 50,
|
26 |
+
"pitch_max": 1100,
|
27 |
+
"pitch_bins_num": 512,
|
28 |
+
"encoder": {
|
29 |
+
"encoder_layer": 6,
|
30 |
+
"encoder_hidden": 512,
|
31 |
+
"encoder_head": 8,
|
32 |
+
"conv_filter_size": 2048,
|
33 |
+
"conv_kernel_size": 9,
|
34 |
+
"encoder_dropout": 0.2,
|
35 |
+
"use_cln": true
|
36 |
+
},
|
37 |
+
"duration_predictor": {
|
38 |
+
"input_size": 512,
|
39 |
+
"filter_size": 512,
|
40 |
+
"kernel_size": 3,
|
41 |
+
"conv_layers": 30,
|
42 |
+
"cross_attn_per_layer": 3,
|
43 |
+
"attn_head": 8,
|
44 |
+
"drop_out": 0.5
|
45 |
+
},
|
46 |
+
"pitch_predictor": {
|
47 |
+
"input_size": 512,
|
48 |
+
"filter_size": 512,
|
49 |
+
"kernel_size": 5,
|
50 |
+
"conv_layers": 30,
|
51 |
+
"cross_attn_per_layer": 3,
|
52 |
+
"attn_head": 8,
|
53 |
+
"drop_out": 0.5
|
54 |
+
}
|
55 |
+
},
|
56 |
+
"diffusion": {
|
57 |
+
"wavenet": {
|
58 |
+
"input_size": 128,
|
59 |
+
"hidden_size": 512,
|
60 |
+
"out_size": 128,
|
61 |
+
"num_layers": 40,
|
62 |
+
"cross_attn_per_layer": 3,
|
63 |
+
"dilation_cycle": 2,
|
64 |
+
"attn_head": 8,
|
65 |
+
"drop_out": 0.2
|
66 |
+
},
|
67 |
+
"beta_min": 0.05,
|
68 |
+
"beta_max": 20,
|
69 |
+
"sigma": 1.0,
|
70 |
+
"noise_factor": 1.0,
|
71 |
+
"ode_solver": "euler"
|
72 |
+
},
|
73 |
+
"prompt_encoder": {
|
74 |
+
"encoder_layer": 6,
|
75 |
+
"encoder_hidden": 512,
|
76 |
+
"encoder_head": 8,
|
77 |
+
"conv_filter_size": 2048,
|
78 |
+
"conv_kernel_size": 9,
|
79 |
+
"encoder_dropout": 0.2,
|
80 |
+
"use_cln": false
|
81 |
+
},
|
82 |
+
"query_emb": {
|
83 |
+
"query_token_num": 32,
|
84 |
+
"hidden_size": 512,
|
85 |
+
"head_num": 8
|
86 |
+
}
|
87 |
+
}
|
88 |
+
}
|
config/transformer.json
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/base.json",
|
3 |
+
"model_type": "Transformer",
|
4 |
+
"task_type": "svc",
|
5 |
+
"use_custom_dataset": false,
|
6 |
+
"preprocess": {
|
7 |
+
// data augmentations
|
8 |
+
"use_pitch_shift": false,
|
9 |
+
"use_formant_shift": false,
|
10 |
+
"use_time_stretch": false,
|
11 |
+
"use_equalizer": false,
|
12 |
+
// acoustic features
|
13 |
+
"extract_mel": true,
|
14 |
+
"mel_min_max_norm": true,
|
15 |
+
"extract_pitch": true,
|
16 |
+
"pitch_extractor": "parselmouth",
|
17 |
+
"extract_uv": true,
|
18 |
+
"extract_energy": true,
|
19 |
+
// content features
|
20 |
+
"extract_whisper_feature": false,
|
21 |
+
"whisper_sample_rate": 16000,
|
22 |
+
"extract_contentvec_feature": false,
|
23 |
+
"contentvec_sample_rate": 16000,
|
24 |
+
"extract_wenet_feature": false,
|
25 |
+
"wenet_sample_rate": 16000,
|
26 |
+
"extract_mert_feature": false,
|
27 |
+
"mert_sample_rate": 16000,
|
28 |
+
// Default config for whisper
|
29 |
+
"whisper_frameshift": 0.01,
|
30 |
+
"whisper_downsample_rate": 2,
|
31 |
+
// Default config for content vector
|
32 |
+
"contentvec_frameshift": 0.02,
|
33 |
+
// Default config for mert
|
34 |
+
"mert_model": "m-a-p/MERT-v1-330M",
|
35 |
+
"mert_feature_layer": -1,
|
36 |
+
"mert_hop_size": 320,
|
37 |
+
// 24k
|
38 |
+
"mert_frameshit": 0.01333,
|
39 |
+
// 10ms
|
40 |
+
"wenet_frameshift": 0.01,
|
41 |
+
// wenetspeech is 4, gigaspeech is 6
|
42 |
+
"wenet_downsample_rate": 4,
|
43 |
+
// Default config
|
44 |
+
"n_mel": 100,
|
45 |
+
"win_size": 1024,
|
46 |
+
// todo
|
47 |
+
"hop_size": 256,
|
48 |
+
"sample_rate": 24000,
|
49 |
+
"n_fft": 1024,
|
50 |
+
// todo
|
51 |
+
"fmin": 0,
|
52 |
+
"fmax": 12000,
|
53 |
+
// todo
|
54 |
+
"f0_min": 50,
|
55 |
+
// ~C2
|
56 |
+
"f0_max": 1100,
|
57 |
+
//1100, // ~C6(1100), ~G5(800)
|
58 |
+
"pitch_bin": 256,
|
59 |
+
"pitch_max": 1100.0,
|
60 |
+
"pitch_min": 50.0,
|
61 |
+
"is_label": true,
|
62 |
+
"is_mu_law": true,
|
63 |
+
"bits": 8,
|
64 |
+
"mel_min_max_stats_dir": "mel_min_max_stats",
|
65 |
+
"whisper_dir": "whisper",
|
66 |
+
"contentvec_dir": "contentvec",
|
67 |
+
"wenet_dir": "wenet",
|
68 |
+
"mert_dir": "mert",
|
69 |
+
// Extract content features using dataloader
|
70 |
+
"pin_memory": true,
|
71 |
+
"num_workers": 8,
|
72 |
+
"content_feature_batch_size": 16,
|
73 |
+
// Features used for model training
|
74 |
+
"use_mel": true,
|
75 |
+
"use_min_max_norm_mel": true,
|
76 |
+
"use_frame_pitch": true,
|
77 |
+
"use_uv": true,
|
78 |
+
"use_frame_energy": true,
|
79 |
+
"use_log_scale_pitch": false,
|
80 |
+
"use_log_scale_energy": false,
|
81 |
+
"use_spkid": true,
|
82 |
+
// Meta file
|
83 |
+
"train_file": "train.json",
|
84 |
+
"valid_file": "test.json",
|
85 |
+
"spk2id": "singers.json",
|
86 |
+
"utt2spk": "utt2singer"
|
87 |
+
},
|
88 |
+
"model": {
|
89 |
+
"condition_encoder": {
|
90 |
+
"merge_mode": "add",
|
91 |
+
"input_melody_dim": 1,
|
92 |
+
"use_log_f0": true,
|
93 |
+
"n_bins_melody": 256,
|
94 |
+
//# Quantization (0 for not quantization)
|
95 |
+
"output_melody_dim": 384,
|
96 |
+
"input_loudness_dim": 1,
|
97 |
+
"use_log_loudness": true,
|
98 |
+
"n_bins_loudness": 256,
|
99 |
+
"output_loudness_dim": 384,
|
100 |
+
"use_whisper": false,
|
101 |
+
"use_contentvec": true,
|
102 |
+
"use_wenet": false,
|
103 |
+
"use_mert": false,
|
104 |
+
"whisper_dim": 1024,
|
105 |
+
"contentvec_dim": 256,
|
106 |
+
"mert_dim": 256,
|
107 |
+
"wenet_dim": 512,
|
108 |
+
"content_encoder_dim": 384,
|
109 |
+
"output_singer_dim": 384,
|
110 |
+
"singer_table_size": 512,
|
111 |
+
"output_content_dim": 384,
|
112 |
+
"use_spkid": true
|
113 |
+
},
|
114 |
+
"transformer": {
|
115 |
+
"type": "conformer",
|
116 |
+
// 'conformer' or 'transformer'
|
117 |
+
"input_dim": 384,
|
118 |
+
"output_dim": 100,
|
119 |
+
"n_heads": 2,
|
120 |
+
"n_layers": 6,
|
121 |
+
"filter_channels": 512,
|
122 |
+
"dropout": 0.1,
|
123 |
+
}
|
124 |
+
},
|
125 |
+
"train": {
|
126 |
+
// Basic settings
|
127 |
+
"batch_size": 64,
|
128 |
+
"gradient_accumulation_step": 1,
|
129 |
+
"max_epoch": -1,
|
130 |
+
// -1 means no limit
|
131 |
+
"save_checkpoint_stride": [
|
132 |
+
10,
|
133 |
+
100
|
134 |
+
],
|
135 |
+
// unit is epoch
|
136 |
+
"keep_last": [
|
137 |
+
3,
|
138 |
+
-1
|
139 |
+
],
|
140 |
+
// -1 means infinite, if one number will broadcast
|
141 |
+
"run_eval": [
|
142 |
+
false,
|
143 |
+
true
|
144 |
+
],
|
145 |
+
// if one number will broadcast
|
146 |
+
// Fix the random seed
|
147 |
+
"random_seed": 10086,
|
148 |
+
// Batchsampler
|
149 |
+
"sampler": {
|
150 |
+
"holistic_shuffle": true,
|
151 |
+
"drop_last": true
|
152 |
+
},
|
153 |
+
// Dataloader
|
154 |
+
"dataloader": {
|
155 |
+
"num_worker": 32,
|
156 |
+
"pin_memory": true
|
157 |
+
},
|
158 |
+
// Trackers
|
159 |
+
"tracker": [
|
160 |
+
"tensorboard"
|
161 |
+
// "wandb",
|
162 |
+
// "cometml",
|
163 |
+
// "mlflow",
|
164 |
+
],
|
165 |
+
// Optimizer
|
166 |
+
"optimizer": "AdamW",
|
167 |
+
"adamw": {
|
168 |
+
"lr": 4.0e-4
|
169 |
+
// nn model lr
|
170 |
+
},
|
171 |
+
// LR Scheduler
|
172 |
+
"scheduler": "ReduceLROnPlateau",
|
173 |
+
"reducelronplateau": {
|
174 |
+
"factor": 0.8,
|
175 |
+
"patience": 10,
|
176 |
+
// unit is epoch
|
177 |
+
"min_lr": 1.0e-4
|
178 |
+
}
|
179 |
+
}
|
180 |
+
}
|
config/tts.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/base.json",
|
3 |
+
"supported_model_type": [
|
4 |
+
"Fastspeech2",
|
5 |
+
"VITS",
|
6 |
+
"VALLE",
|
7 |
+
],
|
8 |
+
"task_type": "tts",
|
9 |
+
"preprocess": {
|
10 |
+
"language": "en-us",
|
11 |
+
// linguistic features
|
12 |
+
"extract_phone": true,
|
13 |
+
"phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)"
|
14 |
+
"lexicon_path": "./text/lexicon/librispeech-lexicon.txt",
|
15 |
+
// Directory names of processed data or extracted features
|
16 |
+
"phone_dir": "phones",
|
17 |
+
"use_phone": true,
|
18 |
+
},
|
19 |
+
"model": {
|
20 |
+
"text_token_num": 512,
|
21 |
+
}
|
22 |
+
|
23 |
+
}
|
config/valle.json
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/tts.json",
|
3 |
+
"model_type": "VALLE",
|
4 |
+
"task_type": "tts",
|
5 |
+
"dataset": [
|
6 |
+
"libritts"
|
7 |
+
],
|
8 |
+
"preprocess": {
|
9 |
+
"extract_phone": true,
|
10 |
+
"phone_extractor": "espeak", // phoneme extractor: espeak, pypinyin, pypinyin_initials_finals or lexicon
|
11 |
+
"extract_acoustic_token": true,
|
12 |
+
"acoustic_token_extractor": "Encodec", // acoustic token extractor: encodec, dac(todo)
|
13 |
+
"acoustic_token_dir": "acoutic_tokens",
|
14 |
+
"use_text": false,
|
15 |
+
"use_phone": true,
|
16 |
+
"use_acoustic_token": true,
|
17 |
+
"symbols_dict": "symbols.dict",
|
18 |
+
"min_duration": 0.5, // the duration lowerbound to filter the audio with duration < min_duration
|
19 |
+
"max_duration": 14, // the duration uperbound to filter the audio with duration > max_duration.
|
20 |
+
"sample_rate": 24000,
|
21 |
+
"codec_hop_size": 320
|
22 |
+
},
|
23 |
+
"model": {
|
24 |
+
"text_token_num": 512,
|
25 |
+
"audio_token_num": 1024,
|
26 |
+
"decoder_dim": 1024, // embedding dimension of the decoder model
|
27 |
+
"nhead": 16, // number of attention heads in the decoder layers
|
28 |
+
"num_decoder_layers": 12, // number of decoder layers
|
29 |
+
"norm_first": true, // pre or post Normalization.
|
30 |
+
"add_prenet": false, // whether add PreNet after Inputs
|
31 |
+
"prefix_mode": 0, // mode for how to prefix VALL-E NAR Decoder, 0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance
|
32 |
+
"share_embedding": true, // share the parameters of the output projection layer with the parameters of the acoustic embedding
|
33 |
+
"nar_scale_factor": 1, // model scale factor which will be assigned different meanings in different models
|
34 |
+
"prepend_bos": false, // whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs
|
35 |
+
"num_quantizers": 8, // numbert of the audio quantization layers
|
36 |
+
// "scaling_xformers": false, // Apply Reworked Conformer scaling on Transformers
|
37 |
+
},
|
38 |
+
"train": {
|
39 |
+
"ddp": false,
|
40 |
+
"train_stage": 1, // 0: train all modules, For VALL_E, support 1: AR Decoder 2: NAR Decoder(s)
|
41 |
+
"max_epoch": 20,
|
42 |
+
"optimizer": "AdamW",
|
43 |
+
"scheduler": "cosine",
|
44 |
+
"warmup_steps": 16000, // number of steps that affects how rapidly the learning rate decreases
|
45 |
+
"base_lr": 1e-4, // base learning rate."
|
46 |
+
"valid_interval": 1000,
|
47 |
+
"log_epoch_step": 1000,
|
48 |
+
"save_checkpoint_stride": [
|
49 |
+
1,
|
50 |
+
1
|
51 |
+
]
|
52 |
+
}
|
53 |
+
}
|
config/vits.json
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/tts.json",
|
3 |
+
"model_type": "VITS",
|
4 |
+
"task_type": "tts",
|
5 |
+
"preprocess": {
|
6 |
+
"extract_phone": true,
|
7 |
+
"extract_mel": true,
|
8 |
+
"n_mel": 80,
|
9 |
+
"fmin": 0,
|
10 |
+
"fmax": null,
|
11 |
+
"extract_linear_spec": true,
|
12 |
+
"extract_audio": true,
|
13 |
+
"use_linear": true,
|
14 |
+
"use_mel": true,
|
15 |
+
"use_audio": true,
|
16 |
+
"use_text": false,
|
17 |
+
"use_phone": true,
|
18 |
+
"lexicon_path": "./text/lexicon/librispeech-lexicon.txt",
|
19 |
+
"n_fft": 1024,
|
20 |
+
"win_size": 1024,
|
21 |
+
"hop_size": 256,
|
22 |
+
"segment_size": 8192,
|
23 |
+
"text_cleaners": [
|
24 |
+
"english_cleaners"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
"model": {
|
28 |
+
"text_token_num": 512,
|
29 |
+
"inter_channels": 192,
|
30 |
+
"hidden_channels": 192,
|
31 |
+
"filter_channels": 768,
|
32 |
+
"n_heads": 2,
|
33 |
+
"n_layers": 6,
|
34 |
+
"kernel_size": 3,
|
35 |
+
"p_dropout": 0.1,
|
36 |
+
"resblock": "1",
|
37 |
+
"resblock_kernel_sizes": [
|
38 |
+
3,
|
39 |
+
7,
|
40 |
+
11
|
41 |
+
],
|
42 |
+
"resblock_dilation_sizes": [
|
43 |
+
[
|
44 |
+
1,
|
45 |
+
3,
|
46 |
+
5
|
47 |
+
],
|
48 |
+
[
|
49 |
+
1,
|
50 |
+
3,
|
51 |
+
5
|
52 |
+
],
|
53 |
+
[
|
54 |
+
1,
|
55 |
+
3,
|
56 |
+
5
|
57 |
+
]
|
58 |
+
],
|
59 |
+
"upsample_rates": [
|
60 |
+
8,
|
61 |
+
8,
|
62 |
+
2,
|
63 |
+
2
|
64 |
+
],
|
65 |
+
"upsample_initial_channel": 512,
|
66 |
+
"upsample_kernel_sizes": [
|
67 |
+
16,
|
68 |
+
16,
|
69 |
+
4,
|
70 |
+
4
|
71 |
+
],
|
72 |
+
"n_layers_q": 3,
|
73 |
+
"use_spectral_norm": false,
|
74 |
+
"n_speakers": 0, // number of speakers, while be automatically set if n_speakers is 0 and multi_speaker_training is true
|
75 |
+
"gin_channels": 256,
|
76 |
+
"use_sdp": true
|
77 |
+
},
|
78 |
+
"train": {
|
79 |
+
"fp16_run": true,
|
80 |
+
"learning_rate": 2e-4,
|
81 |
+
"betas": [
|
82 |
+
0.8,
|
83 |
+
0.99
|
84 |
+
],
|
85 |
+
"eps": 1e-9,
|
86 |
+
"batch_size": 16,
|
87 |
+
"lr_decay": 0.999875,
|
88 |
+
// "segment_size": 8192,
|
89 |
+
"init_lr_ratio": 1,
|
90 |
+
"warmup_epochs": 0,
|
91 |
+
"c_mel": 45,
|
92 |
+
"c_kl": 1.0,
|
93 |
+
"AdamW": {
|
94 |
+
"betas": [
|
95 |
+
0.8,
|
96 |
+
0.99
|
97 |
+
],
|
98 |
+
"eps": 1e-9,
|
99 |
+
}
|
100 |
+
}
|
101 |
+
}
|
config/vitssvc.json
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/base.json",
|
3 |
+
"model_type": "VITS",
|
4 |
+
"task_type": "svc",
|
5 |
+
"preprocess": {
|
6 |
+
"extract_phone": false,
|
7 |
+
"extract_mel": true,
|
8 |
+
"extract_linear_spec": true,
|
9 |
+
"extract_audio": true,
|
10 |
+
"use_linear": true,
|
11 |
+
"use_mel": true,
|
12 |
+
"use_audio": true,
|
13 |
+
"use_text": false,
|
14 |
+
"use_phone": true,
|
15 |
+
|
16 |
+
"fmin": 0,
|
17 |
+
"fmax": null,
|
18 |
+
"f0_min": 50,
|
19 |
+
"f0_max": 1100,
|
20 |
+
// f0_bin in sovits
|
21 |
+
"pitch_bin": 256,
|
22 |
+
// filter_length in sovits
|
23 |
+
"n_fft": 2048,
|
24 |
+
// hop_length in sovits
|
25 |
+
"hop_size": 512,
|
26 |
+
// win_length in sovits
|
27 |
+
"win_size": 2048,
|
28 |
+
"segment_size": 8192,
|
29 |
+
"n_mel": 100,
|
30 |
+
"sample_rate": 44100,
|
31 |
+
|
32 |
+
"mel_min_max_stats_dir": "mel_min_max_stats",
|
33 |
+
"whisper_dir": "whisper",
|
34 |
+
"contentvec_dir": "contentvec",
|
35 |
+
"wenet_dir": "wenet",
|
36 |
+
"mert_dir": "mert",
|
37 |
+
},
|
38 |
+
"model": {
|
39 |
+
"condition_encoder": {
|
40 |
+
"merge_mode": "add",
|
41 |
+
"input_melody_dim": 1,
|
42 |
+
"use_log_f0": true,
|
43 |
+
"n_bins_melody": 256,
|
44 |
+
//# Quantization (0 for not quantization)
|
45 |
+
"output_melody_dim": 196,
|
46 |
+
"input_loudness_dim": 1,
|
47 |
+
"use_log_loudness": false,
|
48 |
+
"n_bins_loudness": 256,
|
49 |
+
"output_loudness_dim": 196,
|
50 |
+
"use_whisper": false,
|
51 |
+
"use_contentvec": false,
|
52 |
+
"use_wenet": false,
|
53 |
+
"use_mert": false,
|
54 |
+
"whisper_dim": 1024,
|
55 |
+
"contentvec_dim": 256,
|
56 |
+
"mert_dim": 256,
|
57 |
+
"wenet_dim": 512,
|
58 |
+
"content_encoder_dim": 196,
|
59 |
+
"output_singer_dim": 196,
|
60 |
+
"singer_table_size": 512,
|
61 |
+
"output_content_dim": 196,
|
62 |
+
"use_spkid": true
|
63 |
+
},
|
64 |
+
"vits": {
|
65 |
+
"filter_channels": 256,
|
66 |
+
"gin_channels": 256,
|
67 |
+
"hidden_channels": 192,
|
68 |
+
"inter_channels": 192,
|
69 |
+
"kernel_size": 3,
|
70 |
+
"n_flow_layer": 4,
|
71 |
+
"n_heads": 2,
|
72 |
+
"n_layers": 6,
|
73 |
+
"n_layers_q": 3,
|
74 |
+
"n_speakers": 512,
|
75 |
+
"p_dropout": 0.1,
|
76 |
+
"ssl_dim": 256,
|
77 |
+
"use_spectral_norm": false,
|
78 |
+
},
|
79 |
+
"generator": "hifigan",
|
80 |
+
"generator_config": {
|
81 |
+
"hifigan": {
|
82 |
+
"resblock": "1",
|
83 |
+
"resblock_kernel_sizes": [
|
84 |
+
3,
|
85 |
+
7,
|
86 |
+
11
|
87 |
+
],
|
88 |
+
"upsample_rates": [
|
89 |
+
8,8,2,2,2
|
90 |
+
],
|
91 |
+
"upsample_kernel_sizes": [
|
92 |
+
16,16,4,4,4
|
93 |
+
],
|
94 |
+
"upsample_initial_channel": 512,
|
95 |
+
"resblock_dilation_sizes": [
|
96 |
+
[1,3,5],
|
97 |
+
[1,3,5],
|
98 |
+
[1,3,5]
|
99 |
+
]
|
100 |
+
},
|
101 |
+
"melgan": {
|
102 |
+
"ratios": [8, 8, 2, 2, 2],
|
103 |
+
"ngf": 32,
|
104 |
+
"n_residual_layers": 3,
|
105 |
+
"num_D": 3,
|
106 |
+
"ndf": 16,
|
107 |
+
"n_layers": 4,
|
108 |
+
"downsampling_factor": 4
|
109 |
+
},
|
110 |
+
"bigvgan": {
|
111 |
+
"resblock": "1",
|
112 |
+
"activation": "snakebeta",
|
113 |
+
"snake_logscale": true,
|
114 |
+
"upsample_rates": [
|
115 |
+
8,8,2,2,2,
|
116 |
+
],
|
117 |
+
"upsample_kernel_sizes": [
|
118 |
+
16,16,4,4,4,
|
119 |
+
],
|
120 |
+
"upsample_initial_channel": 512,
|
121 |
+
"resblock_kernel_sizes": [
|
122 |
+
3,
|
123 |
+
7,
|
124 |
+
11
|
125 |
+
],
|
126 |
+
"resblock_dilation_sizes": [
|
127 |
+
[1,3,5],
|
128 |
+
[1,3,5],
|
129 |
+
[1,3,5]
|
130 |
+
]
|
131 |
+
},
|
132 |
+
"nsfhifigan": {
|
133 |
+
"resblock": "1",
|
134 |
+
"harmonic_num": 8,
|
135 |
+
"upsample_rates": [
|
136 |
+
8,8,2,2,2,
|
137 |
+
],
|
138 |
+
"upsample_kernel_sizes": [
|
139 |
+
16,16,4,4,4,
|
140 |
+
],
|
141 |
+
"upsample_initial_channel": 768,
|
142 |
+
"resblock_kernel_sizes": [
|
143 |
+
3,
|
144 |
+
7,
|
145 |
+
11
|
146 |
+
],
|
147 |
+
"resblock_dilation_sizes": [
|
148 |
+
[1,3,5],
|
149 |
+
[1,3,5],
|
150 |
+
[1,3,5]
|
151 |
+
]
|
152 |
+
},
|
153 |
+
"apnet": {
|
154 |
+
"ASP_channel": 512,
|
155 |
+
"ASP_resblock_kernel_sizes": [3,7,11],
|
156 |
+
"ASP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
157 |
+
"ASP_input_conv_kernel_size": 7,
|
158 |
+
"ASP_output_conv_kernel_size": 7,
|
159 |
+
|
160 |
+
"PSP_channel": 512,
|
161 |
+
"PSP_resblock_kernel_sizes": [3,7,11],
|
162 |
+
"PSP_resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
163 |
+
"PSP_input_conv_kernel_size": 7,
|
164 |
+
"PSP_output_R_conv_kernel_size": 7,
|
165 |
+
"PSP_output_I_conv_kernel_size": 7,
|
166 |
+
}
|
167 |
+
},
|
168 |
+
},
|
169 |
+
"train": {
|
170 |
+
"fp16_run": true,
|
171 |
+
"learning_rate": 2e-4,
|
172 |
+
"betas": [
|
173 |
+
0.8,
|
174 |
+
0.99
|
175 |
+
],
|
176 |
+
"eps": 1e-9,
|
177 |
+
"batch_size": 16,
|
178 |
+
"lr_decay": 0.999875,
|
179 |
+
// "segment_size": 8192,
|
180 |
+
"init_lr_ratio": 1,
|
181 |
+
"warmup_epochs": 0,
|
182 |
+
"c_mel": 45,
|
183 |
+
"c_kl": 1.0,
|
184 |
+
"AdamW": {
|
185 |
+
"betas": [
|
186 |
+
0.8,
|
187 |
+
0.99
|
188 |
+
],
|
189 |
+
"eps": 1e-9,
|
190 |
+
}
|
191 |
+
}
|
192 |
+
}
|
config/vocoder.json
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/base.json",
|
3 |
+
"dataset": [
|
4 |
+
"LJSpeech",
|
5 |
+
"LibriTTS",
|
6 |
+
"opencpop",
|
7 |
+
"m4singer",
|
8 |
+
"svcc",
|
9 |
+
"svcceval",
|
10 |
+
"pjs",
|
11 |
+
"opensinger",
|
12 |
+
"popbutfy",
|
13 |
+
"nus48e",
|
14 |
+
"popcs",
|
15 |
+
"kising",
|
16 |
+
"csd",
|
17 |
+
"opera",
|
18 |
+
"vctk",
|
19 |
+
"lijian",
|
20 |
+
"cdmusiceval"
|
21 |
+
],
|
22 |
+
"task_type": "vocoder",
|
23 |
+
"preprocess": {
|
24 |
+
// acoustic features
|
25 |
+
"extract_mel": true,
|
26 |
+
"extract_pitch": false,
|
27 |
+
"extract_uv": false,
|
28 |
+
"extract_audio": true,
|
29 |
+
"extract_label": false,
|
30 |
+
"extract_one_hot": false,
|
31 |
+
"extract_amplitude_phase": false,
|
32 |
+
"pitch_extractor": "parselmouth",
|
33 |
+
// Settings for data preprocessing
|
34 |
+
"n_mel": 100,
|
35 |
+
"win_size": 1024,
|
36 |
+
"hop_size": 256,
|
37 |
+
"sample_rate": 24000,
|
38 |
+
"n_fft": 1024,
|
39 |
+
"fmin": 0,
|
40 |
+
"fmax": 12000,
|
41 |
+
"f0_min": 50,
|
42 |
+
"f0_max": 1100,
|
43 |
+
"pitch_bin": 256,
|
44 |
+
"pitch_max": 1100.0,
|
45 |
+
"pitch_min": 50.0,
|
46 |
+
"is_mu_law": false,
|
47 |
+
"bits": 8,
|
48 |
+
"cut_mel_frame": 32,
|
49 |
+
// Directory names of processed data or extracted features
|
50 |
+
"spk2id": "singers.json",
|
51 |
+
// Features used for model training
|
52 |
+
"use_mel": true,
|
53 |
+
"use_frame_pitch": false,
|
54 |
+
"use_uv": false,
|
55 |
+
"use_audio": true,
|
56 |
+
"use_label": false,
|
57 |
+
"use_one_hot": false,
|
58 |
+
"train_file": "train.json",
|
59 |
+
"valid_file": "test.json"
|
60 |
+
},
|
61 |
+
"train": {
|
62 |
+
"random_seed": 114514,
|
63 |
+
"batch_size": 64,
|
64 |
+
"gradient_accumulation_step": 1,
|
65 |
+
"max_epoch": 1000000,
|
66 |
+
"save_checkpoint_stride": [
|
67 |
+
20
|
68 |
+
],
|
69 |
+
"run_eval": [
|
70 |
+
true
|
71 |
+
],
|
72 |
+
"sampler": {
|
73 |
+
"holistic_shuffle": true,
|
74 |
+
"drop_last": true
|
75 |
+
},
|
76 |
+
"dataloader": {
|
77 |
+
"num_worker": 4,
|
78 |
+
"pin_memory": true
|
79 |
+
},
|
80 |
+
"tracker": [
|
81 |
+
"tensorboard"
|
82 |
+
],
|
83 |
+
}
|
84 |
+
}
|
evaluation/__init__.py
ADDED
File without changes
|
evaluation/features/__init__.py
ADDED
File without changes
|
evaluation/features/long_term_average_spectrum.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
from scipy import signal
|
8 |
+
|
9 |
+
|
10 |
+
def extract_ltas(audio, fs=None, n_fft=1024, hop_length=256):
|
11 |
+
"""Extract Long-Term Average Spectrum for a given audio."""
|
12 |
+
if fs != None:
|
13 |
+
y, _ = librosa.load(audio, sr=fs)
|
14 |
+
else:
|
15 |
+
y, fs = librosa.load(audio)
|
16 |
+
frequency, density = signal.welch(
|
17 |
+
x=y, fs=fs, window="hann", nperseg=hop_length, nfft=n_fft
|
18 |
+
)
|
19 |
+
return frequency, density
|
evaluation/features/signal_to_noise_ratio.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import scipy.signal as sig
|
8 |
+
import copy
|
9 |
+
import librosa
|
10 |
+
|
11 |
+
|
12 |
+
def bandpower(ps, mode="time"):
|
13 |
+
"""
|
14 |
+
estimate bandpower, see https://de.mathworks.com/help/signal/ref/bandpower.html
|
15 |
+
"""
|
16 |
+
if mode == "time":
|
17 |
+
x = ps
|
18 |
+
l2norm = np.linalg.norm(x) ** 2.0 / len(x)
|
19 |
+
return l2norm
|
20 |
+
elif mode == "psd":
|
21 |
+
return sum(ps)
|
22 |
+
|
23 |
+
|
24 |
+
def getIndizesAroundPeak(arr, peakIndex, searchWidth=1000):
|
25 |
+
peakBins = []
|
26 |
+
magMax = arr[peakIndex]
|
27 |
+
curVal = magMax
|
28 |
+
for i in range(searchWidth):
|
29 |
+
newBin = peakIndex + i
|
30 |
+
if newBin >= len(arr):
|
31 |
+
break
|
32 |
+
newVal = arr[newBin]
|
33 |
+
if newVal > curVal:
|
34 |
+
break
|
35 |
+
else:
|
36 |
+
peakBins.append(int(newBin))
|
37 |
+
curVal = newVal
|
38 |
+
curVal = magMax
|
39 |
+
for i in range(searchWidth):
|
40 |
+
newBin = peakIndex - i
|
41 |
+
if newBin < 0:
|
42 |
+
break
|
43 |
+
newVal = arr[newBin]
|
44 |
+
if newVal > curVal:
|
45 |
+
break
|
46 |
+
else:
|
47 |
+
peakBins.append(int(newBin))
|
48 |
+
curVal = newVal
|
49 |
+
return np.array(list(set(peakBins)))
|
50 |
+
|
51 |
+
|
52 |
+
def freqToBin(fAxis, Freq):
|
53 |
+
return np.argmin(abs(fAxis - Freq))
|
54 |
+
|
55 |
+
|
56 |
+
def getPeakInArea(psd, faxis, estimation, searchWidthHz=10):
|
57 |
+
"""
|
58 |
+
returns bin and frequency of the maximum in an area
|
59 |
+
"""
|
60 |
+
binLow = freqToBin(faxis, estimation - searchWidthHz)
|
61 |
+
binHi = freqToBin(faxis, estimation + searchWidthHz)
|
62 |
+
peakbin = binLow + np.argmax(psd[binLow : binHi + 1])
|
63 |
+
return peakbin, faxis[peakbin]
|
64 |
+
|
65 |
+
|
66 |
+
def getHarmonics(fund, sr, nHarmonics=6, aliased=False):
|
67 |
+
harmonicMultipliers = np.arange(2, nHarmonics + 2)
|
68 |
+
harmonicFs = fund * harmonicMultipliers
|
69 |
+
if not aliased:
|
70 |
+
harmonicFs[harmonicFs > sr / 2] = -1
|
71 |
+
harmonicFs = np.delete(harmonicFs, harmonicFs == -1)
|
72 |
+
else:
|
73 |
+
nyqZone = np.floor(harmonicFs / (sr / 2))
|
74 |
+
oddEvenNyq = nyqZone % 2
|
75 |
+
harmonicFs = np.mod(harmonicFs, sr / 2)
|
76 |
+
harmonicFs[oddEvenNyq == 1] = (sr / 2) - harmonicFs[oddEvenNyq == 1]
|
77 |
+
return harmonicFs
|
78 |
+
|
79 |
+
|
80 |
+
def extract_snr(audio, sr=None):
|
81 |
+
"""Extract Signal-to-Noise Ratio for a given audio."""
|
82 |
+
if sr != None:
|
83 |
+
audio, _ = librosa.load(audio, sr=sr)
|
84 |
+
else:
|
85 |
+
audio, sr = librosa.load(audio, sr=sr)
|
86 |
+
faxis, ps = sig.periodogram(
|
87 |
+
audio, fs=sr, window=("kaiser", 38)
|
88 |
+
) # get periodogram, parametrized like in matlab
|
89 |
+
fundBin = np.argmax(
|
90 |
+
ps
|
91 |
+
) # estimate fundamental at maximum amplitude, get the bin number
|
92 |
+
fundIndizes = getIndizesAroundPeak(
|
93 |
+
ps, fundBin
|
94 |
+
) # get bin numbers around fundamental peak
|
95 |
+
fundFrequency = faxis[fundBin] # frequency of fundamental
|
96 |
+
|
97 |
+
nHarmonics = 18
|
98 |
+
harmonicFs = getHarmonics(
|
99 |
+
fundFrequency, sr, nHarmonics=nHarmonics, aliased=True
|
100 |
+
) # get harmonic frequencies
|
101 |
+
|
102 |
+
harmonicBorders = np.zeros([2, nHarmonics], dtype=np.int16).T
|
103 |
+
fullHarmonicBins = np.array([], dtype=np.int16)
|
104 |
+
fullHarmonicBinList = []
|
105 |
+
harmPeakFreqs = []
|
106 |
+
harmPeaks = []
|
107 |
+
for i, harmonic in enumerate(harmonicFs):
|
108 |
+
searcharea = 0.1 * fundFrequency
|
109 |
+
estimation = harmonic
|
110 |
+
|
111 |
+
binNum, freq = getPeakInArea(ps, faxis, estimation, searcharea)
|
112 |
+
harmPeakFreqs.append(freq)
|
113 |
+
harmPeaks.append(ps[binNum])
|
114 |
+
allBins = getIndizesAroundPeak(ps, binNum, searchWidth=1000)
|
115 |
+
fullHarmonicBins = np.append(fullHarmonicBins, allBins)
|
116 |
+
fullHarmonicBinList.append(allBins)
|
117 |
+
harmonicBorders[i, :] = [allBins[0], allBins[-1]]
|
118 |
+
|
119 |
+
fundIndizes.sort()
|
120 |
+
pFund = bandpower(ps[fundIndizes[0] : fundIndizes[-1]]) # get power of fundamental
|
121 |
+
|
122 |
+
noisePrepared = copy.copy(ps)
|
123 |
+
noisePrepared[fundIndizes] = 0
|
124 |
+
noisePrepared[fullHarmonicBins] = 0
|
125 |
+
noiseMean = np.median(noisePrepared[noisePrepared != 0])
|
126 |
+
noisePrepared[fundIndizes] = noiseMean
|
127 |
+
noisePrepared[fullHarmonicBins] = noiseMean
|
128 |
+
|
129 |
+
noisePower = bandpower(noisePrepared)
|
130 |
+
|
131 |
+
r = 10 * np.log10(pFund / noisePower)
|
132 |
+
|
133 |
+
return r, 10 * np.log10(noisePower)
|
evaluation/features/singing_power_ratio.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import librosa
|
8 |
+
|
9 |
+
from utils.util import JsonHParams
|
10 |
+
from utils.f0 import get_f0_features_using_parselmouth, get_pitch_sub_median
|
11 |
+
from utils.mel import extract_mel_features
|
12 |
+
|
13 |
+
|
14 |
+
def extract_spr(
|
15 |
+
audio,
|
16 |
+
fs=None,
|
17 |
+
hop_length=256,
|
18 |
+
win_length=1024,
|
19 |
+
n_fft=1024,
|
20 |
+
n_mels=128,
|
21 |
+
f0_min=37,
|
22 |
+
f0_max=1000,
|
23 |
+
pitch_bin=256,
|
24 |
+
pitch_max=1100.0,
|
25 |
+
pitch_min=50.0,
|
26 |
+
):
|
27 |
+
"""Compute Singing Power Ratio (SPR) from a given audio.
|
28 |
+
audio: path to the audio.
|
29 |
+
fs: sampling rate.
|
30 |
+
hop_length: hop length.
|
31 |
+
win_length: window length.
|
32 |
+
n_mels: number of mel filters.
|
33 |
+
f0_min: lower limit for f0.
|
34 |
+
f0_max: upper limit for f0.
|
35 |
+
pitch_bin: number of bins for f0 quantization.
|
36 |
+
pitch_max: upper limit for f0 quantization.
|
37 |
+
pitch_min: lower limit for f0 quantization.
|
38 |
+
"""
|
39 |
+
# Load audio
|
40 |
+
if fs != None:
|
41 |
+
audio, _ = librosa.load(audio, sr=fs)
|
42 |
+
else:
|
43 |
+
audio, fs = librosa.load(audio)
|
44 |
+
audio = torch.from_numpy(audio)
|
45 |
+
|
46 |
+
# Initialize config
|
47 |
+
cfg = JsonHParams()
|
48 |
+
cfg.sample_rate = fs
|
49 |
+
cfg.hop_size = hop_length
|
50 |
+
cfg.win_size = win_length
|
51 |
+
cfg.n_fft = n_fft
|
52 |
+
cfg.n_mel = n_mels
|
53 |
+
cfg.f0_min = f0_min
|
54 |
+
cfg.f0_max = f0_max
|
55 |
+
cfg.pitch_bin = pitch_bin
|
56 |
+
cfg.pitch_max = pitch_max
|
57 |
+
cfg.pitch_min = pitch_min
|
58 |
+
|
59 |
+
# Extract mel spectrograms
|
60 |
+
|
61 |
+
cfg.fmin = 2000
|
62 |
+
cfg.fmax = 4000
|
63 |
+
|
64 |
+
mel1 = extract_mel_features(
|
65 |
+
y=audio.unsqueeze(0),
|
66 |
+
cfg=cfg,
|
67 |
+
).squeeze(0)
|
68 |
+
|
69 |
+
cfg.fmin = 0
|
70 |
+
cfg.fmax = 2000
|
71 |
+
|
72 |
+
mel2 = extract_mel_features(
|
73 |
+
y=audio.unsqueeze(0),
|
74 |
+
cfg=cfg,
|
75 |
+
).squeeze(0)
|
76 |
+
|
77 |
+
f0 = get_f0_features_using_parselmouth(
|
78 |
+
audio,
|
79 |
+
cfg,
|
80 |
+
)[0]
|
81 |
+
|
82 |
+
# Mel length alignment
|
83 |
+
length = min(len(f0), mel1.shape[-1])
|
84 |
+
f0 = f0[:length]
|
85 |
+
mel1 = mel1[:, :length]
|
86 |
+
mel2 = mel2[:, :length]
|
87 |
+
|
88 |
+
# Compute SPR
|
89 |
+
res = []
|
90 |
+
|
91 |
+
for i in range(mel1.shape[-1]):
|
92 |
+
if f0[i] <= 1:
|
93 |
+
continue
|
94 |
+
|
95 |
+
chunk1 = mel1[:, i]
|
96 |
+
chunk2 = mel2[:, i]
|
97 |
+
|
98 |
+
max1 = max(chunk1.numpy())
|
99 |
+
max2 = max(chunk2.numpy())
|
100 |
+
|
101 |
+
tmp_res = max2 - max1
|
102 |
+
|
103 |
+
res.append(tmp_res)
|
104 |
+
|
105 |
+
if len(res) == 0:
|
106 |
+
return False
|
107 |
+
else:
|
108 |
+
return sum(res) / len(res)
|
evaluation/metrics/__init__.py
ADDED
File without changes
|
evaluation/metrics/energy/__init__.py
ADDED
File without changes
|
evaluation/metrics/energy/energy_pearson_coefficients.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import librosa
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from numpy import linalg as LA
|
12 |
+
|
13 |
+
from torchmetrics import PearsonCorrCoef
|
14 |
+
|
15 |
+
|
16 |
+
def extract_energy_pearson_coeffcients(
|
17 |
+
audio_ref,
|
18 |
+
audio_deg,
|
19 |
+
fs=None,
|
20 |
+
n_fft=1024,
|
21 |
+
hop_length=256,
|
22 |
+
win_length=1024,
|
23 |
+
method="dtw",
|
24 |
+
db_scale=True,
|
25 |
+
):
|
26 |
+
"""Compute Energy Pearson Coefficients between the predicted and the ground truth audio.
|
27 |
+
audio_ref: path to the ground truth audio.
|
28 |
+
audio_deg: path to the predicted audio.
|
29 |
+
fs: sampling rate.
|
30 |
+
n_fft: fft size.
|
31 |
+
hop_length: hop length.
|
32 |
+
win_length: window length.
|
33 |
+
method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
|
34 |
+
"cut" will cut both audios into a same length according to the one with the shorter length.
|
35 |
+
db_scale: the ground truth and predicted audio will be converted to db_scale if "True".
|
36 |
+
"""
|
37 |
+
# Initialize method
|
38 |
+
pearson = PearsonCorrCoef()
|
39 |
+
|
40 |
+
# Load audio
|
41 |
+
if fs != None:
|
42 |
+
audio_ref, _ = librosa.load(audio_ref, sr=fs)
|
43 |
+
audio_deg, _ = librosa.load(audio_deg, sr=fs)
|
44 |
+
else:
|
45 |
+
audio_ref, fs = librosa.load(audio_ref)
|
46 |
+
audio_deg, fs = librosa.load(audio_deg)
|
47 |
+
|
48 |
+
# STFT
|
49 |
+
spec_ref = librosa.stft(
|
50 |
+
y=audio_ref, n_fft=n_fft, hop_length=hop_length, win_length=win_length
|
51 |
+
)
|
52 |
+
spec_deg = librosa.stft(
|
53 |
+
y=audio_deg, n_fft=n_fft, hop_length=hop_length, win_length=win_length
|
54 |
+
)
|
55 |
+
|
56 |
+
# Get magnitudes
|
57 |
+
mag_ref = np.abs(spec_ref).T
|
58 |
+
mag_deg = np.abs(spec_deg).T
|
59 |
+
|
60 |
+
# Convert spectrogram to energy
|
61 |
+
energy_ref = LA.norm(mag_ref, axis=1)
|
62 |
+
energy_deg = LA.norm(mag_deg, axis=1)
|
63 |
+
|
64 |
+
# Convert to db_scale
|
65 |
+
if db_scale:
|
66 |
+
energy_ref = 20 * np.log10(energy_ref)
|
67 |
+
energy_deg = 20 * np.log10(energy_deg)
|
68 |
+
|
69 |
+
# Audio length alignment
|
70 |
+
if method == "cut":
|
71 |
+
length = min(len(energy_ref), len(energy_deg))
|
72 |
+
energy_ref = energy_ref[:length]
|
73 |
+
energy_deg = energy_deg[:length]
|
74 |
+
elif method == "dtw":
|
75 |
+
_, wp = librosa.sequence.dtw(energy_ref, energy_deg, backtrack=True)
|
76 |
+
energy_gt_new = []
|
77 |
+
energy_pred_new = []
|
78 |
+
for i in range(wp.shape[0]):
|
79 |
+
gt_index = wp[i][0]
|
80 |
+
pred_index = wp[i][1]
|
81 |
+
energy_gt_new.append(energy_ref[gt_index])
|
82 |
+
energy_pred_new.append(energy_deg[pred_index])
|
83 |
+
energy_ref = np.array(energy_gt_new)
|
84 |
+
energy_deg = np.array(energy_pred_new)
|
85 |
+
assert len(energy_ref) == len(energy_deg)
|
86 |
+
|
87 |
+
# Convert to tensor
|
88 |
+
energy_ref = torch.from_numpy(energy_ref)
|
89 |
+
energy_deg = torch.from_numpy(energy_deg)
|
90 |
+
|
91 |
+
return pearson(energy_ref, energy_deg).numpy().tolist()
|
evaluation/metrics/energy/energy_rmse.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import librosa
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from numpy import linalg as LA
|
12 |
+
|
13 |
+
|
14 |
+
def extract_energy_rmse(
|
15 |
+
audio_ref,
|
16 |
+
audio_deg,
|
17 |
+
fs=None,
|
18 |
+
n_fft=1024,
|
19 |
+
hop_length=256,
|
20 |
+
win_length=1024,
|
21 |
+
method="dtw",
|
22 |
+
db_scale=True,
|
23 |
+
):
|
24 |
+
"""Compute Energy Root Mean Square Error (RMSE) between the predicted and the ground truth audio.
|
25 |
+
audio_ref: path to the ground truth audio.
|
26 |
+
audio_deg: path to the predicted audio.
|
27 |
+
fs: sampling rate.
|
28 |
+
n_fft: fft size.
|
29 |
+
hop_length: hop length.
|
30 |
+
win_length: window length.
|
31 |
+
method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
|
32 |
+
"cut" will cut both audios into a same length according to the one with the shorter length.
|
33 |
+
db_scale: the ground truth and predicted audio will be converted to db_scale if "True".
|
34 |
+
"""
|
35 |
+
# Load audio
|
36 |
+
if fs != None:
|
37 |
+
audio_ref, _ = librosa.load(audio_ref, sr=fs)
|
38 |
+
audio_deg, _ = librosa.load(audio_deg, sr=fs)
|
39 |
+
else:
|
40 |
+
audio_ref, fs = librosa.load(audio_ref)
|
41 |
+
audio_deg, fs = librosa.load(audio_deg)
|
42 |
+
|
43 |
+
# STFT
|
44 |
+
spec_ref = librosa.stft(
|
45 |
+
y=audio_ref, n_fft=n_fft, hop_length=hop_length, win_length=win_length
|
46 |
+
)
|
47 |
+
spec_deg = librosa.stft(
|
48 |
+
y=audio_deg, n_fft=n_fft, hop_length=hop_length, win_length=win_length
|
49 |
+
)
|
50 |
+
|
51 |
+
# Get magnitudes
|
52 |
+
mag_ref = np.abs(spec_ref).T
|
53 |
+
mag_deg = np.abs(spec_deg).T
|
54 |
+
|
55 |
+
# Convert spectrogram to energy
|
56 |
+
energy_ref = LA.norm(mag_ref, axis=1)
|
57 |
+
energy_deg = LA.norm(mag_deg, axis=1)
|
58 |
+
|
59 |
+
# Convert to db_scale
|
60 |
+
if db_scale:
|
61 |
+
energy_ref = 20 * np.log10(energy_ref)
|
62 |
+
energy_deg = 20 * np.log10(energy_deg)
|
63 |
+
|
64 |
+
# Audio length alignment
|
65 |
+
if method == "cut":
|
66 |
+
length = min(len(energy_ref), len(energy_deg))
|
67 |
+
energy_ref = energy_ref[:length]
|
68 |
+
energy_deg = energy_deg[:length]
|
69 |
+
elif method == "dtw":
|
70 |
+
_, wp = librosa.sequence.dtw(energy_ref, energy_deg, backtrack=True)
|
71 |
+
energy_gt_new = []
|
72 |
+
energy_pred_new = []
|
73 |
+
for i in range(wp.shape[0]):
|
74 |
+
gt_index = wp[i][0]
|
75 |
+
pred_index = wp[i][1]
|
76 |
+
energy_gt_new.append(energy_ref[gt_index])
|
77 |
+
energy_pred_new.append(energy_deg[pred_index])
|
78 |
+
energy_ref = np.array(energy_gt_new)
|
79 |
+
energy_deg = np.array(energy_pred_new)
|
80 |
+
assert len(energy_ref) == len(energy_deg)
|
81 |
+
|
82 |
+
# Compute RMSE
|
83 |
+
energy_mse = np.square(np.subtract(energy_ref, energy_deg)).mean()
|
84 |
+
energy_rmse = math.sqrt(energy_mse)
|
85 |
+
|
86 |
+
return energy_rmse
|
evaluation/metrics/f0/__init__.py
ADDED
File without changes
|
evaluation/metrics/f0/f0_pearson_coefficients.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import librosa
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torchmetrics import PearsonCorrCoef
|
12 |
+
|
13 |
+
from utils.util import JsonHParams
|
14 |
+
from utils.f0 import get_f0_features_using_parselmouth, get_pitch_sub_median
|
15 |
+
|
16 |
+
|
17 |
+
def extract_fpc(
|
18 |
+
audio_ref,
|
19 |
+
audio_deg,
|
20 |
+
fs=None,
|
21 |
+
hop_length=256,
|
22 |
+
f0_min=50,
|
23 |
+
f0_max=1100,
|
24 |
+
pitch_bin=256,
|
25 |
+
pitch_min=50,
|
26 |
+
pitch_max=1100,
|
27 |
+
need_mean=True,
|
28 |
+
method="dtw",
|
29 |
+
):
|
30 |
+
"""Compute F0 Pearson Distance (FPC) between the predicted and the ground truth audio.
|
31 |
+
audio_ref: path to the ground truth audio.
|
32 |
+
audio_deg: path to the predicted audio.
|
33 |
+
fs: sampling rate.
|
34 |
+
hop_length: hop length.
|
35 |
+
f0_min: lower limit for f0.
|
36 |
+
f0_max: upper limit for f0.
|
37 |
+
pitch_bin: number of bins for f0 quantization.
|
38 |
+
pitch_max: upper limit for f0 quantization.
|
39 |
+
pitch_min: lower limit for f0 quantization.
|
40 |
+
need_mean: subtract the mean value from f0 if "True".
|
41 |
+
method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
|
42 |
+
"cut" will cut both audios into a same length according to the one with the shorter length.
|
43 |
+
"""
|
44 |
+
# Initialize method
|
45 |
+
pearson = PearsonCorrCoef()
|
46 |
+
|
47 |
+
# Load audio
|
48 |
+
if fs != None:
|
49 |
+
audio_ref, _ = librosa.load(audio_ref, sr=fs)
|
50 |
+
audio_deg, _ = librosa.load(audio_deg, sr=fs)
|
51 |
+
else:
|
52 |
+
audio_ref, fs = librosa.load(audio_ref)
|
53 |
+
audio_deg, fs = librosa.load(audio_deg)
|
54 |
+
|
55 |
+
# Initialize config
|
56 |
+
cfg = JsonHParams()
|
57 |
+
cfg.sample_rate = fs
|
58 |
+
cfg.hop_size = hop_length
|
59 |
+
cfg.f0_min = f0_min
|
60 |
+
cfg.f0_max = f0_max
|
61 |
+
cfg.pitch_bin = pitch_bin
|
62 |
+
cfg.pitch_max = pitch_max
|
63 |
+
cfg.pitch_min = pitch_min
|
64 |
+
|
65 |
+
# Compute f0
|
66 |
+
f0_ref = get_f0_features_using_parselmouth(
|
67 |
+
audio_ref,
|
68 |
+
cfg,
|
69 |
+
)[0]
|
70 |
+
|
71 |
+
f0_deg = get_f0_features_using_parselmouth(
|
72 |
+
audio_deg,
|
73 |
+
cfg,
|
74 |
+
)[0]
|
75 |
+
|
76 |
+
# Subtract mean value from f0
|
77 |
+
if need_mean:
|
78 |
+
f0_ref = torch.from_numpy(f0_ref)
|
79 |
+
f0_deg = torch.from_numpy(f0_deg)
|
80 |
+
|
81 |
+
f0_ref = get_pitch_sub_median(f0_ref).numpy()
|
82 |
+
f0_deg = get_pitch_sub_median(f0_deg).numpy()
|
83 |
+
|
84 |
+
# Avoid silence
|
85 |
+
min_length = min(len(f0_ref), len(f0_deg))
|
86 |
+
if min_length <= 1:
|
87 |
+
return 1
|
88 |
+
|
89 |
+
# F0 length alignment
|
90 |
+
if method == "cut":
|
91 |
+
length = min(len(f0_ref), len(f0_deg))
|
92 |
+
f0_ref = f0_ref[:length]
|
93 |
+
f0_deg = f0_deg[:length]
|
94 |
+
elif method == "dtw":
|
95 |
+
_, wp = librosa.sequence.dtw(f0_ref, f0_deg, backtrack=True)
|
96 |
+
f0_gt_new = []
|
97 |
+
f0_pred_new = []
|
98 |
+
for i in range(wp.shape[0]):
|
99 |
+
gt_index = wp[i][0]
|
100 |
+
pred_index = wp[i][1]
|
101 |
+
f0_gt_new.append(f0_ref[gt_index])
|
102 |
+
f0_pred_new.append(f0_deg[pred_index])
|
103 |
+
f0_ref = np.array(f0_gt_new)
|
104 |
+
f0_deg = np.array(f0_pred_new)
|
105 |
+
assert len(f0_ref) == len(f0_deg)
|
106 |
+
|
107 |
+
# Convert to tensor
|
108 |
+
f0_ref = torch.from_numpy(f0_ref)
|
109 |
+
f0_deg = torch.from_numpy(f0_deg)
|
110 |
+
|
111 |
+
return pearson(f0_ref, f0_deg).numpy().tolist()
|
evaluation/metrics/f0/f0_periodicity_rmse.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torchcrepe
|
7 |
+
import math
|
8 |
+
import librosa
|
9 |
+
import torch
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
def extract_f0_periodicity_rmse(
|
15 |
+
audio_ref,
|
16 |
+
audio_deg,
|
17 |
+
fs=None,
|
18 |
+
hop_length=256,
|
19 |
+
method="dtw",
|
20 |
+
):
|
21 |
+
"""Compute f0 periodicity Root Mean Square Error (RMSE) between the predicted and the ground truth audio.
|
22 |
+
audio_ref: path to the ground truth audio.
|
23 |
+
audio_deg: path to the predicted audio.
|
24 |
+
fs: sampling rate.
|
25 |
+
hop_length: hop length.
|
26 |
+
method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
|
27 |
+
"cut" will cut both audios into a same length according to the one with the shorter length.
|
28 |
+
"""
|
29 |
+
# Load audio
|
30 |
+
if fs != None:
|
31 |
+
audio_ref, _ = librosa.load(audio_ref, sr=fs)
|
32 |
+
audio_deg, _ = librosa.load(audio_deg, sr=fs)
|
33 |
+
else:
|
34 |
+
audio_ref, fs = librosa.load(audio_ref)
|
35 |
+
audio_deg, fs = librosa.load(audio_deg)
|
36 |
+
|
37 |
+
# Convert to torch
|
38 |
+
audio_ref = torch.from_numpy(audio_ref).unsqueeze(0)
|
39 |
+
audio_deg = torch.from_numpy(audio_deg).unsqueeze(0)
|
40 |
+
|
41 |
+
# Get periodicity
|
42 |
+
pitch_ref, periodicity_ref = torchcrepe.predict(
|
43 |
+
audio_ref,
|
44 |
+
sample_rate=fs,
|
45 |
+
hop_length=hop_length,
|
46 |
+
fmin=0,
|
47 |
+
fmax=1500,
|
48 |
+
model="full",
|
49 |
+
return_periodicity=True,
|
50 |
+
device="cuda:0",
|
51 |
+
)
|
52 |
+
pitch_deg, periodicity_deg = torchcrepe.predict(
|
53 |
+
audio_deg,
|
54 |
+
sample_rate=fs,
|
55 |
+
hop_length=hop_length,
|
56 |
+
fmin=0,
|
57 |
+
fmax=1500,
|
58 |
+
model="full",
|
59 |
+
return_periodicity=True,
|
60 |
+
device="cuda:0",
|
61 |
+
)
|
62 |
+
|
63 |
+
# Cut silence
|
64 |
+
periodicity_ref = (
|
65 |
+
torchcrepe.threshold.Silence()(
|
66 |
+
periodicity_ref,
|
67 |
+
audio_ref,
|
68 |
+
fs,
|
69 |
+
hop_length=hop_length,
|
70 |
+
)
|
71 |
+
.squeeze(0)
|
72 |
+
.numpy()
|
73 |
+
)
|
74 |
+
periodicity_deg = (
|
75 |
+
torchcrepe.threshold.Silence()(
|
76 |
+
periodicity_deg,
|
77 |
+
audio_deg,
|
78 |
+
fs,
|
79 |
+
hop_length=hop_length,
|
80 |
+
)
|
81 |
+
.squeeze(0)
|
82 |
+
.numpy()
|
83 |
+
)
|
84 |
+
|
85 |
+
# Avoid silence audio
|
86 |
+
min_length = min(len(periodicity_ref), len(periodicity_deg))
|
87 |
+
if min_length <= 1:
|
88 |
+
return 0
|
89 |
+
|
90 |
+
# Periodicity length alignment
|
91 |
+
if method == "cut":
|
92 |
+
length = min(len(periodicity_ref), len(periodicity_deg))
|
93 |
+
periodicity_ref = periodicity_ref[:length]
|
94 |
+
periodicity_deg = periodicity_deg[:length]
|
95 |
+
elif method == "dtw":
|
96 |
+
_, wp = librosa.sequence.dtw(periodicity_ref, periodicity_deg, backtrack=True)
|
97 |
+
periodicity_ref_new = []
|
98 |
+
periodicity_deg_new = []
|
99 |
+
for i in range(wp.shape[0]):
|
100 |
+
ref_index = wp[i][0]
|
101 |
+
deg_index = wp[i][1]
|
102 |
+
periodicity_ref_new.append(periodicity_ref[ref_index])
|
103 |
+
periodicity_deg_new.append(periodicity_deg[deg_index])
|
104 |
+
periodicity_ref = np.array(periodicity_ref_new)
|
105 |
+
periodicity_deg = np.array(periodicity_deg_new)
|
106 |
+
assert len(periodicity_ref) == len(periodicity_deg)
|
107 |
+
|
108 |
+
# Compute RMSE
|
109 |
+
periodicity_mse = np.square(np.subtract(periodicity_ref, periodicity_deg)).mean()
|
110 |
+
periodicity_rmse = math.sqrt(periodicity_mse)
|
111 |
+
|
112 |
+
return periodicity_rmse
|
evaluation/metrics/f0/f0_rmse.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import librosa
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from utils.util import JsonHParams
|
13 |
+
from utils.f0 import get_f0_features_using_parselmouth, get_pitch_sub_median
|
14 |
+
|
15 |
+
|
16 |
+
ZERO = 1e-8
|
17 |
+
|
18 |
+
|
19 |
+
def extract_f0rmse(
|
20 |
+
audio_ref,
|
21 |
+
audio_deg,
|
22 |
+
fs=None,
|
23 |
+
hop_length=256,
|
24 |
+
f0_min=37,
|
25 |
+
f0_max=1000,
|
26 |
+
pitch_bin=256,
|
27 |
+
pitch_max=1100.0,
|
28 |
+
pitch_min=50.0,
|
29 |
+
need_mean=True,
|
30 |
+
method="dtw",
|
31 |
+
):
|
32 |
+
"""Compute F0 Root Mean Square Error (RMSE) between the predicted and the ground truth audio.
|
33 |
+
audio_ref: path to the ground truth audio.
|
34 |
+
audio_deg: path to the predicted audio.
|
35 |
+
fs: sampling rate.
|
36 |
+
hop_length: hop length.
|
37 |
+
f0_min: lower limit for f0.
|
38 |
+
f0_max: upper limit for f0.
|
39 |
+
pitch_bin: number of bins for f0 quantization.
|
40 |
+
pitch_max: upper limit for f0 quantization.
|
41 |
+
pitch_min: lower limit for f0 quantization.
|
42 |
+
need_mean: subtract the mean value from f0 if "True".
|
43 |
+
method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
|
44 |
+
"cut" will cut both audios into a same length according to the one with the shorter length.
|
45 |
+
"""
|
46 |
+
# Load audio
|
47 |
+
if fs != None:
|
48 |
+
audio_ref, _ = librosa.load(audio_ref, sr=fs)
|
49 |
+
audio_deg, _ = librosa.load(audio_deg, sr=fs)
|
50 |
+
else:
|
51 |
+
audio_ref, fs = librosa.load(audio_ref)
|
52 |
+
audio_deg, fs = librosa.load(audio_deg)
|
53 |
+
|
54 |
+
# Initialize config for f0 extraction
|
55 |
+
cfg = JsonHParams()
|
56 |
+
cfg.sample_rate = fs
|
57 |
+
cfg.hop_size = hop_length
|
58 |
+
cfg.f0_min = f0_min
|
59 |
+
cfg.f0_max = f0_max
|
60 |
+
cfg.pitch_bin = pitch_bin
|
61 |
+
cfg.pitch_max = pitch_max
|
62 |
+
cfg.pitch_min = pitch_min
|
63 |
+
|
64 |
+
# Extract f0
|
65 |
+
f0_ref = get_f0_features_using_parselmouth(
|
66 |
+
audio_ref,
|
67 |
+
cfg,
|
68 |
+
)[0]
|
69 |
+
|
70 |
+
f0_deg = get_f0_features_using_parselmouth(
|
71 |
+
audio_deg,
|
72 |
+
cfg,
|
73 |
+
)[0]
|
74 |
+
|
75 |
+
# Subtract mean value from f0
|
76 |
+
if need_mean:
|
77 |
+
f0_ref = torch.from_numpy(f0_ref)
|
78 |
+
f0_deg = torch.from_numpy(f0_deg)
|
79 |
+
|
80 |
+
f0_ref = get_pitch_sub_median(f0_ref).numpy()
|
81 |
+
f0_deg = get_pitch_sub_median(f0_deg).numpy()
|
82 |
+
|
83 |
+
# Avoid silence
|
84 |
+
min_length = min(len(f0_ref), len(f0_deg))
|
85 |
+
if min_length <= 1:
|
86 |
+
return 0
|
87 |
+
|
88 |
+
# F0 length alignment
|
89 |
+
if method == "cut":
|
90 |
+
length = min(len(f0_ref), len(f0_deg))
|
91 |
+
f0_ref = f0_ref[:length]
|
92 |
+
f0_deg = f0_deg[:length]
|
93 |
+
elif method == "dtw":
|
94 |
+
_, wp = librosa.sequence.dtw(f0_ref, f0_deg, backtrack=True)
|
95 |
+
f0_gt_new = []
|
96 |
+
f0_pred_new = []
|
97 |
+
for i in range(wp.shape[0]):
|
98 |
+
gt_index = wp[i][0]
|
99 |
+
pred_index = wp[i][1]
|
100 |
+
f0_gt_new.append(f0_ref[gt_index])
|
101 |
+
f0_pred_new.append(f0_deg[pred_index])
|
102 |
+
f0_ref = np.array(f0_gt_new)
|
103 |
+
f0_deg = np.array(f0_pred_new)
|
104 |
+
assert len(f0_ref) == len(f0_deg)
|
105 |
+
|
106 |
+
# Compute RMSE
|
107 |
+
f0_mse = np.square(np.subtract(f0_ref, f0_deg)).mean()
|
108 |
+
f0_rmse = math.sqrt(f0_mse)
|
109 |
+
|
110 |
+
return f0_rmse
|
evaluation/metrics/f0/v_uv_f1.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import librosa
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from utils.util import JsonHParams
|
13 |
+
from utils.f0 import get_f0_features_using_parselmouth
|
14 |
+
|
15 |
+
|
16 |
+
ZERO = 1e-8
|
17 |
+
|
18 |
+
|
19 |
+
def extract_f1_v_uv(
|
20 |
+
audio_ref,
|
21 |
+
audio_deg,
|
22 |
+
fs=None,
|
23 |
+
hop_length=256,
|
24 |
+
f0_min=37,
|
25 |
+
f0_max=1000,
|
26 |
+
pitch_bin=256,
|
27 |
+
pitch_max=1100.0,
|
28 |
+
pitch_min=50.0,
|
29 |
+
method="dtw",
|
30 |
+
):
|
31 |
+
"""Compute F1 socre of voiced/unvoiced accuracy between the predicted and the ground truth audio.
|
32 |
+
audio_ref: path to the ground truth audio.
|
33 |
+
audio_deg: path to the predicted audio.
|
34 |
+
fs: sampling rate.
|
35 |
+
hop_length: hop length.
|
36 |
+
f0_min: lower limit for f0.
|
37 |
+
f0_max: upper limit for f0.
|
38 |
+
pitch_bin: number of bins for f0 quantization.
|
39 |
+
pitch_max: upper limit for f0 quantization.
|
40 |
+
pitch_min: lower limit for f0 quantization.
|
41 |
+
need_mean: subtract the mean value from f0 if "True".
|
42 |
+
method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
|
43 |
+
"cut" will cut both audios into a same length according to the one with the shorter length.
|
44 |
+
"""
|
45 |
+
# Load audio
|
46 |
+
if fs != None:
|
47 |
+
audio_ref, _ = librosa.load(audio_ref, sr=fs)
|
48 |
+
audio_deg, _ = librosa.load(audio_deg, sr=fs)
|
49 |
+
else:
|
50 |
+
audio_ref, fs = librosa.load(audio_ref)
|
51 |
+
audio_deg, fs = librosa.load(audio_deg)
|
52 |
+
|
53 |
+
# Initialize config
|
54 |
+
cfg = JsonHParams()
|
55 |
+
cfg.sample_rate = fs
|
56 |
+
cfg.hop_size = hop_length
|
57 |
+
cfg.f0_min = f0_min
|
58 |
+
cfg.f0_max = f0_max
|
59 |
+
cfg.pitch_bin = pitch_bin
|
60 |
+
cfg.pitch_max = pitch_max
|
61 |
+
cfg.pitch_min = pitch_min
|
62 |
+
|
63 |
+
# Compute f0
|
64 |
+
f0_ref = get_f0_features_using_parselmouth(
|
65 |
+
audio_ref,
|
66 |
+
cfg,
|
67 |
+
)[0]
|
68 |
+
|
69 |
+
f0_deg = get_f0_features_using_parselmouth(
|
70 |
+
audio_deg,
|
71 |
+
cfg,
|
72 |
+
)[0]
|
73 |
+
|
74 |
+
# Avoid silence
|
75 |
+
min_length = min(len(f0_ref), len(f0_deg))
|
76 |
+
if min_length <= 1:
|
77 |
+
return 0, 0, 0
|
78 |
+
|
79 |
+
# F0 length alignment
|
80 |
+
if method == "cut":
|
81 |
+
length = min(len(f0_ref), len(f0_deg))
|
82 |
+
f0_ref = f0_ref[:length]
|
83 |
+
f0_deg = f0_deg[:length]
|
84 |
+
elif method == "dtw":
|
85 |
+
_, wp = librosa.sequence.dtw(f0_ref, f0_deg, backtrack=True)
|
86 |
+
f0_gt_new = []
|
87 |
+
f0_pred_new = []
|
88 |
+
for i in range(wp.shape[0]):
|
89 |
+
gt_index = wp[i][0]
|
90 |
+
pred_index = wp[i][1]
|
91 |
+
f0_gt_new.append(f0_ref[gt_index])
|
92 |
+
f0_pred_new.append(f0_deg[pred_index])
|
93 |
+
f0_ref = np.array(f0_gt_new)
|
94 |
+
f0_deg = np.array(f0_pred_new)
|
95 |
+
assert len(f0_ref) == len(f0_deg)
|
96 |
+
|
97 |
+
# Get voiced/unvoiced parts
|
98 |
+
ref_voiced = torch.Tensor([f0_ref != 0]).bool()
|
99 |
+
deg_voiced = torch.Tensor([f0_deg != 0]).bool()
|
100 |
+
|
101 |
+
# Compute TP, FP, FN
|
102 |
+
true_postives = (ref_voiced & deg_voiced).sum()
|
103 |
+
false_postives = (~ref_voiced & deg_voiced).sum()
|
104 |
+
false_negatives = (ref_voiced & ~deg_voiced).sum()
|
105 |
+
|
106 |
+
return (
|
107 |
+
true_postives.numpy().tolist(),
|
108 |
+
false_postives.numpy().tolist(),
|
109 |
+
false_negatives.numpy().tolist(),
|
110 |
+
)
|
evaluation/metrics/intelligibility/__init__.py
ADDED
File without changes
|
evaluation/metrics/intelligibility/character_error_rate.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import whisper
|
7 |
+
|
8 |
+
from torchmetrics import CharErrorRate
|
9 |
+
|
10 |
+
|
11 |
+
def extract_cer(
|
12 |
+
content_gt=None,
|
13 |
+
audio_ref=None,
|
14 |
+
audio_deg=None,
|
15 |
+
fs=None,
|
16 |
+
language="chinese",
|
17 |
+
remove_space=True,
|
18 |
+
remove_punctuation=True,
|
19 |
+
mode="gt_audio",
|
20 |
+
):
|
21 |
+
"""Compute Character Error Rate (CER) between the predicted and the ground truth audio.
|
22 |
+
content_gt: the ground truth content.
|
23 |
+
audio_ref: path to the ground truth audio.
|
24 |
+
audio_deg: path to the predicted audio.
|
25 |
+
mode: "gt_content" computes the CER between the predicted content obtained from the whisper model and the ground truth content.
|
26 |
+
both content_gt and audio_deg are needed.
|
27 |
+
"gt_audio" computes the CER between the extracted ground truth and predicted contents obtained from the whisper model.
|
28 |
+
both audio_ref and audio_deg are needed.
|
29 |
+
"""
|
30 |
+
# Get ground truth content
|
31 |
+
if mode == "gt_content":
|
32 |
+
assert content_gt != None
|
33 |
+
if language == "chinese":
|
34 |
+
prompt = "以下是普通话的句子"
|
35 |
+
model = whisper.load_model("large").cuda()
|
36 |
+
result_deg = model.transcribe(
|
37 |
+
audio_deg, language="zh", verbose=True, initial_prompt=prompt
|
38 |
+
)
|
39 |
+
elif language == "english":
|
40 |
+
model = whisper.load_model("large").cuda()
|
41 |
+
result_deg = model.transcribe(audio_deg, language="en", verbose=True)
|
42 |
+
elif mode == "gt_audio":
|
43 |
+
assert audio_ref != None
|
44 |
+
if language == "chinese":
|
45 |
+
prompt = "以下是普通话的句子"
|
46 |
+
model = whisper.load_model("large").cuda()
|
47 |
+
result_ref = model.transcribe(
|
48 |
+
audio_ref, language="zh", verbose=True, initial_prompt=prompt
|
49 |
+
)
|
50 |
+
result_deg = model.transcribe(
|
51 |
+
audio_deg, language="zh", verbose=True, initial_prompt=prompt
|
52 |
+
)
|
53 |
+
elif language == "english":
|
54 |
+
model = whisper.load_model("large").cuda()
|
55 |
+
result_ref = model.transcribe(audio_deg, language="en", verbose=True)
|
56 |
+
result_deg = model.transcribe(audio_deg, language="en", verbose=True)
|
57 |
+
content_gt = result_ref["text"]
|
58 |
+
if remove_space:
|
59 |
+
content_gt = content_gt.replace(" ", "")
|
60 |
+
if remove_punctuation:
|
61 |
+
content_gt = content_gt.replace(".", "")
|
62 |
+
content_gt = content_gt.replace("'", "")
|
63 |
+
content_gt = content_gt.replace("-", "")
|
64 |
+
content_gt = content_gt.replace(",", "")
|
65 |
+
content_gt = content_gt.replace("!", "")
|
66 |
+
content_gt = content_gt.lower()
|
67 |
+
|
68 |
+
# Get predicted truth content
|
69 |
+
content_pred = result_deg["text"]
|
70 |
+
if remove_space:
|
71 |
+
content_pred = content_pred.replace(" ", "")
|
72 |
+
if remove_punctuation:
|
73 |
+
content_pred = content_pred.replace(".", "")
|
74 |
+
content_pred = content_pred.replace("'", "")
|
75 |
+
content_pred = content_pred.replace("-", "")
|
76 |
+
content_pred = content_pred.replace(",", "")
|
77 |
+
content_pred = content_pred.replace("!", "")
|
78 |
+
content_pred = content_pred.lower()
|
79 |
+
cer = CharErrorRate()
|
80 |
+
|
81 |
+
return cer(content_pred, content_gt).numpy().tolist()
|
evaluation/metrics/intelligibility/word_error_rate.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import whisper
|
7 |
+
|
8 |
+
from torchmetrics import WordErrorRate
|
9 |
+
|
10 |
+
|
11 |
+
def extract_wer(
|
12 |
+
content_gt=None,
|
13 |
+
audio_ref=None,
|
14 |
+
audio_deg=None,
|
15 |
+
fs=None,
|
16 |
+
language="chinese",
|
17 |
+
remove_space=True,
|
18 |
+
remove_punctuation=True,
|
19 |
+
mode="gt_audio",
|
20 |
+
):
|
21 |
+
"""Compute Word Error Rate (WER) between the predicted and the ground truth audio.
|
22 |
+
content_gt: the ground truth content.
|
23 |
+
audio_ref: path to the ground truth audio.
|
24 |
+
audio_deg: path to the predicted audio.
|
25 |
+
mode: "gt_content" computes the WER between the predicted content obtained from the whisper model and the ground truth content.
|
26 |
+
both content_gt and audio_deg are needed.
|
27 |
+
"gt_audio" computes the WER between the extracted ground truth and predicted contents obtained from the whisper model.
|
28 |
+
both audio_ref and audio_deg are needed.
|
29 |
+
"""
|
30 |
+
# Get ground truth content
|
31 |
+
if mode == "gt_content":
|
32 |
+
assert content_gt != None
|
33 |
+
if language == "chinese":
|
34 |
+
prompt = "以下是普通话的句子"
|
35 |
+
model = whisper.load_model("large").cuda()
|
36 |
+
result_deg = model.transcribe(
|
37 |
+
audio_deg, language="zh", verbose=True, initial_prompt=prompt
|
38 |
+
)
|
39 |
+
elif language == "english":
|
40 |
+
model = whisper.load_model("large").cuda()
|
41 |
+
result_deg = model.transcribe(audio_deg, language="en", verbose=True)
|
42 |
+
elif mode == "gt_audio":
|
43 |
+
assert audio_ref != None
|
44 |
+
if language == "chinese":
|
45 |
+
prompt = "以下是普通话的句子"
|
46 |
+
model = whisper.load_model("large").cuda()
|
47 |
+
result_ref = model.transcribe(
|
48 |
+
audio_ref, language="zh", verbose=True, initial_prompt=prompt
|
49 |
+
)
|
50 |
+
result_deg = model.transcribe(
|
51 |
+
audio_deg, language="zh", verbose=True, initial_prompt=prompt
|
52 |
+
)
|
53 |
+
elif language == "english":
|
54 |
+
model = whisper.load_model("large").cuda()
|
55 |
+
result_ref = model.transcribe(audio_deg, language="en", verbose=True)
|
56 |
+
result_deg = model.transcribe(audio_deg, language="en", verbose=True)
|
57 |
+
content_gt = result_ref["text"]
|
58 |
+
if remove_space:
|
59 |
+
content_gt = content_gt.replace(" ", "")
|
60 |
+
if remove_punctuation:
|
61 |
+
content_gt = content_gt.replace(".", "")
|
62 |
+
content_gt = content_gt.replace("'", "")
|
63 |
+
content_gt = content_gt.replace("-", "")
|
64 |
+
content_gt = content_gt.replace(",", "")
|
65 |
+
content_gt = content_gt.replace("!", "")
|
66 |
+
content_gt = content_gt.lower()
|
67 |
+
|
68 |
+
# Get predicted content
|
69 |
+
content_pred = result_deg["text"]
|
70 |
+
if remove_space:
|
71 |
+
content_pred = content_pred.replace(" ", "")
|
72 |
+
if remove_punctuation:
|
73 |
+
content_pred = content_pred.replace(".", "")
|
74 |
+
content_pred = content_pred.replace("'", "")
|
75 |
+
content_pred = content_pred.replace("-", "")
|
76 |
+
content_pred = content_pred.replace(",", "")
|
77 |
+
content_pred = content_pred.replace("!", "")
|
78 |
+
content_pred = content_pred.lower()
|
79 |
+
wer = WordErrorRate()
|
80 |
+
|
81 |
+
return wer(content_pred, content_gt).numpy().tolist()
|
evaluation/metrics/similarity/__init__.py
ADDED
File without changes
|
evaluation/metrics/similarity/models/RawNetBasicBlock.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class PreEmphasis(torch.nn.Module):
|
14 |
+
def __init__(self, coef: float = 0.97) -> None:
|
15 |
+
super().__init__()
|
16 |
+
self.coef = coef
|
17 |
+
# make kernel
|
18 |
+
# In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
|
19 |
+
self.register_buffer(
|
20 |
+
"flipped_filter",
|
21 |
+
torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0),
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, input: torch.tensor) -> torch.tensor:
|
25 |
+
assert (
|
26 |
+
len(input.size()) == 2
|
27 |
+
), "The number of dimensions of input tensor must be 2!"
|
28 |
+
# reflect padding to match lengths of in/out
|
29 |
+
input = input.unsqueeze(1)
|
30 |
+
input = F.pad(input, (1, 0), "reflect")
|
31 |
+
return F.conv1d(input, self.flipped_filter)
|
32 |
+
|
33 |
+
|
34 |
+
class AFMS(nn.Module):
|
35 |
+
"""
|
36 |
+
Alpha-Feature map scaling, added to the output of each residual block[1,2].
|
37 |
+
|
38 |
+
Reference:
|
39 |
+
[1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf
|
40 |
+
[2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, nb_dim: int) -> None:
|
44 |
+
super().__init__()
|
45 |
+
self.alpha = nn.Parameter(torch.ones((nb_dim, 1)))
|
46 |
+
self.fc = nn.Linear(nb_dim, nb_dim)
|
47 |
+
self.sig = nn.Sigmoid()
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1)
|
51 |
+
y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1)
|
52 |
+
|
53 |
+
x = x + self.alpha
|
54 |
+
x = x * y
|
55 |
+
return x
|
56 |
+
|
57 |
+
|
58 |
+
class Bottle2neck(nn.Module):
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
inplanes,
|
62 |
+
planes,
|
63 |
+
kernel_size=None,
|
64 |
+
dilation=None,
|
65 |
+
scale=4,
|
66 |
+
pool=False,
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
|
70 |
+
width = int(math.floor(planes / scale))
|
71 |
+
|
72 |
+
self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
|
73 |
+
self.bn1 = nn.BatchNorm1d(width * scale)
|
74 |
+
|
75 |
+
self.nums = scale - 1
|
76 |
+
|
77 |
+
convs = []
|
78 |
+
bns = []
|
79 |
+
|
80 |
+
num_pad = math.floor(kernel_size / 2) * dilation
|
81 |
+
|
82 |
+
for i in range(self.nums):
|
83 |
+
convs.append(
|
84 |
+
nn.Conv1d(
|
85 |
+
width,
|
86 |
+
width,
|
87 |
+
kernel_size=kernel_size,
|
88 |
+
dilation=dilation,
|
89 |
+
padding=num_pad,
|
90 |
+
)
|
91 |
+
)
|
92 |
+
bns.append(nn.BatchNorm1d(width))
|
93 |
+
|
94 |
+
self.convs = nn.ModuleList(convs)
|
95 |
+
self.bns = nn.ModuleList(bns)
|
96 |
+
|
97 |
+
self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
|
98 |
+
self.bn3 = nn.BatchNorm1d(planes)
|
99 |
+
|
100 |
+
self.relu = nn.ReLU()
|
101 |
+
|
102 |
+
self.width = width
|
103 |
+
|
104 |
+
self.mp = nn.MaxPool1d(pool) if pool else False
|
105 |
+
self.afms = AFMS(planes)
|
106 |
+
|
107 |
+
if inplanes != planes: # if change in number of filters
|
108 |
+
self.residual = nn.Sequential(
|
109 |
+
nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False)
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
self.residual = nn.Identity()
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
residual = self.residual(x)
|
116 |
+
|
117 |
+
out = self.conv1(x)
|
118 |
+
out = self.relu(out)
|
119 |
+
out = self.bn1(out)
|
120 |
+
|
121 |
+
spx = torch.split(out, self.width, 1)
|
122 |
+
for i in range(self.nums):
|
123 |
+
if i == 0:
|
124 |
+
sp = spx[i]
|
125 |
+
else:
|
126 |
+
sp = sp + spx[i]
|
127 |
+
sp = self.convs[i](sp)
|
128 |
+
sp = self.relu(sp)
|
129 |
+
sp = self.bns[i](sp)
|
130 |
+
if i == 0:
|
131 |
+
out = sp
|
132 |
+
else:
|
133 |
+
out = torch.cat((out, sp), 1)
|
134 |
+
|
135 |
+
out = torch.cat((out, spx[self.nums]), 1)
|
136 |
+
|
137 |
+
out = self.conv3(out)
|
138 |
+
out = self.relu(out)
|
139 |
+
out = self.bn3(out)
|
140 |
+
|
141 |
+
out += residual
|
142 |
+
if self.mp:
|
143 |
+
out = self.mp(out)
|
144 |
+
out = self.afms(out)
|
145 |
+
|
146 |
+
return out
|
evaluation/metrics/similarity/models/RawNetModel.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# -*- encoding: utf-8 -*-
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from asteroid_filterbanks import Encoder, ParamSincFB
|
11 |
+
|
12 |
+
from .RawNetBasicBlock import Bottle2neck, PreEmphasis
|
13 |
+
|
14 |
+
|
15 |
+
class RawNet3(nn.Module):
|
16 |
+
def __init__(self, block, model_scale, context, summed, C=1024, **kwargs):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
nOut = kwargs["nOut"]
|
20 |
+
|
21 |
+
self.context = context
|
22 |
+
self.encoder_type = kwargs["encoder_type"]
|
23 |
+
self.log_sinc = kwargs["log_sinc"]
|
24 |
+
self.norm_sinc = kwargs["norm_sinc"]
|
25 |
+
self.out_bn = kwargs["out_bn"]
|
26 |
+
self.summed = summed
|
27 |
+
|
28 |
+
self.preprocess = nn.Sequential(
|
29 |
+
PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True)
|
30 |
+
)
|
31 |
+
self.conv1 = Encoder(
|
32 |
+
ParamSincFB(
|
33 |
+
C // 4,
|
34 |
+
251,
|
35 |
+
stride=kwargs["sinc_stride"],
|
36 |
+
)
|
37 |
+
)
|
38 |
+
self.relu = nn.ReLU()
|
39 |
+
self.bn1 = nn.BatchNorm1d(C // 4)
|
40 |
+
|
41 |
+
self.layer1 = block(
|
42 |
+
C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5
|
43 |
+
)
|
44 |
+
self.layer2 = block(C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3)
|
45 |
+
self.layer3 = block(C, C, kernel_size=3, dilation=4, scale=model_scale)
|
46 |
+
self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)
|
47 |
+
|
48 |
+
if self.context:
|
49 |
+
attn_input = 1536 * 3
|
50 |
+
else:
|
51 |
+
attn_input = 1536
|
52 |
+
print("self.encoder_type", self.encoder_type)
|
53 |
+
if self.encoder_type == "ECA":
|
54 |
+
attn_output = 1536
|
55 |
+
elif self.encoder_type == "ASP":
|
56 |
+
attn_output = 1
|
57 |
+
else:
|
58 |
+
raise ValueError("Undefined encoder")
|
59 |
+
|
60 |
+
self.attention = nn.Sequential(
|
61 |
+
nn.Conv1d(attn_input, 128, kernel_size=1),
|
62 |
+
nn.ReLU(),
|
63 |
+
nn.BatchNorm1d(128),
|
64 |
+
nn.Conv1d(128, attn_output, kernel_size=1),
|
65 |
+
nn.Softmax(dim=2),
|
66 |
+
)
|
67 |
+
|
68 |
+
self.bn5 = nn.BatchNorm1d(3072)
|
69 |
+
|
70 |
+
self.fc6 = nn.Linear(3072, nOut)
|
71 |
+
self.bn6 = nn.BatchNorm1d(nOut)
|
72 |
+
|
73 |
+
self.mp3 = nn.MaxPool1d(3)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
"""
|
77 |
+
:param x: input mini-batch (bs, samp)
|
78 |
+
"""
|
79 |
+
|
80 |
+
with torch.cuda.amp.autocast(enabled=False):
|
81 |
+
x = self.preprocess(x)
|
82 |
+
x = torch.abs(self.conv1(x))
|
83 |
+
if self.log_sinc:
|
84 |
+
x = torch.log(x + 1e-6)
|
85 |
+
if self.norm_sinc == "mean":
|
86 |
+
x = x - torch.mean(x, dim=-1, keepdim=True)
|
87 |
+
elif self.norm_sinc == "mean_std":
|
88 |
+
m = torch.mean(x, dim=-1, keepdim=True)
|
89 |
+
s = torch.std(x, dim=-1, keepdim=True)
|
90 |
+
s[s < 0.001] = 0.001
|
91 |
+
x = (x - m) / s
|
92 |
+
|
93 |
+
if self.summed:
|
94 |
+
x1 = self.layer1(x)
|
95 |
+
x2 = self.layer2(x1)
|
96 |
+
x3 = self.layer3(self.mp3(x1) + x2)
|
97 |
+
else:
|
98 |
+
x1 = self.layer1(x)
|
99 |
+
x2 = self.layer2(x1)
|
100 |
+
x3 = self.layer3(x2)
|
101 |
+
|
102 |
+
x = self.layer4(torch.cat((self.mp3(x1), x2, x3), dim=1))
|
103 |
+
x = self.relu(x)
|
104 |
+
|
105 |
+
t = x.size()[-1]
|
106 |
+
|
107 |
+
if self.context:
|
108 |
+
global_x = torch.cat(
|
109 |
+
(
|
110 |
+
x,
|
111 |
+
torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
|
112 |
+
torch.sqrt(
|
113 |
+
torch.var(x, dim=2, keepdim=True).clamp(min=1e-4, max=1e4)
|
114 |
+
).repeat(1, 1, t),
|
115 |
+
),
|
116 |
+
dim=1,
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
global_x = x
|
120 |
+
|
121 |
+
w = self.attention(global_x)
|
122 |
+
|
123 |
+
mu = torch.sum(x * w, dim=2)
|
124 |
+
sg = torch.sqrt(
|
125 |
+
(torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4)
|
126 |
+
)
|
127 |
+
|
128 |
+
x = torch.cat((mu, sg), 1)
|
129 |
+
|
130 |
+
x = self.bn5(x)
|
131 |
+
|
132 |
+
x = self.fc6(x)
|
133 |
+
|
134 |
+
if self.out_bn:
|
135 |
+
x = self.bn6(x)
|
136 |
+
|
137 |
+
return x
|
138 |
+
|
139 |
+
|
140 |
+
def MainModel(**kwargs):
|
141 |
+
model = RawNet3(Bottle2neck, model_scale=8, context=True, summed=True, **kwargs)
|
142 |
+
return model
|
evaluation/metrics/similarity/models/__init__.py
ADDED
File without changes
|
evaluation/metrics/similarity/speaker_similarity.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import soundfile as sf
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from tqdm import tqdm
|
13 |
+
import librosa
|
14 |
+
|
15 |
+
from .models.RawNetModel import RawNet3
|
16 |
+
from .models.RawNetBasicBlock import Bottle2neck
|
17 |
+
|
18 |
+
|
19 |
+
def extract_speaker_embd(
|
20 |
+
model, fn: str, n_samples: int, n_segments: int = 10, gpu: bool = False
|
21 |
+
) -> np.ndarray:
|
22 |
+
audio, sample_rate = sf.read(fn)
|
23 |
+
if len(audio.shape) > 1:
|
24 |
+
raise ValueError(
|
25 |
+
f"RawNet3 supports mono input only. Input data has a shape of {audio.shape}."
|
26 |
+
)
|
27 |
+
|
28 |
+
if sample_rate != 16000:
|
29 |
+
# resample to 16000kHz
|
30 |
+
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
|
31 |
+
# print("resample to 16000kHz!")
|
32 |
+
if len(audio) < n_samples: # RawNet3 was trained using utterances of 3 seconds
|
33 |
+
shortage = n_samples - len(audio) + 1
|
34 |
+
audio = np.pad(audio, (0, shortage), "wrap")
|
35 |
+
|
36 |
+
audios = []
|
37 |
+
startframe = np.linspace(0, len(audio) - n_samples, num=n_segments)
|
38 |
+
for asf in startframe:
|
39 |
+
audios.append(audio[int(asf) : int(asf) + n_samples])
|
40 |
+
|
41 |
+
audios = torch.from_numpy(np.stack(audios, axis=0).astype(np.float32))
|
42 |
+
if gpu:
|
43 |
+
audios = audios.to("cuda")
|
44 |
+
with torch.no_grad():
|
45 |
+
output = model(audios)
|
46 |
+
|
47 |
+
return output
|
48 |
+
|
49 |
+
|
50 |
+
def extract_speaker_similarity(target_path, reference_path):
|
51 |
+
model = RawNet3(
|
52 |
+
Bottle2neck,
|
53 |
+
model_scale=8,
|
54 |
+
context=True,
|
55 |
+
summed=True,
|
56 |
+
encoder_type="ECA",
|
57 |
+
nOut=256,
|
58 |
+
out_bn=False,
|
59 |
+
sinc_stride=10,
|
60 |
+
log_sinc=True,
|
61 |
+
norm_sinc="mean",
|
62 |
+
grad_mult=1,
|
63 |
+
)
|
64 |
+
|
65 |
+
gpu = False
|
66 |
+
model.load_state_dict(
|
67 |
+
torch.load(
|
68 |
+
"pretrained/rawnet3/model.pt",
|
69 |
+
map_location=lambda storage, loc: storage,
|
70 |
+
)["model"]
|
71 |
+
)
|
72 |
+
model.eval()
|
73 |
+
print("RawNet3 initialised & weights loaded!")
|
74 |
+
|
75 |
+
if torch.cuda.is_available():
|
76 |
+
print("Cuda available, conducting inference on GPU")
|
77 |
+
model = model.to("cuda")
|
78 |
+
gpu = True
|
79 |
+
# for target_path, reference_path in zip(target_paths, ref_paths):
|
80 |
+
# print(f"Extracting embeddings for target singers...")
|
81 |
+
|
82 |
+
target_embeddings = []
|
83 |
+
for file in tqdm(os.listdir(target_path)):
|
84 |
+
output = extract_speaker_embd(
|
85 |
+
model,
|
86 |
+
fn=os.path.join(target_path, file),
|
87 |
+
n_samples=48000,
|
88 |
+
n_segments=10,
|
89 |
+
gpu=gpu,
|
90 |
+
).mean(0)
|
91 |
+
target_embeddings.append(output.detach().cpu().numpy())
|
92 |
+
target_embeddings = np.array(target_embeddings)
|
93 |
+
target_embedding = np.mean(target_embeddings, axis=0)
|
94 |
+
|
95 |
+
# print(f"Extracting embeddings for reference singer...")
|
96 |
+
|
97 |
+
reference_embeddings = []
|
98 |
+
for file in tqdm(os.listdir(reference_path)):
|
99 |
+
output = extract_speaker_embd(
|
100 |
+
model,
|
101 |
+
fn=os.path.join(reference_path, file),
|
102 |
+
n_samples=48000,
|
103 |
+
n_segments=10,
|
104 |
+
gpu=gpu,
|
105 |
+
).mean(0)
|
106 |
+
reference_embeddings.append(output.detach().cpu().numpy())
|
107 |
+
reference_embeddings = np.array(reference_embeddings)
|
108 |
+
|
109 |
+
# print("Calculating cosine similarity...")
|
110 |
+
|
111 |
+
cos_sim = F.cosine_similarity(
|
112 |
+
torch.from_numpy(np.mean(target_embeddings, axis=0)).unsqueeze(0),
|
113 |
+
torch.from_numpy(np.mean(reference_embeddings, axis=0)).unsqueeze(0),
|
114 |
+
dim=1,
|
115 |
+
)
|
116 |
+
|
117 |
+
# print(f"Mean cosine similarity: {cos_sim.item()}")
|
118 |
+
|
119 |
+
return cos_sim.item()
|
evaluation/metrics/spectrogram/__init__.py
ADDED
File without changes
|
evaluation/metrics/spectrogram/frechet_distance.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from frechet_audio_distance import FrechetAudioDistance
|
7 |
+
|
8 |
+
|
9 |
+
def extract_fad(
|
10 |
+
audio_dir1,
|
11 |
+
audio_dir2,
|
12 |
+
mode="vggish",
|
13 |
+
use_pca=False,
|
14 |
+
use_activation=False,
|
15 |
+
verbose=False,
|
16 |
+
):
|
17 |
+
"""Extract Frechet Audio Distance for two given audio folders.
|
18 |
+
audio_dir1: path to the ground truth audio folder.
|
19 |
+
audio_dir2: path to the predicted audio folder.
|
20 |
+
mode: "vggish", "pann", "clap" for different models.
|
21 |
+
"""
|
22 |
+
frechet = FrechetAudioDistance(
|
23 |
+
model_name=mode,
|
24 |
+
use_pca=use_pca,
|
25 |
+
use_activation=use_activation,
|
26 |
+
verbose=verbose,
|
27 |
+
)
|
28 |
+
|
29 |
+
fad_score = frechet.score(audio_dir1, audio_dir2)
|
30 |
+
|
31 |
+
return fad_score
|
evaluation/metrics/spectrogram/mel_cepstral_distortion.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from pymcd.mcd import Calculate_MCD
|
7 |
+
|
8 |
+
|
9 |
+
def extract_mcd(audio_ref, audio_deg, fs=None, mode="dtw_sl"):
|
10 |
+
"""Extract Mel-Cepstral Distance for a two given audio.
|
11 |
+
Args:
|
12 |
+
audio_ref: The given reference audio. It is an audio path.
|
13 |
+
audio_deg: The given synthesized audio. It is an audio path.
|
14 |
+
mode: "plain", "dtw" and "dtw_sl".
|
15 |
+
"""
|
16 |
+
mcd_toolbox = Calculate_MCD(MCD_mode=mode)
|
17 |
+
if fs != None:
|
18 |
+
mcd_toolbox.SAMPLING_RATE = fs
|
19 |
+
mcd_value = mcd_toolbox.calculate_mcd(audio_ref, audio_deg)
|
20 |
+
|
21 |
+
return mcd_value
|
evaluation/metrics/spectrogram/multi_resolution_stft_distance.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import torch
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def extract_mstft(
|
13 |
+
audio_ref,
|
14 |
+
audio_deg,
|
15 |
+
fs=None,
|
16 |
+
mid_freq=None,
|
17 |
+
high_freq=None,
|
18 |
+
method="cut",
|
19 |
+
version="pwg",
|
20 |
+
):
|
21 |
+
"""Compute Multi-Scale STFT Distance (mstft) between the predicted and the ground truth audio.
|
22 |
+
audio_ref: path to the ground truth audio.
|
23 |
+
audio_deg: path to the predicted audio.
|
24 |
+
fs: sampling rate.
|
25 |
+
med_freq: division frequency for mid frequency parts.
|
26 |
+
high_freq: division frequency for high frequency parts.
|
27 |
+
method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
|
28 |
+
"cut" will cut both audios into a same length according to the one with the shorter length.
|
29 |
+
version: "pwg" will use the computational version provided by ParallelWaveGAN.
|
30 |
+
"encodec" will use the computational version provided by Encodec.
|
31 |
+
"""
|
32 |
+
# Load audio
|
33 |
+
if fs != None:
|
34 |
+
audio_ref, _ = librosa.load(audio_ref, sr=fs)
|
35 |
+
audio_deg, _ = librosa.load(audio_deg, sr=fs)
|
36 |
+
else:
|
37 |
+
audio_ref, fs = librosa.load(audio_ref)
|
38 |
+
audio_deg, fs = librosa.load(audio_deg)
|
39 |
+
|
40 |
+
# Automatically choose mid_freq and high_freq if they are not given
|
41 |
+
if mid_freq == None:
|
42 |
+
mid_freq = fs // 6
|
43 |
+
if high_freq == None:
|
44 |
+
high_freq = fs // 3
|
45 |
+
|
46 |
+
# Audio length alignment
|
47 |
+
if len(audio_ref) != len(audio_deg):
|
48 |
+
if method == "cut":
|
49 |
+
length = min(len(audio_ref), len(audio_deg))
|
50 |
+
audio_ref = audio_ref[:length]
|
51 |
+
audio_deg = audio_deg[:length]
|
52 |
+
elif method == "dtw":
|
53 |
+
_, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
|
54 |
+
audio_ref_new = []
|
55 |
+
audio_deg_new = []
|
56 |
+
for i in range(wp.shape[0]):
|
57 |
+
ref_index = wp[i][0]
|
58 |
+
deg_index = wp[i][1]
|
59 |
+
audio_ref_new.append(audio_ref[ref_index])
|
60 |
+
audio_deg_new.append(audio_deg[deg_index])
|
61 |
+
audio_ref = np.array(audio_ref_new)
|
62 |
+
audio_deg = np.array(audio_deg_new)
|
63 |
+
assert len(audio_ref) == len(audio_deg)
|
64 |
+
|
65 |
+
# Define loss function
|
66 |
+
l1Loss = torch.nn.L1Loss(reduction="mean")
|
67 |
+
l2Loss = torch.nn.MSELoss(reduction="mean")
|
68 |
+
|
69 |
+
# Compute distance
|
70 |
+
if version == "encodec":
|
71 |
+
n_fft = 1024
|
72 |
+
|
73 |
+
mstft = 0
|
74 |
+
mstft_low = 0
|
75 |
+
mstft_mid = 0
|
76 |
+
mstft_high = 0
|
77 |
+
|
78 |
+
freq_resolution = fs / n_fft
|
79 |
+
mid_freq_index = 1 + int(np.floor(mid_freq / freq_resolution))
|
80 |
+
high_freq_index = 1 + int(np.floor(high_freq / freq_resolution))
|
81 |
+
|
82 |
+
for i in range(5, 11):
|
83 |
+
hop_length = 2**i // 4
|
84 |
+
win_length = 2**i
|
85 |
+
|
86 |
+
spec_ref = librosa.stft(
|
87 |
+
y=audio_ref, n_fft=n_fft, hop_length=hop_length, win_length=win_length
|
88 |
+
)
|
89 |
+
spec_deg = librosa.stft(
|
90 |
+
y=audio_deg, n_fft=n_fft, hop_length=hop_length, win_length=win_length
|
91 |
+
)
|
92 |
+
|
93 |
+
mag_ref = np.abs(spec_ref)
|
94 |
+
mag_deg = np.abs(spec_deg)
|
95 |
+
|
96 |
+
mag_ref = torch.from_numpy(mag_ref)
|
97 |
+
mag_deg = torch.from_numpy(mag_deg)
|
98 |
+
mstft += l1Loss(mag_ref, mag_deg) + l2Loss(mag_ref, mag_deg)
|
99 |
+
|
100 |
+
mag_ref_low = mag_ref[:mid_freq_index, :]
|
101 |
+
mag_deg_low = mag_deg[:mid_freq_index, :]
|
102 |
+
mstft_low += l1Loss(mag_ref_low, mag_deg_low) + l2Loss(
|
103 |
+
mag_ref_low, mag_deg_low
|
104 |
+
)
|
105 |
+
|
106 |
+
mag_ref_mid = mag_ref[mid_freq_index:high_freq_index, :]
|
107 |
+
mag_deg_mid = mag_deg[mid_freq_index:high_freq_index, :]
|
108 |
+
mstft_mid += l1Loss(mag_ref_mid, mag_deg_mid) + l2Loss(
|
109 |
+
mag_ref_mid, mag_deg_mid
|
110 |
+
)
|
111 |
+
|
112 |
+
mag_ref_high = mag_ref[high_freq_index:, :]
|
113 |
+
mag_deg_high = mag_deg[high_freq_index:, :]
|
114 |
+
mstft_high += l1Loss(mag_ref_high, mag_deg_high) + l2Loss(
|
115 |
+
mag_ref_high, mag_deg_high
|
116 |
+
)
|
117 |
+
|
118 |
+
mstft /= 6
|
119 |
+
mstft_low /= 6
|
120 |
+
mstft_mid /= 6
|
121 |
+
mstft_high /= 6
|
122 |
+
|
123 |
+
return mstft
|
124 |
+
elif version == "pwg":
|
125 |
+
fft_sizes = [1024, 2048, 512]
|
126 |
+
hop_sizes = [120, 240, 50]
|
127 |
+
win_sizes = [600, 1200, 240]
|
128 |
+
|
129 |
+
audio_ref = torch.from_numpy(audio_ref)
|
130 |
+
audio_deg = torch.from_numpy(audio_deg)
|
131 |
+
|
132 |
+
mstft_sc = 0
|
133 |
+
mstft_sc_low = 0
|
134 |
+
mstft_sc_mid = 0
|
135 |
+
mstft_sc_high = 0
|
136 |
+
|
137 |
+
mstft_mag = 0
|
138 |
+
mstft_mag_low = 0
|
139 |
+
mstft_mag_mid = 0
|
140 |
+
mstft_mag_high = 0
|
141 |
+
|
142 |
+
for n_fft, hop_length, win_length in zip(fft_sizes, hop_sizes, win_sizes):
|
143 |
+
spec_ref = torch.stft(
|
144 |
+
audio_ref, n_fft, hop_length, win_length, return_complex=False
|
145 |
+
)
|
146 |
+
spec_deg = torch.stft(
|
147 |
+
audio_deg, n_fft, hop_length, win_length, return_complex=False
|
148 |
+
)
|
149 |
+
|
150 |
+
real_ref = spec_ref[..., 0]
|
151 |
+
imag_ref = spec_ref[..., 1]
|
152 |
+
real_deg = spec_deg[..., 0]
|
153 |
+
imag_deg = spec_deg[..., 1]
|
154 |
+
|
155 |
+
mag_ref = torch.sqrt(
|
156 |
+
torch.clamp(real_ref**2 + imag_ref**2, min=1e-7)
|
157 |
+
).transpose(1, 0)
|
158 |
+
mag_deg = torch.sqrt(
|
159 |
+
torch.clamp(real_deg**2 + imag_deg**2, min=1e-7)
|
160 |
+
).transpose(1, 0)
|
161 |
+
sc_loss = torch.norm(mag_ref - mag_deg, p="fro") / torch.norm(
|
162 |
+
mag_ref, p="fro"
|
163 |
+
)
|
164 |
+
mag_loss = l1Loss(torch.log(mag_ref), torch.log(mag_deg))
|
165 |
+
|
166 |
+
mstft_sc += sc_loss
|
167 |
+
mstft_mag += mag_loss
|
168 |
+
|
169 |
+
freq_resolution = fs / n_fft
|
170 |
+
mid_freq_index = 1 + int(np.floor(mid_freq / freq_resolution))
|
171 |
+
high_freq_index = 1 + int(np.floor(high_freq / freq_resolution))
|
172 |
+
|
173 |
+
mag_ref_low = mag_ref[:, :mid_freq_index]
|
174 |
+
mag_deg_low = mag_deg[:, :mid_freq_index]
|
175 |
+
sc_loss_low = torch.norm(mag_ref_low - mag_deg_low, p="fro") / torch.norm(
|
176 |
+
mag_ref_low, p="fro"
|
177 |
+
)
|
178 |
+
mag_loss_low = l1Loss(torch.log(mag_ref_low), torch.log(mag_deg_low))
|
179 |
+
|
180 |
+
mstft_sc_low += sc_loss_low
|
181 |
+
mstft_mag_low += mag_loss_low
|
182 |
+
|
183 |
+
mag_ref_mid = mag_ref[:, mid_freq_index:high_freq_index]
|
184 |
+
mag_deg_mid = mag_deg[:, mid_freq_index:high_freq_index]
|
185 |
+
sc_loss_mid = torch.norm(mag_ref_mid - mag_deg_mid, p="fro") / torch.norm(
|
186 |
+
mag_ref_mid, p="fro"
|
187 |
+
)
|
188 |
+
mag_loss_mid = l1Loss(torch.log(mag_ref_mid), torch.log(mag_deg_mid))
|
189 |
+
|
190 |
+
mstft_sc_mid += sc_loss_mid
|
191 |
+
mstft_mag_mid += mag_loss_mid
|
192 |
+
|
193 |
+
mag_ref_high = mag_ref[:, high_freq_index:]
|
194 |
+
mag_deg_high = mag_deg[:, high_freq_index:]
|
195 |
+
sc_loss_high = torch.norm(
|
196 |
+
mag_ref_high - mag_deg_high, p="fro"
|
197 |
+
) / torch.norm(mag_ref_high, p="fro")
|
198 |
+
mag_loss_high = l1Loss(torch.log(mag_ref_high), torch.log(mag_deg_high))
|
199 |
+
|
200 |
+
mstft_sc_high += sc_loss_high
|
201 |
+
mstft_mag_high += mag_loss_high
|
202 |
+
|
203 |
+
# Normalize distances
|
204 |
+
mstft_sc /= len(fft_sizes)
|
205 |
+
mstft_sc_low /= len(fft_sizes)
|
206 |
+
mstft_sc_mid /= len(fft_sizes)
|
207 |
+
mstft_sc_high /= len(fft_sizes)
|
208 |
+
|
209 |
+
mstft_mag /= len(fft_sizes)
|
210 |
+
mstft_mag_low /= len(fft_sizes)
|
211 |
+
mstft_mag_mid /= len(fft_sizes)
|
212 |
+
mstft_mag_high /= len(fft_sizes)
|
213 |
+
|
214 |
+
# return (
|
215 |
+
# mstft_sc.numpy().tolist(),
|
216 |
+
# mstft_sc_low.numpy().tolist(),
|
217 |
+
# mstft_sc_mid.numpy().tolist(),
|
218 |
+
# mstft_sc_high.numpy().tolist(),
|
219 |
+
# mstft_mag.numpy().tolist(),
|
220 |
+
# mstft_mag_low.numpy().tolist(),
|
221 |
+
# mstft_mag_mid.numpy().tolist(),
|
222 |
+
# mstft_mag_high.numpy().tolist(),
|
223 |
+
# )
|
224 |
+
|
225 |
+
return mstft_sc.numpy().tolist() + mstft_mag.numpy().tolist()
|
evaluation/metrics/spectrogram/pesq.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from pypesq import pesq
|
11 |
+
|
12 |
+
|
13 |
+
def extract_pesq(audio_ref, audio_deg, fs=None, method="cut"):
|
14 |
+
"""Extract PESQ for a two given audio.
|
15 |
+
audio1: the given reference audio. It is a numpy array.
|
16 |
+
audio2: the given synthesized audio. It is a numpy array.
|
17 |
+
fs: sampling rate.
|
18 |
+
method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
|
19 |
+
"cut" will cut both audios into a same length according to the one with the shorter length.
|
20 |
+
"""
|
21 |
+
# Load audio
|
22 |
+
if fs != None:
|
23 |
+
audio_ref, _ = librosa.load(audio_ref, sr=fs)
|
24 |
+
audio_deg, _ = librosa.load(audio_deg, sr=fs)
|
25 |
+
else:
|
26 |
+
audio_ref, fs = librosa.load(audio_ref)
|
27 |
+
audio_deg, fs = librosa.load(audio_deg)
|
28 |
+
|
29 |
+
# Resample
|
30 |
+
if fs != 16000:
|
31 |
+
audio_ref = librosa.resample(audio_ref, orig_sr=fs, target_sr=16000)
|
32 |
+
audio_deg = librosa.resample(audio_deg, orig_sr=fs, target_sr=16000)
|
33 |
+
fs = 16000
|
34 |
+
|
35 |
+
# Audio length alignment
|
36 |
+
if len(audio_ref) != len(audio_deg):
|
37 |
+
if method == "cut":
|
38 |
+
length = min(len(audio_ref), len(audio_deg))
|
39 |
+
audio_ref = audio_ref[:length]
|
40 |
+
audio_deg = audio_deg[:length]
|
41 |
+
elif method == "dtw":
|
42 |
+
_, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
|
43 |
+
audio_ref_new = []
|
44 |
+
audio_deg_new = []
|
45 |
+
for i in range(wp.shape[0]):
|
46 |
+
ref_index = wp[i][0]
|
47 |
+
deg_index = wp[i][1]
|
48 |
+
audio_ref_new.append(audio_ref[ref_index])
|
49 |
+
audio_deg_new.append(audio_deg[deg_index])
|
50 |
+
audio_ref = np.array(audio_ref_new)
|
51 |
+
audio_deg = np.array(audio_deg_new)
|
52 |
+
assert len(audio_ref) == len(audio_deg)
|
53 |
+
|
54 |
+
# Compute pesq
|
55 |
+
score = pesq(audio_ref, audio_deg, fs)
|
56 |
+
return score
|
evaluation/metrics/spectrogram/scale_invariant_signal_to_distortion_ratio.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import librosa
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torchmetrics import ScaleInvariantSignalDistortionRatio
|
12 |
+
|
13 |
+
|
14 |
+
def extract_si_sdr(audio_ref, audio_deg, fs=None, method="cut"):
|
15 |
+
si_sdr = ScaleInvariantSignalDistortionRatio()
|
16 |
+
|
17 |
+
if fs != None:
|
18 |
+
audio_ref, _ = librosa.load(audio_ref, sr=fs)
|
19 |
+
audio_deg, _ = librosa.load(audio_deg, sr=fs)
|
20 |
+
else:
|
21 |
+
audio_ref, fs = librosa.load(audio_ref)
|
22 |
+
audio_deg, fs = librosa.load(audio_deg)
|
23 |
+
|
24 |
+
if len(audio_ref) != len(audio_deg):
|
25 |
+
if method == "cut":
|
26 |
+
length = min(len(audio_ref), len(audio_deg))
|
27 |
+
audio_ref = audio_ref[:length]
|
28 |
+
audio_deg = audio_deg[:length]
|
29 |
+
elif method == "dtw":
|
30 |
+
_, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
|
31 |
+
audio_ref_new = []
|
32 |
+
audio_deg_new = []
|
33 |
+
for i in range(wp.shape[0]):
|
34 |
+
ref_index = wp[i][0]
|
35 |
+
deg_index = wp[i][1]
|
36 |
+
audio_ref_new.append(audio_ref[ref_index])
|
37 |
+
audio_deg_new.append(audio_deg[deg_index])
|
38 |
+
audio_ref = np.array(audio_ref_new)
|
39 |
+
audio_deg = np.array(audio_deg_new)
|
40 |
+
assert len(audio_ref) == len(audio_deg)
|
41 |
+
|
42 |
+
audio_ref = torch.from_numpy(audio_ref)
|
43 |
+
audio_deg = torch.from_numpy(audio_deg)
|
44 |
+
|
45 |
+
return si_sdr(audio_deg, audio_ref)
|
evaluation/metrics/spectrogram/scale_invariant_signal_to_noise_ratio.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import librosa
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torchmetrics import ScaleInvariantSignalNoiseRatio
|
12 |
+
|
13 |
+
|
14 |
+
def extract_si_snr(audio_ref, audio_deg, fs=None, method="cut"):
|
15 |
+
si_snr = ScaleInvariantSignalNoiseRatio()
|
16 |
+
|
17 |
+
if fs != None:
|
18 |
+
audio_ref, _ = librosa.load(audio_ref, sr=fs)
|
19 |
+
audio_deg, _ = librosa.load(audio_deg, sr=fs)
|
20 |
+
else:
|
21 |
+
audio_ref, fs = librosa.load(audio_ref)
|
22 |
+
audio_deg, fs = librosa.load(audio_deg)
|
23 |
+
|
24 |
+
if len(audio_ref) != len(audio_deg):
|
25 |
+
if method == "cut":
|
26 |
+
length = min(len(audio_ref), len(audio_deg))
|
27 |
+
audio_ref = audio_ref[:length]
|
28 |
+
audio_deg = audio_deg[:length]
|
29 |
+
elif method == "dtw":
|
30 |
+
_, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
|
31 |
+
audio_ref_new = []
|
32 |
+
audio_deg_new = []
|
33 |
+
for i in range(wp.shape[0]):
|
34 |
+
ref_index = wp[i][0]
|
35 |
+
deg_index = wp[i][1]
|
36 |
+
audio_ref_new.append(audio_ref[ref_index])
|
37 |
+
audio_deg_new.append(audio_deg[deg_index])
|
38 |
+
audio_ref = np.array(audio_ref_new)
|
39 |
+
audio_deg = np.array(audio_deg_new)
|
40 |
+
assert len(audio_ref) == len(audio_deg)
|
41 |
+
|
42 |
+
audio_ref = torch.from_numpy(audio_ref)
|
43 |
+
audio_deg = torch.from_numpy(audio_deg)
|
44 |
+
|
45 |
+
return si_snr(audio_deg, audio_ref)
|
evaluation/metrics/spectrogram/short_time_objective_intelligibility.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import librosa
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
12 |
+
|
13 |
+
|
14 |
+
def extract_stoi(audio_ref, audio_deg, fs=None, extended=False, method="cut"):
|
15 |
+
"""Compute Short-Time Objective Intelligibility between the predicted and the ground truth audio.
|
16 |
+
audio_ref: path to the ground truth audio.
|
17 |
+
audio_deg: path to the predicted audio.
|
18 |
+
fs: sampling rate.
|
19 |
+
method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
|
20 |
+
"cut" will cut both audios into a same length according to the one with the shorter length.
|
21 |
+
"""
|
22 |
+
# Load audio
|
23 |
+
if fs != None:
|
24 |
+
audio_ref, _ = librosa.load(audio_ref, sr=fs)
|
25 |
+
audio_deg, _ = librosa.load(audio_deg, sr=fs)
|
26 |
+
else:
|
27 |
+
audio_ref, fs = librosa.load(audio_ref)
|
28 |
+
audio_deg, fs = librosa.load(audio_deg)
|
29 |
+
|
30 |
+
# Initialize method
|
31 |
+
stoi = ShortTimeObjectiveIntelligibility(fs, extended)
|
32 |
+
|
33 |
+
# Audio length alignment
|
34 |
+
if len(audio_ref) != len(audio_deg):
|
35 |
+
if method == "cut":
|
36 |
+
length = min(len(audio_ref), len(audio_deg))
|
37 |
+
audio_ref = audio_ref[:length]
|
38 |
+
audio_deg = audio_deg[:length]
|
39 |
+
elif method == "dtw":
|
40 |
+
_, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
|
41 |
+
audio_ref_new = []
|
42 |
+
audio_deg_new = []
|
43 |
+
for i in range(wp.shape[0]):
|
44 |
+
ref_index = wp[i][0]
|
45 |
+
deg_index = wp[i][1]
|
46 |
+
audio_ref_new.append(audio_ref[ref_index])
|
47 |
+
audio_deg_new.append(audio_deg[deg_index])
|
48 |
+
audio_ref = np.array(audio_ref_new)
|
49 |
+
audio_deg = np.array(audio_deg_new)
|
50 |
+
assert len(audio_ref) == len(audio_deg)
|
51 |
+
|
52 |
+
# Convert to tensor
|
53 |
+
audio_ref = torch.from_numpy(audio_ref)
|
54 |
+
audio_deg = torch.from_numpy(audio_deg)
|
55 |
+
|
56 |
+
return stoi(audio_deg, audio_ref).numpy().tolist()
|
models/tts/base/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# from .tts_inferece import TTSInference
|
7 |
+
from .tts_trainer import TTSTrainer
|
models/tts/base/tts_dataset.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import torchaudio
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from utils.data_utils import *
|
12 |
+
from torch.nn.utils.rnn import pad_sequence
|
13 |
+
from text import text_to_sequence
|
14 |
+
from text.text_token_collation import phoneIDCollation
|
15 |
+
from processors.acoustic_extractor import cal_normalized_mel
|
16 |
+
|
17 |
+
from models.base.base_dataset import (
|
18 |
+
BaseDataset,
|
19 |
+
BaseCollator,
|
20 |
+
BaseTestDataset,
|
21 |
+
BaseTestCollator,
|
22 |
+
)
|
23 |
+
|
24 |
+
from processors.content_extractor import (
|
25 |
+
ContentvecExtractor,
|
26 |
+
WenetExtractor,
|
27 |
+
WhisperExtractor,
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
class TTSDataset(BaseDataset):
|
32 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
cfg: config
|
36 |
+
dataset: dataset name
|
37 |
+
is_valid: whether to use train or valid dataset
|
38 |
+
"""
|
39 |
+
|
40 |
+
assert isinstance(dataset, str)
|
41 |
+
|
42 |
+
self.cfg = cfg
|
43 |
+
|
44 |
+
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
|
45 |
+
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
|
46 |
+
self.metafile_path = os.path.join(processed_data_dir, meta_file)
|
47 |
+
self.metadata = self.get_metadata()
|
48 |
+
|
49 |
+
"""
|
50 |
+
load spk2id and utt2spk from json file
|
51 |
+
spk2id: {spk1: 0, spk2: 1, ...}
|
52 |
+
utt2spk: {dataset_uid: spk1, ...}
|
53 |
+
"""
|
54 |
+
if cfg.preprocess.use_spkid:
|
55 |
+
dataset = self.metadata[0]["Dataset"]
|
56 |
+
|
57 |
+
spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
|
58 |
+
with open(spk2id_path, "r") as f:
|
59 |
+
self.spk2id = json.load(f)
|
60 |
+
|
61 |
+
utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
|
62 |
+
self.utt2spk = dict()
|
63 |
+
with open(utt2spk_path, "r") as f:
|
64 |
+
for line in f.readlines():
|
65 |
+
utt, spk = line.strip().split("\t")
|
66 |
+
self.utt2spk[utt] = spk
|
67 |
+
|
68 |
+
if cfg.preprocess.use_uv:
|
69 |
+
self.utt2uv_path = {}
|
70 |
+
for utt_info in self.metadata:
|
71 |
+
dataset = utt_info["Dataset"]
|
72 |
+
uid = utt_info["Uid"]
|
73 |
+
utt = "{}_{}".format(dataset, uid)
|
74 |
+
self.utt2uv_path[utt] = os.path.join(
|
75 |
+
cfg.preprocess.processed_dir,
|
76 |
+
dataset,
|
77 |
+
cfg.preprocess.uv_dir,
|
78 |
+
uid + ".npy",
|
79 |
+
)
|
80 |
+
|
81 |
+
if cfg.preprocess.use_frame_pitch:
|
82 |
+
self.utt2frame_pitch_path = {}
|
83 |
+
for utt_info in self.metadata:
|
84 |
+
dataset = utt_info["Dataset"]
|
85 |
+
uid = utt_info["Uid"]
|
86 |
+
utt = "{}_{}".format(dataset, uid)
|
87 |
+
|
88 |
+
self.utt2frame_pitch_path[utt] = os.path.join(
|
89 |
+
cfg.preprocess.processed_dir,
|
90 |
+
dataset,
|
91 |
+
cfg.preprocess.pitch_dir,
|
92 |
+
uid + ".npy",
|
93 |
+
)
|
94 |
+
|
95 |
+
if cfg.preprocess.use_frame_energy:
|
96 |
+
self.utt2frame_energy_path = {}
|
97 |
+
for utt_info in self.metadata:
|
98 |
+
dataset = utt_info["Dataset"]
|
99 |
+
uid = utt_info["Uid"]
|
100 |
+
utt = "{}_{}".format(dataset, uid)
|
101 |
+
|
102 |
+
self.utt2frame_energy_path[utt] = os.path.join(
|
103 |
+
cfg.preprocess.processed_dir,
|
104 |
+
dataset,
|
105 |
+
cfg.preprocess.energy_dir,
|
106 |
+
uid + ".npy",
|
107 |
+
)
|
108 |
+
|
109 |
+
if cfg.preprocess.use_mel:
|
110 |
+
self.utt2mel_path = {}
|
111 |
+
for utt_info in self.metadata:
|
112 |
+
dataset = utt_info["Dataset"]
|
113 |
+
uid = utt_info["Uid"]
|
114 |
+
utt = "{}_{}".format(dataset, uid)
|
115 |
+
|
116 |
+
self.utt2mel_path[utt] = os.path.join(
|
117 |
+
cfg.preprocess.processed_dir,
|
118 |
+
dataset,
|
119 |
+
cfg.preprocess.mel_dir,
|
120 |
+
uid + ".npy",
|
121 |
+
)
|
122 |
+
|
123 |
+
if cfg.preprocess.use_linear:
|
124 |
+
self.utt2linear_path = {}
|
125 |
+
for utt_info in self.metadata:
|
126 |
+
dataset = utt_info["Dataset"]
|
127 |
+
uid = utt_info["Uid"]
|
128 |
+
utt = "{}_{}".format(dataset, uid)
|
129 |
+
|
130 |
+
self.utt2linear_path[utt] = os.path.join(
|
131 |
+
cfg.preprocess.processed_dir,
|
132 |
+
dataset,
|
133 |
+
cfg.preprocess.linear_dir,
|
134 |
+
uid + ".npy",
|
135 |
+
)
|
136 |
+
|
137 |
+
if cfg.preprocess.use_audio:
|
138 |
+
self.utt2audio_path = {}
|
139 |
+
for utt_info in self.metadata:
|
140 |
+
dataset = utt_info["Dataset"]
|
141 |
+
uid = utt_info["Uid"]
|
142 |
+
utt = "{}_{}".format(dataset, uid)
|
143 |
+
|
144 |
+
if cfg.preprocess.extract_audio:
|
145 |
+
self.utt2audio_path[utt] = os.path.join(
|
146 |
+
cfg.preprocess.processed_dir,
|
147 |
+
dataset,
|
148 |
+
cfg.preprocess.audio_dir,
|
149 |
+
uid + ".wav",
|
150 |
+
)
|
151 |
+
else:
|
152 |
+
self.utt2audio_path[utt] = utt_info["Path"]
|
153 |
+
|
154 |
+
# self.utt2audio_path[utt] = os.path.join(
|
155 |
+
# cfg.preprocess.processed_dir,
|
156 |
+
# dataset,
|
157 |
+
# cfg.preprocess.audio_dir,
|
158 |
+
# uid + ".numpy",
|
159 |
+
# )
|
160 |
+
|
161 |
+
elif cfg.preprocess.use_label:
|
162 |
+
self.utt2label_path = {}
|
163 |
+
for utt_info in self.metadata:
|
164 |
+
dataset = utt_info["Dataset"]
|
165 |
+
uid = utt_info["Uid"]
|
166 |
+
utt = "{}_{}".format(dataset, uid)
|
167 |
+
|
168 |
+
self.utt2label_path[utt] = os.path.join(
|
169 |
+
cfg.preprocess.processed_dir,
|
170 |
+
dataset,
|
171 |
+
cfg.preprocess.label_dir,
|
172 |
+
uid + ".npy",
|
173 |
+
)
|
174 |
+
elif cfg.preprocess.use_one_hot:
|
175 |
+
self.utt2one_hot_path = {}
|
176 |
+
for utt_info in self.metadata:
|
177 |
+
dataset = utt_info["Dataset"]
|
178 |
+
uid = utt_info["Uid"]
|
179 |
+
utt = "{}_{}".format(dataset, uid)
|
180 |
+
|
181 |
+
self.utt2one_hot_path[utt] = os.path.join(
|
182 |
+
cfg.preprocess.processed_dir,
|
183 |
+
dataset,
|
184 |
+
cfg.preprocess.one_hot_dir,
|
185 |
+
uid + ".npy",
|
186 |
+
)
|
187 |
+
|
188 |
+
if cfg.preprocess.use_text or cfg.preprocess.use_phone:
|
189 |
+
self.utt2seq = {}
|
190 |
+
for utt_info in self.metadata:
|
191 |
+
dataset = utt_info["Dataset"]
|
192 |
+
uid = utt_info["Uid"]
|
193 |
+
utt = "{}_{}".format(dataset, uid)
|
194 |
+
|
195 |
+
if cfg.preprocess.use_text:
|
196 |
+
text = utt_info["Text"]
|
197 |
+
sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
|
198 |
+
elif cfg.preprocess.use_phone:
|
199 |
+
# load phoneme squence from phone file
|
200 |
+
phone_path = os.path.join(
|
201 |
+
processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
|
202 |
+
)
|
203 |
+
with open(phone_path, "r") as fin:
|
204 |
+
phones = fin.readlines()
|
205 |
+
assert len(phones) == 1
|
206 |
+
phones = phones[0].strip()
|
207 |
+
phones_seq = phones.split(" ")
|
208 |
+
|
209 |
+
phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
|
210 |
+
sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
|
211 |
+
|
212 |
+
self.utt2seq[utt] = sequence
|
213 |
+
|
214 |
+
def __getitem__(self, index):
|
215 |
+
utt_info = self.metadata[index]
|
216 |
+
|
217 |
+
dataset = utt_info["Dataset"]
|
218 |
+
uid = utt_info["Uid"]
|
219 |
+
utt = "{}_{}".format(dataset, uid)
|
220 |
+
|
221 |
+
single_feature = dict()
|
222 |
+
|
223 |
+
if self.cfg.preprocess.use_spkid:
|
224 |
+
single_feature["spk_id"] = np.array(
|
225 |
+
[self.spk2id[self.utt2spk[utt]]], dtype=np.int32
|
226 |
+
)
|
227 |
+
|
228 |
+
if self.cfg.preprocess.use_mel:
|
229 |
+
mel = np.load(self.utt2mel_path[utt])
|
230 |
+
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
|
231 |
+
if self.cfg.preprocess.use_min_max_norm_mel:
|
232 |
+
# do mel norm
|
233 |
+
mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
|
234 |
+
|
235 |
+
if "target_len" not in single_feature.keys():
|
236 |
+
single_feature["target_len"] = mel.shape[1]
|
237 |
+
single_feature["mel"] = mel.T # [T, n_mels]
|
238 |
+
|
239 |
+
if self.cfg.preprocess.use_linear:
|
240 |
+
linear = np.load(self.utt2linear_path[utt])
|
241 |
+
if "target_len" not in single_feature.keys():
|
242 |
+
single_feature["target_len"] = linear.shape[1]
|
243 |
+
single_feature["linear"] = linear.T # [T, n_linear]
|
244 |
+
|
245 |
+
if self.cfg.preprocess.use_frame_pitch:
|
246 |
+
frame_pitch_path = self.utt2frame_pitch_path[utt]
|
247 |
+
frame_pitch = np.load(frame_pitch_path)
|
248 |
+
if "target_len" not in single_feature.keys():
|
249 |
+
single_feature["target_len"] = len(frame_pitch)
|
250 |
+
aligned_frame_pitch = align_length(
|
251 |
+
frame_pitch, single_feature["target_len"]
|
252 |
+
)
|
253 |
+
single_feature["frame_pitch"] = aligned_frame_pitch
|
254 |
+
|
255 |
+
if self.cfg.preprocess.use_uv:
|
256 |
+
frame_uv_path = self.utt2uv_path[utt]
|
257 |
+
frame_uv = np.load(frame_uv_path)
|
258 |
+
aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
|
259 |
+
aligned_frame_uv = [
|
260 |
+
0 if frame_uv else 1 for frame_uv in aligned_frame_uv
|
261 |
+
]
|
262 |
+
aligned_frame_uv = np.array(aligned_frame_uv)
|
263 |
+
single_feature["frame_uv"] = aligned_frame_uv
|
264 |
+
|
265 |
+
if self.cfg.preprocess.use_frame_energy:
|
266 |
+
frame_energy_path = self.utt2frame_energy_path[utt]
|
267 |
+
frame_energy = np.load(frame_energy_path)
|
268 |
+
if "target_len" not in single_feature.keys():
|
269 |
+
single_feature["target_len"] = len(frame_energy)
|
270 |
+
aligned_frame_energy = align_length(
|
271 |
+
frame_energy, single_feature["target_len"]
|
272 |
+
)
|
273 |
+
single_feature["frame_energy"] = aligned_frame_energy
|
274 |
+
|
275 |
+
if self.cfg.preprocess.use_audio:
|
276 |
+
audio, sr = torchaudio.load(self.utt2audio_path[utt])
|
277 |
+
audio = audio.cpu().numpy().squeeze()
|
278 |
+
single_feature["audio"] = audio
|
279 |
+
single_feature["audio_len"] = audio.shape[0]
|
280 |
+
|
281 |
+
if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
|
282 |
+
single_feature["phone_seq"] = np.array(self.utt2seq[utt])
|
283 |
+
single_feature["phone_len"] = len(self.utt2seq[utt])
|
284 |
+
|
285 |
+
return single_feature
|
286 |
+
|
287 |
+
def __len__(self):
|
288 |
+
return super().__len__()
|
289 |
+
|
290 |
+
def get_metadata(self):
|
291 |
+
return super().get_metadata()
|
292 |
+
|
293 |
+
|
294 |
+
class TTSCollator(BaseCollator):
|
295 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
296 |
+
|
297 |
+
def __init__(self, cfg):
|
298 |
+
super().__init__(cfg)
|
299 |
+
|
300 |
+
def __call__(self, batch):
|
301 |
+
parsed_batch_features = super().__call__(batch)
|
302 |
+
return parsed_batch_features
|
303 |
+
|
304 |
+
|
305 |
+
class TTSTestDataset(BaseTestDataset):
|
306 |
+
def __init__(self, args, cfg):
|
307 |
+
self.cfg = cfg
|
308 |
+
|
309 |
+
# inference from test list file
|
310 |
+
if args.test_list_file is not None:
|
311 |
+
# construst metadata
|
312 |
+
self.metadata = []
|
313 |
+
|
314 |
+
with open(args.test_list_file, "r") as fin:
|
315 |
+
for idx, line in enumerate(fin.readlines()):
|
316 |
+
utt_info = {}
|
317 |
+
|
318 |
+
utt_info["Dataset"] = "test"
|
319 |
+
utt_info["Text"] = line.strip()
|
320 |
+
utt_info["Uid"] = str(idx)
|
321 |
+
self.metadata.append(utt_info)
|
322 |
+
|
323 |
+
else:
|
324 |
+
assert args.testing_set
|
325 |
+
self.metafile_path = os.path.join(
|
326 |
+
cfg.preprocess.processed_dir,
|
327 |
+
args.dataset,
|
328 |
+
"{}.json".format(args.testing_set),
|
329 |
+
)
|
330 |
+
self.metadata = self.get_metadata()
|
331 |
+
|
332 |
+
def __getitem__(self, index):
|
333 |
+
single_feature = {}
|
334 |
+
|
335 |
+
return single_feature
|
336 |
+
|
337 |
+
def __len__(self):
|
338 |
+
return len(self.metadata)
|
339 |
+
|
340 |
+
|
341 |
+
class TTSTestCollator(BaseTestCollator):
|
342 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
343 |
+
|
344 |
+
def __init__(self, cfg):
|
345 |
+
self.cfg = cfg
|
346 |
+
|
347 |
+
def __call__(self, batch):
|
348 |
+
packed_batch_features = dict()
|
349 |
+
|
350 |
+
# mel: [b, T, n_mels]
|
351 |
+
# frame_pitch, frame_energy: [1, T]
|
352 |
+
# target_len: [1]
|
353 |
+
# spk_id: [b, 1]
|
354 |
+
# mask: [b, T, 1]
|
355 |
+
|
356 |
+
for key in batch[0].keys():
|
357 |
+
if key == "target_len":
|
358 |
+
packed_batch_features["target_len"] = torch.LongTensor(
|
359 |
+
[b["target_len"] for b in batch]
|
360 |
+
)
|
361 |
+
masks = [
|
362 |
+
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
|
363 |
+
]
|
364 |
+
packed_batch_features["mask"] = pad_sequence(
|
365 |
+
masks, batch_first=True, padding_value=0
|
366 |
+
)
|
367 |
+
elif key == "phone_len":
|
368 |
+
packed_batch_features["phone_len"] = torch.LongTensor(
|
369 |
+
[b["phone_len"] for b in batch]
|
370 |
+
)
|
371 |
+
masks = [
|
372 |
+
torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
|
373 |
+
]
|
374 |
+
packed_batch_features["phn_mask"] = pad_sequence(
|
375 |
+
masks, batch_first=True, padding_value=0
|
376 |
+
)
|
377 |
+
elif key == "audio_len":
|
378 |
+
packed_batch_features["audio_len"] = torch.LongTensor(
|
379 |
+
[b["audio_len"] for b in batch]
|
380 |
+
)
|
381 |
+
masks = [
|
382 |
+
torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
|
383 |
+
]
|
384 |
+
else:
|
385 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
386 |
+
packed_batch_features[key] = pad_sequence(
|
387 |
+
values, batch_first=True, padding_value=0
|
388 |
+
)
|
389 |
+
return packed_batch_features
|
models/tts/base/tts_inferece.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import time
|
9 |
+
import accelerate
|
10 |
+
import random
|
11 |
+
import numpy as np
|
12 |
+
from tqdm import tqdm
|
13 |
+
from accelerate.logging import get_logger
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
|
16 |
+
|
17 |
+
from abc import abstractmethod
|
18 |
+
from pathlib import Path
|
19 |
+
from utils.io import save_audio
|
20 |
+
from utils.util import load_config
|
21 |
+
from models.vocoders.vocoder_inference import synthesis
|
22 |
+
|
23 |
+
|
24 |
+
class TTSInference(object):
|
25 |
+
def __init__(self, args=None, cfg=None):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
start = time.monotonic_ns()
|
29 |
+
self.args = args
|
30 |
+
self.cfg = cfg
|
31 |
+
self.infer_type = args.mode
|
32 |
+
|
33 |
+
# get exp_dir
|
34 |
+
if self.args.acoustics_dir is not None:
|
35 |
+
self.exp_dir = self.args.acoustics_dir
|
36 |
+
elif self.args.checkpoint_path is not None:
|
37 |
+
self.exp_dir = os.path.dirname(os.path.dirname(self.args.checkpoint_path))
|
38 |
+
|
39 |
+
# Init accelerator
|
40 |
+
self.accelerator = accelerate.Accelerator()
|
41 |
+
self.accelerator.wait_for_everyone()
|
42 |
+
self.device = self.accelerator.device
|
43 |
+
|
44 |
+
# Get logger
|
45 |
+
with self.accelerator.main_process_first():
|
46 |
+
self.logger = get_logger("inference", log_level=args.log_level)
|
47 |
+
|
48 |
+
# Log some info
|
49 |
+
self.logger.info("=" * 56)
|
50 |
+
self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
|
51 |
+
self.logger.info("=" * 56)
|
52 |
+
self.logger.info("\n")
|
53 |
+
|
54 |
+
self.acoustic_model_dir = args.acoustics_dir
|
55 |
+
self.logger.debug(f"Acoustic model dir: {args.acoustics_dir}")
|
56 |
+
|
57 |
+
if args.vocoder_dir is not None:
|
58 |
+
self.vocoder_dir = args.vocoder_dir
|
59 |
+
self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
|
60 |
+
|
61 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
62 |
+
|
63 |
+
# Set random seed
|
64 |
+
with self.accelerator.main_process_first():
|
65 |
+
start = time.monotonic_ns()
|
66 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
67 |
+
end = time.monotonic_ns()
|
68 |
+
self.logger.debug(
|
69 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
70 |
+
)
|
71 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
72 |
+
|
73 |
+
# Setup data loader
|
74 |
+
if self.infer_type == "batch":
|
75 |
+
with self.accelerator.main_process_first():
|
76 |
+
self.logger.info("Building dataset...")
|
77 |
+
start = time.monotonic_ns()
|
78 |
+
self.test_dataloader = self._build_test_dataloader()
|
79 |
+
end = time.monotonic_ns()
|
80 |
+
self.logger.info(
|
81 |
+
f"Building dataset done in {(end - start) / 1e6:.2f}ms"
|
82 |
+
)
|
83 |
+
|
84 |
+
# Build model
|
85 |
+
with self.accelerator.main_process_first():
|
86 |
+
self.logger.info("Building model...")
|
87 |
+
start = time.monotonic_ns()
|
88 |
+
self.model = self._build_model()
|
89 |
+
end = time.monotonic_ns()
|
90 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
|
91 |
+
|
92 |
+
# Init with accelerate
|
93 |
+
self.logger.info("Initializing accelerate...")
|
94 |
+
start = time.monotonic_ns()
|
95 |
+
self.accelerator = accelerate.Accelerator()
|
96 |
+
self.model = self.accelerator.prepare(self.model)
|
97 |
+
if self.infer_type == "batch":
|
98 |
+
self.test_dataloader = self.accelerator.prepare(self.test_dataloader)
|
99 |
+
end = time.monotonic_ns()
|
100 |
+
self.accelerator.wait_for_everyone()
|
101 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
|
102 |
+
|
103 |
+
with self.accelerator.main_process_first():
|
104 |
+
self.logger.info("Loading checkpoint...")
|
105 |
+
start = time.monotonic_ns()
|
106 |
+
if args.acoustics_dir is not None:
|
107 |
+
self._load_model(
|
108 |
+
checkpoint_dir=os.path.join(args.acoustics_dir, "checkpoint")
|
109 |
+
)
|
110 |
+
elif args.checkpoint_path is not None:
|
111 |
+
self._load_model(checkpoint_path=args.checkpoint_path)
|
112 |
+
else:
|
113 |
+
print("Either checkpoint dir or checkpoint path should be provided.")
|
114 |
+
|
115 |
+
end = time.monotonic_ns()
|
116 |
+
self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
|
117 |
+
|
118 |
+
self.model.eval()
|
119 |
+
self.accelerator.wait_for_everyone()
|
120 |
+
|
121 |
+
def _build_test_dataset(self):
|
122 |
+
pass
|
123 |
+
|
124 |
+
def _build_model(self):
|
125 |
+
pass
|
126 |
+
|
127 |
+
# TODO: LEGACY CODE
|
128 |
+
def _build_test_dataloader(self):
|
129 |
+
datasets, collate = self._build_test_dataset()
|
130 |
+
self.test_dataset = datasets(self.args, self.cfg)
|
131 |
+
self.test_collate = collate(self.cfg)
|
132 |
+
self.test_batch_size = min(
|
133 |
+
self.cfg.train.batch_size, len(self.test_dataset.metadata)
|
134 |
+
)
|
135 |
+
test_dataloader = DataLoader(
|
136 |
+
self.test_dataset,
|
137 |
+
collate_fn=self.test_collate,
|
138 |
+
num_workers=1,
|
139 |
+
batch_size=self.test_batch_size,
|
140 |
+
shuffle=False,
|
141 |
+
)
|
142 |
+
return test_dataloader
|
143 |
+
|
144 |
+
def _load_model(
|
145 |
+
self,
|
146 |
+
checkpoint_dir: str = None,
|
147 |
+
checkpoint_path: str = None,
|
148 |
+
old_mode: bool = False,
|
149 |
+
):
|
150 |
+
r"""Load model from checkpoint. If checkpoint_path is None, it will
|
151 |
+
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
|
152 |
+
None, it will load the checkpoint specified by checkpoint_path. **Only use this
|
153 |
+
method after** ``accelerator.prepare()``.
|
154 |
+
"""
|
155 |
+
|
156 |
+
if checkpoint_path is None:
|
157 |
+
assert checkpoint_dir is not None
|
158 |
+
# Load the latest accelerator state dicts
|
159 |
+
ls = [
|
160 |
+
str(i) for i in Path(checkpoint_dir).glob("*") if not "audio" in str(i)
|
161 |
+
]
|
162 |
+
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
|
163 |
+
checkpoint_path = ls[0]
|
164 |
+
|
165 |
+
self.accelerator.load_state(str(checkpoint_path))
|
166 |
+
return str(checkpoint_path)
|
167 |
+
|
168 |
+
def inference(self):
|
169 |
+
if self.infer_type == "single":
|
170 |
+
out_dir = os.path.join(self.args.output_dir, "single")
|
171 |
+
os.makedirs(out_dir, exist_ok=True)
|
172 |
+
|
173 |
+
pred_audio = self.inference_for_single_utterance()
|
174 |
+
save_path = os.path.join(out_dir, "test_pred.wav")
|
175 |
+
save_audio(save_path, pred_audio, self.cfg.preprocess.sample_rate)
|
176 |
+
|
177 |
+
elif self.infer_type == "batch":
|
178 |
+
out_dir = os.path.join(self.args.output_dir, "batch")
|
179 |
+
os.makedirs(out_dir, exist_ok=True)
|
180 |
+
|
181 |
+
pred_audio_list = self.inference_for_batches()
|
182 |
+
for it, wav in zip(self.test_dataset.metadata, pred_audio_list):
|
183 |
+
uid = it["Uid"]
|
184 |
+
save_audio(
|
185 |
+
os.path.join(out_dir, f"{uid}.wav"),
|
186 |
+
wav.numpy(),
|
187 |
+
self.cfg.preprocess.sample_rate,
|
188 |
+
add_silence=True,
|
189 |
+
turn_up=True,
|
190 |
+
)
|
191 |
+
tmp_file = os.path.join(out_dir, f"{uid}.pt")
|
192 |
+
if os.path.exists(tmp_file):
|
193 |
+
os.remove(tmp_file)
|
194 |
+
print("Saved to: ", out_dir)
|
195 |
+
|
196 |
+
@torch.inference_mode()
|
197 |
+
def inference_for_batches(self):
|
198 |
+
y_pred = []
|
199 |
+
for i, batch in tqdm(enumerate(self.test_dataloader)):
|
200 |
+
y_pred, mel_lens, _ = self._inference_each_batch(batch)
|
201 |
+
y_ls = y_pred.chunk(self.test_batch_size)
|
202 |
+
tgt_ls = mel_lens.chunk(self.test_batch_size)
|
203 |
+
j = 0
|
204 |
+
for it, l in zip(y_ls, tgt_ls):
|
205 |
+
l = l.item()
|
206 |
+
it = it.squeeze(0)[:l].detach().cpu()
|
207 |
+
|
208 |
+
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
|
209 |
+
torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
|
210 |
+
j += 1
|
211 |
+
|
212 |
+
vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
|
213 |
+
res = synthesis(
|
214 |
+
cfg=vocoder_cfg,
|
215 |
+
vocoder_weight_file=vocoder_ckpt,
|
216 |
+
n_samples=None,
|
217 |
+
pred=[
|
218 |
+
torch.load(
|
219 |
+
os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"]))
|
220 |
+
).numpy()
|
221 |
+
for item in self.test_dataset.metadata
|
222 |
+
],
|
223 |
+
)
|
224 |
+
for it, wav in zip(self.test_dataset.metadata, res):
|
225 |
+
uid = it["Uid"]
|
226 |
+
save_audio(
|
227 |
+
os.path.join(self.args.output_dir, f"{uid}.wav"),
|
228 |
+
wav.numpy(),
|
229 |
+
22050,
|
230 |
+
add_silence=True,
|
231 |
+
turn_up=True,
|
232 |
+
)
|
233 |
+
|
234 |
+
@abstractmethod
|
235 |
+
@torch.inference_mode()
|
236 |
+
def _inference_each_batch(self, batch_data):
|
237 |
+
pass
|
238 |
+
|
239 |
+
def inference_for_single_utterance(self, text):
|
240 |
+
pass
|
241 |
+
|
242 |
+
def synthesis_by_vocoder(self, pred):
|
243 |
+
audios_pred = synthesis(
|
244 |
+
self.vocoder_cfg,
|
245 |
+
self.checkpoint_dir_vocoder,
|
246 |
+
len(pred),
|
247 |
+
pred,
|
248 |
+
)
|
249 |
+
|
250 |
+
return audios_pred
|
251 |
+
|
252 |
+
@staticmethod
|
253 |
+
def _parse_vocoder(vocoder_dir):
|
254 |
+
r"""Parse vocoder config"""
|
255 |
+
vocoder_dir = os.path.abspath(vocoder_dir)
|
256 |
+
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
|
257 |
+
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
|
258 |
+
ckpt_path = str(ckpt_list[0])
|
259 |
+
vocoder_cfg = load_config(
|
260 |
+
os.path.join(vocoder_dir, "args.json"), lowercase=True
|
261 |
+
)
|
262 |
+
return vocoder_cfg, ckpt_path
|
263 |
+
|
264 |
+
def _set_random_seed(self, seed):
|
265 |
+
"""Set random seed for all possible random modules."""
|
266 |
+
random.seed(seed)
|
267 |
+
np.random.seed(seed)
|
268 |
+
torch.random.manual_seed(seed)
|
models/tts/base/tts_trainer.py
ADDED
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
import torch
|
10 |
+
import time
|
11 |
+
from pathlib import Path
|
12 |
+
import torch
|
13 |
+
from tqdm import tqdm
|
14 |
+
import re
|
15 |
+
import logging
|
16 |
+
import json5
|
17 |
+
import accelerate
|
18 |
+
from accelerate.logging import get_logger
|
19 |
+
from accelerate.utils import ProjectConfiguration
|
20 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
21 |
+
from accelerate import DistributedDataParallelKwargs
|
22 |
+
from schedulers.scheduler import Eden
|
23 |
+
from models.base.base_sampler import build_samplers
|
24 |
+
from models.base.new_trainer import BaseTrainer
|
25 |
+
|
26 |
+
|
27 |
+
class TTSTrainer(BaseTrainer):
|
28 |
+
r"""The base trainer for all TTS models. It inherits from BaseTrainer and implements
|
29 |
+
``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
|
30 |
+
class, and implement ``_build_model``, ``_forward_step``.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, args=None, cfg=None):
|
34 |
+
self.args = args
|
35 |
+
self.cfg = cfg
|
36 |
+
|
37 |
+
cfg.exp_name = args.exp_name
|
38 |
+
|
39 |
+
# init with accelerate
|
40 |
+
self._init_accelerator()
|
41 |
+
self.accelerator.wait_for_everyone()
|
42 |
+
|
43 |
+
with self.accelerator.main_process_first():
|
44 |
+
self.logger = get_logger(args.exp_name, log_level="INFO")
|
45 |
+
|
46 |
+
# Log some info
|
47 |
+
self.logger.info("=" * 56)
|
48 |
+
self.logger.info("||\t\t" + "New training process started." + "\t\t||")
|
49 |
+
self.logger.info("=" * 56)
|
50 |
+
self.logger.info("\n")
|
51 |
+
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
|
52 |
+
self.logger.info(f"Experiment name: {args.exp_name}")
|
53 |
+
self.logger.info(f"Experiment directory: {self.exp_dir}")
|
54 |
+
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
|
55 |
+
if self.accelerator.is_main_process:
|
56 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
57 |
+
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
|
58 |
+
|
59 |
+
# init counts
|
60 |
+
self.batch_count: int = 0
|
61 |
+
self.step: int = 0
|
62 |
+
self.epoch: int = 0
|
63 |
+
self.max_epoch = (
|
64 |
+
self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
|
65 |
+
)
|
66 |
+
self.logger.info(
|
67 |
+
"Max epoch: {}".format(
|
68 |
+
self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
|
69 |
+
)
|
70 |
+
)
|
71 |
+
|
72 |
+
# Check values
|
73 |
+
if self.accelerator.is_main_process:
|
74 |
+
self.__check_basic_configs()
|
75 |
+
# Set runtime configs
|
76 |
+
self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
|
77 |
+
self.checkpoints_path = [
|
78 |
+
[] for _ in range(len(self.save_checkpoint_stride))
|
79 |
+
]
|
80 |
+
self.keep_last = [
|
81 |
+
i if i > 0 else float("inf") for i in self.cfg.train.keep_last
|
82 |
+
]
|
83 |
+
self.run_eval = self.cfg.train.run_eval
|
84 |
+
|
85 |
+
# set random seed
|
86 |
+
with self.accelerator.main_process_first():
|
87 |
+
start = time.monotonic_ns()
|
88 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
89 |
+
end = time.monotonic_ns()
|
90 |
+
self.logger.debug(
|
91 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
92 |
+
)
|
93 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
94 |
+
|
95 |
+
# setup data_loader
|
96 |
+
with self.accelerator.main_process_first():
|
97 |
+
self.logger.info("Building dataset...")
|
98 |
+
start = time.monotonic_ns()
|
99 |
+
self.train_dataloader, self.valid_dataloader = self._build_dataloader()
|
100 |
+
end = time.monotonic_ns()
|
101 |
+
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
|
102 |
+
|
103 |
+
# save phone table to exp dir. Should be done before building model due to loading phone table in model
|
104 |
+
if cfg.preprocess.use_phone and cfg.preprocess.phone_extractor != "lexicon":
|
105 |
+
self._save_phone_symbols_file_to_exp_path()
|
106 |
+
|
107 |
+
# setup model
|
108 |
+
with self.accelerator.main_process_first():
|
109 |
+
self.logger.info("Building model...")
|
110 |
+
start = time.monotonic_ns()
|
111 |
+
self.model = self._build_model()
|
112 |
+
end = time.monotonic_ns()
|
113 |
+
self.logger.debug(self.model)
|
114 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
|
115 |
+
self.logger.info(
|
116 |
+
f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
|
117 |
+
)
|
118 |
+
|
119 |
+
# optimizer & scheduler
|
120 |
+
with self.accelerator.main_process_first():
|
121 |
+
self.logger.info("Building optimizer and scheduler...")
|
122 |
+
start = time.monotonic_ns()
|
123 |
+
self.optimizer = self._build_optimizer()
|
124 |
+
self.scheduler = self._build_scheduler()
|
125 |
+
end = time.monotonic_ns()
|
126 |
+
self.logger.info(
|
127 |
+
f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
|
128 |
+
)
|
129 |
+
|
130 |
+
# create criterion
|
131 |
+
with self.accelerator.main_process_first():
|
132 |
+
self.logger.info("Building criterion...")
|
133 |
+
start = time.monotonic_ns()
|
134 |
+
self.criterion = self._build_criterion()
|
135 |
+
end = time.monotonic_ns()
|
136 |
+
self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
|
137 |
+
|
138 |
+
# Resume or Finetune
|
139 |
+
with self.accelerator.main_process_first():
|
140 |
+
self._check_resume()
|
141 |
+
|
142 |
+
# accelerate prepare
|
143 |
+
self.logger.info("Initializing accelerate...")
|
144 |
+
start = time.monotonic_ns()
|
145 |
+
self._accelerator_prepare()
|
146 |
+
end = time.monotonic_ns()
|
147 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
|
148 |
+
|
149 |
+
# save config file path
|
150 |
+
self.config_save_path = os.path.join(self.exp_dir, "args.json")
|
151 |
+
self.device = self.accelerator.device
|
152 |
+
|
153 |
+
if cfg.preprocess.use_spkid and cfg.train.multi_speaker_training:
|
154 |
+
self.speakers = self._build_speaker_lut()
|
155 |
+
self.utt2spk_dict = self._build_utt2spk_dict()
|
156 |
+
|
157 |
+
# Only for TTS tasks
|
158 |
+
self.task_type = "TTS"
|
159 |
+
self.logger.info("Task type: {}".format(self.task_type))
|
160 |
+
|
161 |
+
def _check_resume(self):
|
162 |
+
# if args.resume:
|
163 |
+
if self.args.resume or (
|
164 |
+
self.cfg.model_type == "VALLE" and self.args.train_stage == 2
|
165 |
+
):
|
166 |
+
if self.cfg.model_type == "VALLE" and self.args.train_stage == 2:
|
167 |
+
self.args.resume_type = "finetune"
|
168 |
+
|
169 |
+
self.logger.info("Resuming from checkpoint...")
|
170 |
+
start = time.monotonic_ns()
|
171 |
+
self.ckpt_path = self._load_model(
|
172 |
+
self.checkpoint_dir, self.args.checkpoint_path, self.args.resume_type
|
173 |
+
)
|
174 |
+
end = time.monotonic_ns()
|
175 |
+
self.logger.info(
|
176 |
+
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
|
177 |
+
)
|
178 |
+
self.checkpoints_path = json.load(
|
179 |
+
open(os.path.join(self.ckpt_path, "ckpts.json"), "r")
|
180 |
+
)
|
181 |
+
|
182 |
+
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
|
183 |
+
if self.accelerator.is_main_process:
|
184 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
185 |
+
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
|
186 |
+
|
187 |
+
def _init_accelerator(self):
|
188 |
+
self.exp_dir = os.path.join(
|
189 |
+
os.path.abspath(self.cfg.log_dir), self.args.exp_name
|
190 |
+
)
|
191 |
+
project_config = ProjectConfiguration(
|
192 |
+
project_dir=self.exp_dir,
|
193 |
+
logging_dir=os.path.join(self.exp_dir, "log"),
|
194 |
+
)
|
195 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
196 |
+
self.accelerator = accelerate.Accelerator(
|
197 |
+
gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
|
198 |
+
log_with=self.cfg.train.tracker,
|
199 |
+
project_config=project_config,
|
200 |
+
kwargs_handlers=[kwargs],
|
201 |
+
)
|
202 |
+
if self.accelerator.is_main_process:
|
203 |
+
os.makedirs(project_config.project_dir, exist_ok=True)
|
204 |
+
os.makedirs(project_config.logging_dir, exist_ok=True)
|
205 |
+
with self.accelerator.main_process_first():
|
206 |
+
self.accelerator.init_trackers(self.args.exp_name)
|
207 |
+
|
208 |
+
def _accelerator_prepare(self):
|
209 |
+
(
|
210 |
+
self.train_dataloader,
|
211 |
+
self.valid_dataloader,
|
212 |
+
) = self.accelerator.prepare(
|
213 |
+
self.train_dataloader,
|
214 |
+
self.valid_dataloader,
|
215 |
+
)
|
216 |
+
|
217 |
+
if isinstance(self.model, dict):
|
218 |
+
for key in self.model.keys():
|
219 |
+
self.model[key] = self.accelerator.prepare(self.model[key])
|
220 |
+
else:
|
221 |
+
self.model = self.accelerator.prepare(self.model)
|
222 |
+
|
223 |
+
if isinstance(self.optimizer, dict):
|
224 |
+
for key in self.optimizer.keys():
|
225 |
+
self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
|
226 |
+
else:
|
227 |
+
self.optimizer = self.accelerator.prepare(self.optimizer)
|
228 |
+
|
229 |
+
if isinstance(self.scheduler, dict):
|
230 |
+
for key in self.scheduler.keys():
|
231 |
+
self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
|
232 |
+
else:
|
233 |
+
self.scheduler = self.accelerator.prepare(self.scheduler)
|
234 |
+
|
235 |
+
### Following are methods only for TTS tasks ###
|
236 |
+
def _build_dataset(self):
|
237 |
+
pass
|
238 |
+
|
239 |
+
def _build_criterion(self):
|
240 |
+
pass
|
241 |
+
|
242 |
+
def _build_model(self):
|
243 |
+
pass
|
244 |
+
|
245 |
+
def _build_dataloader(self):
|
246 |
+
"""Build dataloader which merges a series of datasets."""
|
247 |
+
# Build dataset instance for each dataset and combine them by ConcatDataset
|
248 |
+
Dataset, Collator = self._build_dataset()
|
249 |
+
|
250 |
+
# Build train set
|
251 |
+
datasets_list = []
|
252 |
+
for dataset in self.cfg.dataset:
|
253 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
254 |
+
datasets_list.append(subdataset)
|
255 |
+
train_dataset = ConcatDataset(datasets_list)
|
256 |
+
train_collate = Collator(self.cfg)
|
257 |
+
_, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
|
258 |
+
train_loader = DataLoader(
|
259 |
+
train_dataset,
|
260 |
+
collate_fn=train_collate,
|
261 |
+
batch_sampler=batch_sampler,
|
262 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
263 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
264 |
+
)
|
265 |
+
|
266 |
+
# Build test set
|
267 |
+
datasets_list = []
|
268 |
+
for dataset in self.cfg.dataset:
|
269 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
270 |
+
datasets_list.append(subdataset)
|
271 |
+
valid_dataset = ConcatDataset(datasets_list)
|
272 |
+
valid_collate = Collator(self.cfg)
|
273 |
+
_, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
|
274 |
+
valid_loader = DataLoader(
|
275 |
+
valid_dataset,
|
276 |
+
collate_fn=valid_collate,
|
277 |
+
batch_sampler=batch_sampler,
|
278 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
279 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
280 |
+
)
|
281 |
+
return train_loader, valid_loader
|
282 |
+
|
283 |
+
def _build_optimizer(self):
|
284 |
+
pass
|
285 |
+
|
286 |
+
def _build_scheduler(self):
|
287 |
+
pass
|
288 |
+
|
289 |
+
def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
|
290 |
+
"""Load model from checkpoint. If a folder is given, it will
|
291 |
+
load the latest checkpoint in checkpoint_dir. If a path is given
|
292 |
+
it will load the checkpoint specified by checkpoint_path.
|
293 |
+
**Only use this method after** ``accelerator.prepare()``.
|
294 |
+
"""
|
295 |
+
if checkpoint_path is None:
|
296 |
+
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
|
297 |
+
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
|
298 |
+
checkpoint_path = ls[0]
|
299 |
+
self.logger.info("Load model from {}".format(checkpoint_path))
|
300 |
+
print("Load model from {}".format(checkpoint_path))
|
301 |
+
if resume_type == "resume":
|
302 |
+
self.accelerator.load_state(checkpoint_path)
|
303 |
+
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
|
304 |
+
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
|
305 |
+
elif resume_type == "finetune":
|
306 |
+
self.model.load_state_dict(
|
307 |
+
torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
|
308 |
+
)
|
309 |
+
self.model.cuda(self.accelerator.device)
|
310 |
+
self.logger.info("Load model weights for finetune SUCCESS!")
|
311 |
+
else:
|
312 |
+
raise ValueError("Unsupported resume type: {}".format(resume_type))
|
313 |
+
|
314 |
+
return checkpoint_path
|
315 |
+
|
316 |
+
### THIS IS MAIN ENTRY ###
|
317 |
+
def train_loop(self):
|
318 |
+
r"""Training loop. The public entry of training process."""
|
319 |
+
# Wait everyone to prepare before we move on
|
320 |
+
self.accelerator.wait_for_everyone()
|
321 |
+
# dump config file
|
322 |
+
if self.accelerator.is_main_process:
|
323 |
+
self.__dump_cfg(self.config_save_path)
|
324 |
+
|
325 |
+
# self.optimizer.zero_grad()
|
326 |
+
# Wait to ensure good to go
|
327 |
+
|
328 |
+
self.accelerator.wait_for_everyone()
|
329 |
+
while self.epoch < self.max_epoch:
|
330 |
+
self.logger.info("\n")
|
331 |
+
self.logger.info("-" * 32)
|
332 |
+
self.logger.info("Epoch {}: ".format(self.epoch))
|
333 |
+
|
334 |
+
# Do training & validating epoch
|
335 |
+
train_total_loss, train_losses = self._train_epoch()
|
336 |
+
if isinstance(train_losses, dict):
|
337 |
+
for key, loss in train_losses.items():
|
338 |
+
self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
|
339 |
+
self.accelerator.log(
|
340 |
+
{"Epoch/Train {} Loss".format(key): loss},
|
341 |
+
step=self.epoch,
|
342 |
+
)
|
343 |
+
|
344 |
+
valid_total_loss, valid_losses = self._valid_epoch()
|
345 |
+
if isinstance(valid_losses, dict):
|
346 |
+
for key, loss in valid_losses.items():
|
347 |
+
self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
|
348 |
+
self.accelerator.log(
|
349 |
+
{"Epoch/Train {} Loss".format(key): loss},
|
350 |
+
step=self.epoch,
|
351 |
+
)
|
352 |
+
|
353 |
+
self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss))
|
354 |
+
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss))
|
355 |
+
self.accelerator.log(
|
356 |
+
{
|
357 |
+
"Epoch/Train Loss": train_total_loss,
|
358 |
+
"Epoch/Valid Loss": valid_total_loss,
|
359 |
+
},
|
360 |
+
step=self.epoch,
|
361 |
+
)
|
362 |
+
|
363 |
+
self.accelerator.wait_for_everyone()
|
364 |
+
|
365 |
+
# Check if hit save_checkpoint_stride and run_eval
|
366 |
+
run_eval = False
|
367 |
+
if self.accelerator.is_main_process:
|
368 |
+
save_checkpoint = False
|
369 |
+
hit_dix = []
|
370 |
+
for i, num in enumerate(self.save_checkpoint_stride):
|
371 |
+
if self.epoch % num == 0:
|
372 |
+
save_checkpoint = True
|
373 |
+
hit_dix.append(i)
|
374 |
+
run_eval |= self.run_eval[i]
|
375 |
+
|
376 |
+
self.accelerator.wait_for_everyone()
|
377 |
+
if self.accelerator.is_main_process and save_checkpoint:
|
378 |
+
path = os.path.join(
|
379 |
+
self.checkpoint_dir,
|
380 |
+
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
381 |
+
self.epoch, self.step, train_total_loss
|
382 |
+
),
|
383 |
+
)
|
384 |
+
self.accelerator.save_state(path)
|
385 |
+
|
386 |
+
json.dump(
|
387 |
+
self.checkpoints_path,
|
388 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
389 |
+
ensure_ascii=False,
|
390 |
+
indent=4,
|
391 |
+
)
|
392 |
+
|
393 |
+
# Remove old checkpoints
|
394 |
+
to_remove = []
|
395 |
+
for idx in hit_dix:
|
396 |
+
self.checkpoints_path[idx].append(path)
|
397 |
+
while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
|
398 |
+
to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
|
399 |
+
|
400 |
+
# Search conflicts
|
401 |
+
total = set()
|
402 |
+
for i in self.checkpoints_path:
|
403 |
+
total |= set(i)
|
404 |
+
do_remove = set()
|
405 |
+
for idx, path in to_remove[::-1]:
|
406 |
+
if path in total:
|
407 |
+
self.checkpoints_path[idx].insert(0, path)
|
408 |
+
else:
|
409 |
+
do_remove.add(path)
|
410 |
+
|
411 |
+
# Remove old checkpoints
|
412 |
+
for path in do_remove:
|
413 |
+
shutil.rmtree(path, ignore_errors=True)
|
414 |
+
self.logger.debug(f"Remove old checkpoint: {path}")
|
415 |
+
|
416 |
+
self.accelerator.wait_for_everyone()
|
417 |
+
if run_eval:
|
418 |
+
# TODO: run evaluation
|
419 |
+
pass
|
420 |
+
|
421 |
+
# Update info for each epoch
|
422 |
+
self.epoch += 1
|
423 |
+
|
424 |
+
# Finish training and save final checkpoint
|
425 |
+
self.accelerator.wait_for_everyone()
|
426 |
+
if self.accelerator.is_main_process:
|
427 |
+
path = os.path.join(
|
428 |
+
self.checkpoint_dir,
|
429 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
430 |
+
self.epoch, self.step, valid_total_loss
|
431 |
+
),
|
432 |
+
)
|
433 |
+
self.accelerator.save_state(
|
434 |
+
os.path.join(
|
435 |
+
self.checkpoint_dir,
|
436 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
437 |
+
self.epoch, self.step, valid_total_loss
|
438 |
+
),
|
439 |
+
)
|
440 |
+
)
|
441 |
+
|
442 |
+
json.dump(
|
443 |
+
self.checkpoints_path,
|
444 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
445 |
+
ensure_ascii=False,
|
446 |
+
indent=4,
|
447 |
+
)
|
448 |
+
|
449 |
+
self.accelerator.end_training()
|
450 |
+
|
451 |
+
### Following are methods that can be used directly in child classes ###
|
452 |
+
def _train_epoch(self):
|
453 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
454 |
+
one epoch. See ``train_loop`` for usage.
|
455 |
+
"""
|
456 |
+
if isinstance(self.model, dict):
|
457 |
+
for key in self.model.keys():
|
458 |
+
self.model[key].train()
|
459 |
+
else:
|
460 |
+
self.model.train()
|
461 |
+
|
462 |
+
epoch_sum_loss: float = 0.0
|
463 |
+
epoch_losses: dict = {}
|
464 |
+
epoch_step: int = 0
|
465 |
+
for batch in tqdm(
|
466 |
+
self.train_dataloader,
|
467 |
+
desc=f"Training Epoch {self.epoch}",
|
468 |
+
unit="batch",
|
469 |
+
colour="GREEN",
|
470 |
+
leave=False,
|
471 |
+
dynamic_ncols=True,
|
472 |
+
smoothing=0.04,
|
473 |
+
disable=not self.accelerator.is_main_process,
|
474 |
+
):
|
475 |
+
# Do training step and BP
|
476 |
+
with self.accelerator.accumulate(self.model):
|
477 |
+
total_loss, train_losses, _ = self._train_step(batch)
|
478 |
+
self.batch_count += 1
|
479 |
+
|
480 |
+
# Update info for each step
|
481 |
+
# TODO: step means BP counts or batch counts?
|
482 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
483 |
+
if isinstance(self.scheduler, dict):
|
484 |
+
for key in self.scheduler.keys():
|
485 |
+
self.scheduler[key].step()
|
486 |
+
else:
|
487 |
+
if isinstance(self.scheduler, Eden):
|
488 |
+
self.scheduler.step_batch(self.step)
|
489 |
+
else:
|
490 |
+
self.scheduler.step()
|
491 |
+
|
492 |
+
epoch_sum_loss += total_loss
|
493 |
+
|
494 |
+
if isinstance(train_losses, dict):
|
495 |
+
for key, value in train_losses.items():
|
496 |
+
epoch_losses[key] += value
|
497 |
+
|
498 |
+
if isinstance(train_losses, dict):
|
499 |
+
for key, loss in train_losses.items():
|
500 |
+
self.accelerator.log(
|
501 |
+
{"Epoch/Train {} Loss".format(key): loss},
|
502 |
+
step=self.step,
|
503 |
+
)
|
504 |
+
|
505 |
+
self.step += 1
|
506 |
+
epoch_step += 1
|
507 |
+
|
508 |
+
self.accelerator.wait_for_everyone()
|
509 |
+
|
510 |
+
epoch_sum_loss = (
|
511 |
+
epoch_sum_loss
|
512 |
+
/ len(self.train_dataloader)
|
513 |
+
* self.cfg.train.gradient_accumulation_step
|
514 |
+
)
|
515 |
+
|
516 |
+
for key in epoch_losses.keys():
|
517 |
+
epoch_losses[key] = (
|
518 |
+
epoch_losses[key]
|
519 |
+
/ len(self.train_dataloader)
|
520 |
+
* self.cfg.train.gradient_accumulation_step
|
521 |
+
)
|
522 |
+
|
523 |
+
return epoch_sum_loss, epoch_losses
|
524 |
+
|
525 |
+
@torch.inference_mode()
|
526 |
+
def _valid_epoch(self):
|
527 |
+
r"""Testing epoch. Should return average loss of a batch (sample) over
|
528 |
+
one epoch. See ``train_loop`` for usage.
|
529 |
+
"""
|
530 |
+
if isinstance(self.model, dict):
|
531 |
+
for key in self.model.keys():
|
532 |
+
self.model[key].eval()
|
533 |
+
else:
|
534 |
+
self.model.eval()
|
535 |
+
|
536 |
+
epoch_sum_loss = 0.0
|
537 |
+
epoch_losses = dict()
|
538 |
+
for batch in tqdm(
|
539 |
+
self.valid_dataloader,
|
540 |
+
desc=f"Validating Epoch {self.epoch}",
|
541 |
+
unit="batch",
|
542 |
+
colour="GREEN",
|
543 |
+
leave=False,
|
544 |
+
dynamic_ncols=True,
|
545 |
+
smoothing=0.04,
|
546 |
+
disable=not self.accelerator.is_main_process,
|
547 |
+
):
|
548 |
+
total_loss, valid_losses, valid_stats = self._valid_step(batch)
|
549 |
+
epoch_sum_loss += total_loss
|
550 |
+
if isinstance(valid_losses, dict):
|
551 |
+
for key, value in valid_losses.items():
|
552 |
+
if key not in epoch_losses.keys():
|
553 |
+
epoch_losses[key] = value
|
554 |
+
else:
|
555 |
+
epoch_losses[key] += value
|
556 |
+
|
557 |
+
epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader)
|
558 |
+
for key in epoch_losses.keys():
|
559 |
+
epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
|
560 |
+
|
561 |
+
self.accelerator.wait_for_everyone()
|
562 |
+
|
563 |
+
return epoch_sum_loss, epoch_losses
|
564 |
+
|
565 |
+
def _train_step(self):
|
566 |
+
pass
|
567 |
+
|
568 |
+
def _valid_step(self, batch):
|
569 |
+
pass
|
570 |
+
|
571 |
+
def _inference(self):
|
572 |
+
pass
|
573 |
+
|
574 |
+
def _is_valid_pattern(self, directory_name):
|
575 |
+
directory_name = str(directory_name)
|
576 |
+
pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
|
577 |
+
return re.match(pattern, directory_name) is not None
|
578 |
+
|
579 |
+
def _check_basic_configs(self):
|
580 |
+
if self.cfg.train.gradient_accumulation_step <= 0:
|
581 |
+
self.logger.fatal("Invalid gradient_accumulation_step value!")
|
582 |
+
self.logger.error(
|
583 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
584 |
+
)
|
585 |
+
self.accelerator.end_training()
|
586 |
+
raise ValueError(
|
587 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
588 |
+
)
|
589 |
+
|
590 |
+
def __dump_cfg(self, path):
|
591 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
592 |
+
json5.dump(
|
593 |
+
self.cfg,
|
594 |
+
open(path, "w"),
|
595 |
+
indent=4,
|
596 |
+
sort_keys=True,
|
597 |
+
ensure_ascii=False,
|
598 |
+
quote_keys=True,
|
599 |
+
)
|
600 |
+
|
601 |
+
def __check_basic_configs(self):
|
602 |
+
if self.cfg.train.gradient_accumulation_step <= 0:
|
603 |
+
self.logger.fatal("Invalid gradient_accumulation_step value!")
|
604 |
+
self.logger.error(
|
605 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
606 |
+
)
|
607 |
+
self.accelerator.end_training()
|
608 |
+
raise ValueError(
|
609 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
610 |
+
)
|
611 |
+
# TODO: check other values
|
612 |
+
|
613 |
+
@staticmethod
|
614 |
+
def __count_parameters(model):
|
615 |
+
model_param = 0.0
|
616 |
+
if isinstance(model, dict):
|
617 |
+
for key, value in model.items():
|
618 |
+
model_param += sum(p.numel() for p in model[key].parameters())
|
619 |
+
else:
|
620 |
+
model_param = sum(p.numel() for p in model.parameters())
|
621 |
+
return model_param
|
622 |
+
|
623 |
+
def _build_speaker_lut(self):
|
624 |
+
# combine speakers
|
625 |
+
if not os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
|
626 |
+
speakers = {}
|
627 |
+
else:
|
628 |
+
with open(
|
629 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "r"
|
630 |
+
) as speaker_file:
|
631 |
+
speakers = json.load(speaker_file)
|
632 |
+
for dataset in self.cfg.dataset:
|
633 |
+
speaker_lut_path = os.path.join(
|
634 |
+
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
|
635 |
+
)
|
636 |
+
with open(speaker_lut_path, "r") as speaker_lut_path:
|
637 |
+
singer_lut = json.load(speaker_lut_path)
|
638 |
+
for singer in singer_lut.keys():
|
639 |
+
if singer not in speakers:
|
640 |
+
speakers[singer] = len(speakers)
|
641 |
+
with open(
|
642 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
|
643 |
+
) as speaker_file:
|
644 |
+
json.dump(speakers, speaker_file, indent=4, ensure_ascii=False)
|
645 |
+
print(
|
646 |
+
"speakers have been dumped to {}".format(
|
647 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
|
648 |
+
)
|
649 |
+
)
|
650 |
+
return speakers
|
651 |
+
|
652 |
+
def _build_utt2spk_dict(self):
|
653 |
+
# combine speakers
|
654 |
+
utt2spk = {}
|
655 |
+
if not os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk)):
|
656 |
+
utt2spk = {}
|
657 |
+
else:
|
658 |
+
with open(
|
659 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk), "r"
|
660 |
+
) as utt2spk_file:
|
661 |
+
for line in utt2spk_file.readlines():
|
662 |
+
utt, spk = line.strip().split("\t")
|
663 |
+
utt2spk[utt] = spk
|
664 |
+
for dataset in self.cfg.dataset:
|
665 |
+
utt2spk_dict_path = os.path.join(
|
666 |
+
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.utt2spk
|
667 |
+
)
|
668 |
+
with open(utt2spk_dict_path, "r") as utt2spk_dict:
|
669 |
+
for line in utt2spk_dict.readlines():
|
670 |
+
utt, spk = line.strip().split("\t")
|
671 |
+
if utt not in utt2spk.keys():
|
672 |
+
utt2spk[utt] = spk
|
673 |
+
with open(
|
674 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk), "w"
|
675 |
+
) as utt2spk_file:
|
676 |
+
for utt, spk in utt2spk.items():
|
677 |
+
utt2spk_file.write(utt + "\t" + spk + "\n")
|
678 |
+
print(
|
679 |
+
"utterance and speaker mapper have been dumped to {}".format(
|
680 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk)
|
681 |
+
)
|
682 |
+
)
|
683 |
+
return utt2spk
|
684 |
+
|
685 |
+
def _save_phone_symbols_file_to_exp_path(self):
|
686 |
+
phone_symbols_file = os.path.join(
|
687 |
+
self.cfg.preprocess.processed_dir,
|
688 |
+
self.cfg.dataset[0],
|
689 |
+
self.cfg.preprocess.symbols_dict,
|
690 |
+
)
|
691 |
+
phone_symbols_file_to_exp_path = os.path.join(
|
692 |
+
self.exp_dir, self.cfg.preprocess.symbols_dict
|
693 |
+
)
|
694 |
+
shutil.copy(phone_symbols_file, phone_symbols_file_to_exp_path)
|
695 |
+
print(
|
696 |
+
"phone symbols been dumped to {}".format(
|
697 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.symbols_dict)
|
698 |
+
)
|
699 |
+
)
|
models/tts/fastspeech2/__init__.py
ADDED
File without changes
|