JacobLinCool commited on
Commit
d07b6f4
1 Parent(s): 8fe2bf8

feat: training

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.wav filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,10 +1,14 @@
 
1
  import gradio as gr
2
  import zipfile
3
  import os
4
  import tempfile
5
  import shutil
6
- from infer.modules.train.preprocess import PreProcess, preprocess_trainset
 
7
  from infer.modules.train.extract.extract_f0_rmvpe import FeatureInput
 
 
8
  from zero import zero
9
 
10
 
@@ -44,11 +48,6 @@ def preprocess(zip_file: str) -> str:
44
  return temp_dir, f"Preprocessed {len(audio_files)} audio files.\n{log}"
45
 
46
 
47
- def download_expdir(exp_dir: str) -> str:
48
- shutil.make_archive(exp_dir, "zip", exp_dir)
49
- return f"{exp_dir}.zip"
50
-
51
-
52
  @zero(duration=120)
53
  def extract_features(exp_dir: str) -> str:
54
  err = None
@@ -67,6 +66,108 @@ def extract_features(exp_dir: str) -> str:
67
  return log
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  with gr.Blocks() as app:
71
  with gr.Row():
72
  with gr.Column():
@@ -89,6 +190,20 @@ with gr.Blocks() as app:
89
  label="Feature extraction output", lines=5
90
  )
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  with gr.Row():
93
  with gr.Column():
94
  download_expdir_btn = gr.Button(
@@ -109,6 +224,18 @@ with gr.Blocks() as app:
109
  outputs=[extract_features_output],
110
  )
111
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  download_expdir_btn.click(
113
  fn=download_expdir,
114
  inputs=[exp_dir],
 
1
+ from random import shuffle
2
  import gradio as gr
3
  import zipfile
4
  import os
5
  import tempfile
6
  import shutil
7
+ from glob import glob
8
+ from infer.modules.train.preprocess import PreProcess
9
  from infer.modules.train.extract.extract_f0_rmvpe import FeatureInput
10
+ from infer.modules.train.train import train
11
+ from infer.lib.train.process_ckpt import extract_small_model
12
  from zero import zero
13
 
14
 
 
48
  return temp_dir, f"Preprocessed {len(audio_files)} audio files.\n{log}"
49
 
50
 
 
 
 
 
 
51
  @zero(duration=120)
52
  def extract_features(exp_dir: str) -> str:
53
  err = None
 
66
  return log
67
 
68
 
69
+ def write_filelist(exp_dir: str) -> None:
70
+ if_f0_3 = True
71
+ spk_id5 = 0
72
+ gt_wavs_dir = "%s/0_gt_wavs" % (exp_dir)
73
+ feature_dir = "%s/3_feature768" % (exp_dir)
74
+
75
+ if if_f0_3:
76
+ f0_dir = "%s/2a_f0" % (exp_dir)
77
+ f0nsf_dir = "%s/2b-f0nsf" % (exp_dir)
78
+ names = (
79
+ set([name.split(".")[0] for name in os.listdir(gt_wavs_dir)])
80
+ & set([name.split(".")[0] for name in os.listdir(feature_dir)])
81
+ & set([name.split(".")[0] for name in os.listdir(f0_dir)])
82
+ & set([name.split(".")[0] for name in os.listdir(f0nsf_dir)])
83
+ )
84
+ else:
85
+ names = set([name.split(".")[0] for name in os.listdir(gt_wavs_dir)]) & set(
86
+ [name.split(".")[0] for name in os.listdir(feature_dir)]
87
+ )
88
+ opt = []
89
+ for name in names:
90
+ if if_f0_3:
91
+ opt.append(
92
+ "%s/%s.wav|%s/%s.npy|%s/%s.wav.npy|%s/%s.wav.npy|%s"
93
+ % (
94
+ gt_wavs_dir.replace("\\", "\\\\"),
95
+ name,
96
+ feature_dir.replace("\\", "\\\\"),
97
+ name,
98
+ f0_dir.replace("\\", "\\\\"),
99
+ name,
100
+ f0nsf_dir.replace("\\", "\\\\"),
101
+ name,
102
+ spk_id5,
103
+ )
104
+ )
105
+ else:
106
+ opt.append(
107
+ "%s/%s.wav|%s/%s.npy|%s"
108
+ % (
109
+ gt_wavs_dir.replace("\\", "\\\\"),
110
+ name,
111
+ feature_dir.replace("\\", "\\\\"),
112
+ name,
113
+ spk_id5,
114
+ )
115
+ )
116
+ fea_dim = 768
117
+
118
+ now_dir = os.getcwd()
119
+ sr2 = "40k"
120
+ if if_f0_3:
121
+ for _ in range(2):
122
+ opt.append(
123
+ "%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature%s/mute.npy|%s/logs/mute/2a_f0/mute.wav.npy|%s/logs/mute/2b-f0nsf/mute.wav.npy|%s"
124
+ % (now_dir, sr2, now_dir, fea_dim, now_dir, now_dir, spk_id5)
125
+ )
126
+ else:
127
+ for _ in range(2):
128
+ opt.append(
129
+ "%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature%s/mute.npy|%s"
130
+ % (now_dir, sr2, now_dir, fea_dim, spk_id5)
131
+ )
132
+ shuffle(opt)
133
+ with open("%s/filelist.txt" % exp_dir, "w") as f:
134
+ f.write("\n".join(opt))
135
+
136
+
137
+ @zero(duration=300)
138
+ def train_model(exp_dir: str) -> str:
139
+ shutil.copy("config.json", exp_dir)
140
+ write_filelist(exp_dir)
141
+ train(exp_dir)
142
+
143
+ models = glob(f"{exp_dir}/G_*.pth")
144
+ if not models:
145
+ raise gr.Error("No model found")
146
+
147
+ latest_model = max(models, key=os.path.getctime)
148
+ return latest_model
149
+
150
+
151
+ def download_weight(exp_dir: str) -> str:
152
+ models = glob(f"{exp_dir}/G_*.pth")
153
+ if not models:
154
+ raise gr.Error("No model found")
155
+
156
+ latest_model = max(models, key=os.path.getctime)
157
+
158
+ name = os.path.basename(exp_dir)
159
+ extract_small_model(
160
+ latest_model, name, "40k", True, "Model trained by ZeroGPU.", "v2"
161
+ )
162
+
163
+ return "assets/weights/%s.pth" % name
164
+
165
+
166
+ def download_expdir(exp_dir: str) -> str:
167
+ shutil.make_archive(exp_dir, "zip", exp_dir)
168
+ return f"{exp_dir}.zip"
169
+
170
+
171
  with gr.Blocks() as app:
172
  with gr.Row():
173
  with gr.Column():
 
190
  label="Feature extraction output", lines=5
191
  )
192
 
193
+ with gr.Row():
194
+ with gr.Column():
195
+ train_btn = gr.Button(value="Train", variant="primary")
196
+ with gr.Column():
197
+ latest_model = gr.File(label="Latest model")
198
+
199
+ with gr.Row():
200
+ with gr.Column():
201
+ download_weight_btn = gr.Button(
202
+ value="Download latest model", variant="primary"
203
+ )
204
+ with gr.Column():
205
+ download_weight_output = gr.File(label="Download latest model")
206
+
207
  with gr.Row():
208
  with gr.Column():
209
  download_expdir_btn = gr.Button(
 
224
  outputs=[extract_features_output],
225
  )
226
 
227
+ train_btn.click(
228
+ fn=train_model,
229
+ inputs=[exp_dir],
230
+ outputs=[latest_model],
231
+ )
232
+
233
+ download_weight_btn.click(
234
+ fn=download_weight,
235
+ inputs=[exp_dir],
236
+ outputs=[download_weight_output],
237
+ )
238
+
239
  download_expdir_btn.click(
240
  fn=download_expdir,
241
  inputs=[exp_dir],
config.json ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "data": {
3
+ "filter_length": 2048,
4
+ "hop_length": 400,
5
+ "max_wav_value": 32768.0,
6
+ "mel_fmax": null,
7
+ "mel_fmin": 0.0,
8
+ "n_mel_channels": 125,
9
+ "sampling_rate": 40000,
10
+ "win_length": 2048
11
+ },
12
+ "model": {
13
+ "filter_channels": 768,
14
+ "gin_channels": 256,
15
+ "hidden_channels": 192,
16
+ "inter_channels": 192,
17
+ "kernel_size": 3,
18
+ "n_heads": 2,
19
+ "n_layers": 6,
20
+ "p_dropout": 0,
21
+ "resblock": "1",
22
+ "resblock_dilation_sizes": [
23
+ [
24
+ 1,
25
+ 3,
26
+ 5
27
+ ],
28
+ [
29
+ 1,
30
+ 3,
31
+ 5
32
+ ],
33
+ [
34
+ 1,
35
+ 3,
36
+ 5
37
+ ]
38
+ ],
39
+ "resblock_kernel_sizes": [
40
+ 3,
41
+ 7,
42
+ 11
43
+ ],
44
+ "spk_embed_dim": 109,
45
+ "upsample_initial_channel": 512,
46
+ "upsample_kernel_sizes": [
47
+ 16,
48
+ 16,
49
+ 4,
50
+ 4
51
+ ],
52
+ "upsample_rates": [
53
+ 10,
54
+ 10,
55
+ 2,
56
+ 2
57
+ ],
58
+ "use_spectral_norm": false
59
+ },
60
+ "train": {
61
+ "batch_size": 4,
62
+ "betas": [
63
+ 0.8,
64
+ 0.99
65
+ ],
66
+ "c_kl": 1.0,
67
+ "c_mel": 45,
68
+ "epochs": 20000,
69
+ "eps": 1e-09,
70
+ "fp16_run": false,
71
+ "init_lr_ratio": 1,
72
+ "learning_rate": 0.0001,
73
+ "log_interval": 200,
74
+ "lr_decay": 0.999875,
75
+ "seed": 1234,
76
+ "segment_size": 12800,
77
+ "warmup_epochs": 0
78
+ }
79
+ }
infer-web.py ADDED
@@ -0,0 +1,1265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from dotenv import load_dotenv
4
+
5
+ now_dir = os.getcwd()
6
+ sys.path.append(now_dir)
7
+ load_dotenv()
8
+ from infer.modules.vc.modules import VC
9
+ from infer.modules.uvr5.modules import uvr
10
+ from infer.lib.train.process_ckpt import (
11
+ change_info,
12
+ extract_small_model,
13
+ merge,
14
+ show_info,
15
+ )
16
+ from i18n.i18n import I18nAuto
17
+ from configs.config import Config
18
+ from sklearn.cluster import MiniBatchKMeans
19
+ import torch, platform
20
+ import numpy as np
21
+ import gradio as gr
22
+ import faiss
23
+ import fairseq
24
+ import pathlib
25
+ import json
26
+ from time import sleep
27
+ from subprocess import Popen
28
+ from random import shuffle
29
+ import warnings
30
+ import traceback
31
+ import threading
32
+ import shutil
33
+ import logging
34
+
35
+
36
+ logging.getLogger("numba").setLevel(logging.WARNING)
37
+ logging.getLogger("httpx").setLevel(logging.WARNING)
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ tmp = os.path.join(now_dir, "TEMP")
42
+ shutil.rmtree(tmp, ignore_errors=True)
43
+ shutil.rmtree("%s/runtime/Lib/site-packages/infer_pack" % (now_dir), ignore_errors=True)
44
+ shutil.rmtree("%s/runtime/Lib/site-packages/uvr5_pack" % (now_dir), ignore_errors=True)
45
+ os.makedirs(tmp, exist_ok=True)
46
+ os.makedirs(os.path.join(now_dir, "logs"), exist_ok=True)
47
+ os.makedirs(os.path.join(now_dir, "assets/weights"), exist_ok=True)
48
+ os.environ["TEMP"] = tmp
49
+ warnings.filterwarnings("ignore")
50
+ torch.manual_seed(114514)
51
+
52
+
53
+ config = Config()
54
+ vc = VC(config)
55
+
56
+
57
+ if config.dml == True:
58
+
59
+ def forward_dml(ctx, x, scale):
60
+ ctx.scale = scale
61
+ res = x.clone().detach()
62
+ return res
63
+
64
+ fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
65
+ i18n = I18nAuto()
66
+ logger.info(i18n)
67
+ # 判断是否有能用来训练和加速推理的N卡
68
+ ngpu = torch.cuda.device_count()
69
+ gpu_infos = []
70
+ mem = []
71
+ if_gpu_ok = False
72
+
73
+ if torch.cuda.is_available() or ngpu != 0:
74
+ for i in range(ngpu):
75
+ gpu_name = torch.cuda.get_device_name(i)
76
+ if any(
77
+ value in gpu_name.upper()
78
+ for value in [
79
+ "10",
80
+ "16",
81
+ "20",
82
+ "30",
83
+ "40",
84
+ "A2",
85
+ "A3",
86
+ "A4",
87
+ "P4",
88
+ "A50",
89
+ "500",
90
+ "A60",
91
+ "70",
92
+ "80",
93
+ "90",
94
+ "M4",
95
+ "T4",
96
+ "TITAN",
97
+ "4060",
98
+ "L",
99
+ "6000",
100
+ ]
101
+ ):
102
+ # A10#A100#V100#A40#P40#M40#K80#A4500
103
+ if_gpu_ok = True # 至少有一张能用的N卡
104
+ gpu_infos.append("%s\t%s" % (i, gpu_name))
105
+ mem.append(
106
+ int(
107
+ torch.cuda.get_device_properties(i).total_memory
108
+ / 1024
109
+ / 1024
110
+ / 1024
111
+ + 0.4
112
+ )
113
+ )
114
+ if if_gpu_ok and len(gpu_infos) > 0:
115
+ gpu_info = "\n".join(gpu_infos)
116
+ default_batch_size = min(mem) // 2
117
+ else:
118
+ gpu_info = i18n("很遗憾您这没有能用的显卡来支持您训练")
119
+ default_batch_size = 1
120
+ gpus = "-".join([i[0] for i in gpu_infos])
121
+
122
+
123
+ class ToolButton(gr.Button, gr.components.FormComponent):
124
+ """Small button with single emoji as text, fits inside gradio forms"""
125
+
126
+ def __init__(self, **kwargs):
127
+ super().__init__(variant="tool", **kwargs)
128
+
129
+ def get_block_name(self):
130
+ return "button"
131
+
132
+
133
+ weight_root = os.getenv("weight_root")
134
+ weight_uvr5_root = os.getenv("weight_uvr5_root")
135
+ index_root = os.getenv("index_root")
136
+ outside_index_root = os.getenv("outside_index_root")
137
+
138
+ names = []
139
+ for name in os.listdir(weight_root):
140
+ if name.endswith(".pth"):
141
+ names.append(name)
142
+ index_paths = []
143
+
144
+
145
+ def lookup_indices(index_root):
146
+ global index_paths
147
+ for root, dirs, files in os.walk(index_root, topdown=False):
148
+ for name in files:
149
+ if name.endswith(".index") and "trained" not in name:
150
+ index_paths.append("%s/%s" % (root, name))
151
+
152
+
153
+ lookup_indices(index_root)
154
+ lookup_indices(outside_index_root)
155
+ uvr5_names = []
156
+ for name in os.listdir(weight_uvr5_root):
157
+ if name.endswith(".pth") or "onnx" in name:
158
+ uvr5_names.append(name.replace(".pth", ""))
159
+
160
+
161
+ def change_choices():
162
+ names = []
163
+ for name in os.listdir(weight_root):
164
+ if name.endswith(".pth"):
165
+ names.append(name)
166
+ index_paths = []
167
+ for root, dirs, files in os.walk(index_root, topdown=False):
168
+ for name in files:
169
+ if name.endswith(".index") and "trained" not in name:
170
+ index_paths.append("%s/%s" % (root, name))
171
+ return {"choices": sorted(names), "__type__": "update"}, {
172
+ "choices": sorted(index_paths),
173
+ "__type__": "update",
174
+ }
175
+
176
+
177
+ def clean():
178
+ return {"value": "", "__type__": "update"}
179
+
180
+
181
+ def export_onnx(ModelPath, ExportedPath):
182
+ from infer.modules.onnx.export import export_onnx as eo
183
+
184
+ eo(ModelPath, ExportedPath)
185
+
186
+
187
+ sr_dict = {
188
+ "32k": 32000,
189
+ "40k": 40000,
190
+ "48k": 48000,
191
+ }
192
+
193
+
194
+ def if_done(done, p):
195
+ while 1:
196
+ if p.poll() is None:
197
+ sleep(0.5)
198
+ else:
199
+ break
200
+ done[0] = True
201
+
202
+
203
+ def if_done_multi(done, ps):
204
+ while 1:
205
+ # poll==None代表进程未结束
206
+ # 只要有一个进程未结束都不停
207
+ flag = 1
208
+ for p in ps:
209
+ if p.poll() is None:
210
+ flag = 0
211
+ sleep(0.5)
212
+ break
213
+ if flag == 1:
214
+ break
215
+ done[0] = True
216
+
217
+
218
+ def preprocess_dataset(trainset_dir, exp_dir, sr, n_p):
219
+ sr = sr_dict[sr]
220
+ os.makedirs("%s/logs/%s" % (now_dir, exp_dir), exist_ok=True)
221
+ f = open("%s/logs/%s/preprocess.log" % (now_dir, exp_dir), "w")
222
+ f.close()
223
+ cmd = '"%s" infer/modules/train/preprocess.py "%s" %s %s "%s/logs/%s" %s %.1f' % (
224
+ config.python_cmd,
225
+ trainset_dir,
226
+ sr,
227
+ n_p,
228
+ now_dir,
229
+ exp_dir,
230
+ config.noparallel,
231
+ config.preprocess_per,
232
+ )
233
+ logger.info("Execute: " + cmd)
234
+ # , stdin=PIPE, stdout=PIPE,stderr=PIPE,cwd=now_dir
235
+ p = Popen(cmd, shell=True)
236
+ # 煞笔gr, popen read都非得全跑完了再一次性读取, 不用gr就正常读一句输出一句;只能额外弄出一个文本流定时读
237
+ done = [False]
238
+ threading.Thread(
239
+ target=if_done,
240
+ args=(
241
+ done,
242
+ p,
243
+ ),
244
+ ).start()
245
+ while 1:
246
+ with open("%s/logs/%s/preprocess.log" % (now_dir, exp_dir), "r") as f:
247
+ yield (f.read())
248
+ sleep(1)
249
+ if done[0]:
250
+ break
251
+ with open("%s/logs/%s/preprocess.log" % (now_dir, exp_dir), "r") as f:
252
+ log = f.read()
253
+ logger.info(log)
254
+ yield log
255
+
256
+
257
+ # but2.click(extract_f0,[gpus6,np7,f0method8,if_f0_3,trainset_dir4],[info2])
258
+ def extract_f0_feature(gpus, n_p, f0method, if_f0, exp_dir, version19, gpus_rmvpe):
259
+ gpus = gpus.split("-")
260
+ os.makedirs("%s/logs/%s" % (now_dir, exp_dir), exist_ok=True)
261
+ f = open("%s/logs/%s/extract_f0_feature.log" % (now_dir, exp_dir), "w")
262
+ f.close()
263
+ if if_f0:
264
+ if f0method != "rmvpe_gpu":
265
+ cmd = (
266
+ '"%s" infer/modules/train/extract/extract_f0_print.py "%s/logs/%s" %s %s'
267
+ % (
268
+ config.python_cmd,
269
+ now_dir,
270
+ exp_dir,
271
+ n_p,
272
+ f0method,
273
+ )
274
+ )
275
+ logger.info("Execute: " + cmd)
276
+ p = Popen(
277
+ cmd, shell=True, cwd=now_dir
278
+ ) # , stdin=PIPE, stdout=PIPE,stderr=PIPE
279
+ # 煞笔gr, popen read都非得全跑完了再一次性读取, 不用gr就正常读一句输出一句;只能额外弄出一个文本流定时读
280
+ done = [False]
281
+ threading.Thread(
282
+ target=if_done,
283
+ args=(
284
+ done,
285
+ p,
286
+ ),
287
+ ).start()
288
+ else:
289
+ if gpus_rmvpe != "-":
290
+ gpus_rmvpe = gpus_rmvpe.split("-")
291
+ leng = len(gpus_rmvpe)
292
+ ps = []
293
+ for idx, n_g in enumerate(gpus_rmvpe):
294
+ cmd = (
295
+ '"%s" infer/modules/train/extract/extract_f0_rmvpe.py %s %s %s "%s/logs/%s" %s '
296
+ % (
297
+ config.python_cmd,
298
+ leng,
299
+ idx,
300
+ n_g,
301
+ now_dir,
302
+ exp_dir,
303
+ config.is_half,
304
+ )
305
+ )
306
+ logger.info("Execute: " + cmd)
307
+ p = Popen(
308
+ cmd, shell=True, cwd=now_dir
309
+ ) # , shell=True, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=now_dir
310
+ ps.append(p)
311
+ # 煞笔gr, popen read都非得全跑完了再一次性读取, 不用gr就正常读一句输出一句;只能额外弄出一个文本流定时读
312
+ done = [False]
313
+ threading.Thread(
314
+ target=if_done_multi, #
315
+ args=(
316
+ done,
317
+ ps,
318
+ ),
319
+ ).start()
320
+ else:
321
+ cmd = (
322
+ config.python_cmd
323
+ + ' infer/modules/train/extract/extract_f0_rmvpe_dml.py "%s/logs/%s" '
324
+ % (
325
+ now_dir,
326
+ exp_dir,
327
+ )
328
+ )
329
+ logger.info("Execute: " + cmd)
330
+ p = Popen(
331
+ cmd, shell=True, cwd=now_dir
332
+ ) # , shell=True, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=now_dir
333
+ p.wait()
334
+ done = [True]
335
+ while 1:
336
+ with open(
337
+ "%s/logs/%s/extract_f0_feature.log" % (now_dir, exp_dir), "r"
338
+ ) as f:
339
+ yield (f.read())
340
+ sleep(1)
341
+ if done[0]:
342
+ break
343
+ with open("%s/logs/%s/extract_f0_feature.log" % (now_dir, exp_dir), "r") as f:
344
+ log = f.read()
345
+ logger.info(log)
346
+ yield log
347
+ # 对不同part分别开多进程
348
+ """
349
+ n_part=int(sys.argv[1])
350
+ i_part=int(sys.argv[2])
351
+ i_gpu=sys.argv[3]
352
+ exp_dir=sys.argv[4]
353
+ os.environ["CUDA_VISIBLE_DEVICES"]=str(i_gpu)
354
+ """
355
+ leng = len(gpus)
356
+ ps = []
357
+ for idx, n_g in enumerate(gpus):
358
+ cmd = (
359
+ '"%s" infer/modules/train/extract_feature_print.py %s %s %s %s "%s/logs/%s" %s %s'
360
+ % (
361
+ config.python_cmd,
362
+ config.device,
363
+ leng,
364
+ idx,
365
+ n_g,
366
+ now_dir,
367
+ exp_dir,
368
+ version19,
369
+ config.is_half,
370
+ )
371
+ )
372
+ logger.info("Execute: " + cmd)
373
+ p = Popen(
374
+ cmd, shell=True, cwd=now_dir
375
+ ) # , shell=True, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=now_dir
376
+ ps.append(p)
377
+ # 煞笔gr, popen read都非得全跑完了再一次性读取, 不用gr就正常读一句输出一句;只能额外弄出一个文本流定时读
378
+ done = [False]
379
+ threading.Thread(
380
+ target=if_done_multi,
381
+ args=(
382
+ done,
383
+ ps,
384
+ ),
385
+ ).start()
386
+ while 1:
387
+ with open("%s/logs/%s/extract_f0_feature.log" % (now_dir, exp_dir), "r") as f:
388
+ yield (f.read())
389
+ sleep(1)
390
+ if done[0]:
391
+ break
392
+ with open("%s/logs/%s/extract_f0_feature.log" % (now_dir, exp_dir), "r") as f:
393
+ log = f.read()
394
+ logger.info(log)
395
+ yield log
396
+
397
+
398
+ def get_pretrained_models(path_str, f0_str, sr2):
399
+ if_pretrained_generator_exist = os.access(
400
+ "assets/pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2), os.F_OK
401
+ )
402
+ if_pretrained_discriminator_exist = os.access(
403
+ "assets/pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2), os.F_OK
404
+ )
405
+ if not if_pretrained_generator_exist:
406
+ logger.warning(
407
+ "assets/pretrained%s/%sG%s.pth not exist, will not use pretrained model",
408
+ path_str,
409
+ f0_str,
410
+ sr2,
411
+ )
412
+ if not if_pretrained_discriminator_exist:
413
+ logger.warning(
414
+ "assets/pretrained%s/%sD%s.pth not exist, will not use pretrained model",
415
+ path_str,
416
+ f0_str,
417
+ sr2,
418
+ )
419
+ return (
420
+ (
421
+ "assets/pretrained%s/%sG%s.pth" % (path_str, f0_str, sr2)
422
+ if if_pretrained_generator_exist
423
+ else ""
424
+ ),
425
+ (
426
+ "assets/pretrained%s/%sD%s.pth" % (path_str, f0_str, sr2)
427
+ if if_pretrained_discriminator_exist
428
+ else ""
429
+ ),
430
+ )
431
+
432
+
433
+ def change_sr2(sr2, if_f0_3, version19):
434
+ path_str = "" if version19 == "v1" else "_v2"
435
+ f0_str = "f0" if if_f0_3 else ""
436
+ return get_pretrained_models(path_str, f0_str, sr2)
437
+
438
+
439
+ def change_version19(sr2, if_f0_3, version19):
440
+ path_str = "" if version19 == "v1" else "_v2"
441
+ if sr2 == "32k" and version19 == "v1":
442
+ sr2 = "40k"
443
+ to_return_sr2 = (
444
+ {"choices": ["40k", "48k"], "__type__": "update", "value": sr2}
445
+ if version19 == "v1"
446
+ else {"choices": ["40k", "48k", "32k"], "__type__": "update", "value": sr2}
447
+ )
448
+ f0_str = "f0" if if_f0_3 else ""
449
+ return (
450
+ *get_pretrained_models(path_str, f0_str, sr2),
451
+ to_return_sr2,
452
+ )
453
+
454
+
455
+ def change_f0(if_f0_3, sr2, version19): # f0method8,pretrained_G14,pretrained_D15
456
+ path_str = "" if version19 == "v1" else "_v2"
457
+ return (
458
+ {"visible": if_f0_3, "__type__": "update"},
459
+ {"visible": if_f0_3, "__type__": "update"},
460
+ *get_pretrained_models(path_str, "f0" if if_f0_3 == True else "", sr2),
461
+ )
462
+
463
+
464
+ # but3.click(click_train,[exp_dir1,sr2,if_f0_3,save_epoch10,total_epoch11,batch_size12,if_save_latest13,pretrained_G14,pretrained_D15,gpus16])
465
+ def click_train(
466
+ exp_dir1,
467
+ sr2,
468
+ if_f0_3,
469
+ spk_id5,
470
+ save_epoch10,
471
+ total_epoch11,
472
+ batch_size12,
473
+ if_save_latest13,
474
+ pretrained_G14,
475
+ pretrained_D15,
476
+ gpus16,
477
+ if_cache_gpu17,
478
+ if_save_every_weights18,
479
+ version19,
480
+ ):
481
+ # 生成filelist
482
+ exp_dir = "%s/logs/%s" % (now_dir, exp_dir1)
483
+ os.makedirs(exp_dir, exist_ok=True)
484
+ gt_wavs_dir = "%s/0_gt_wavs" % (exp_dir)
485
+ feature_dir = (
486
+ "%s/3_feature256" % (exp_dir)
487
+ if version19 == "v1"
488
+ else "%s/3_feature768" % (exp_dir)
489
+ )
490
+ if if_f0_3:
491
+ f0_dir = "%s/2a_f0" % (exp_dir)
492
+ f0nsf_dir = "%s/2b-f0nsf" % (exp_dir)
493
+ names = (
494
+ set([name.split(".")[0] for name in os.listdir(gt_wavs_dir)])
495
+ & set([name.split(".")[0] for name in os.listdir(feature_dir)])
496
+ & set([name.split(".")[0] for name in os.listdir(f0_dir)])
497
+ & set([name.split(".")[0] for name in os.listdir(f0nsf_dir)])
498
+ )
499
+ else:
500
+ names = set([name.split(".")[0] for name in os.listdir(gt_wavs_dir)]) & set(
501
+ [name.split(".")[0] for name in os.listdir(feature_dir)]
502
+ )
503
+ opt = []
504
+ for name in names:
505
+ if if_f0_3:
506
+ opt.append(
507
+ "%s/%s.wav|%s/%s.npy|%s/%s.wav.npy|%s/%s.wav.npy|%s"
508
+ % (
509
+ gt_wavs_dir.replace("\\", "\\\\"),
510
+ name,
511
+ feature_dir.replace("\\", "\\\\"),
512
+ name,
513
+ f0_dir.replace("\\", "\\\\"),
514
+ name,
515
+ f0nsf_dir.replace("\\", "\\\\"),
516
+ name,
517
+ spk_id5,
518
+ )
519
+ )
520
+ else:
521
+ opt.append(
522
+ "%s/%s.wav|%s/%s.npy|%s"
523
+ % (
524
+ gt_wavs_dir.replace("\\", "\\\\"),
525
+ name,
526
+ feature_dir.replace("\\", "\\\\"),
527
+ name,
528
+ spk_id5,
529
+ )
530
+ )
531
+ fea_dim = 256 if version19 == "v1" else 768
532
+ if if_f0_3:
533
+ for _ in range(2):
534
+ opt.append(
535
+ "%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature%s/mute.npy|%s/logs/mute/2a_f0/mute.wav.npy|%s/logs/mute/2b-f0nsf/mute.wav.npy|%s"
536
+ % (now_dir, sr2, now_dir, fea_dim, now_dir, now_dir, spk_id5)
537
+ )
538
+ else:
539
+ for _ in range(2):
540
+ opt.append(
541
+ "%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature%s/mute.npy|%s"
542
+ % (now_dir, sr2, now_dir, fea_dim, spk_id5)
543
+ )
544
+ shuffle(opt)
545
+ with open("%s/filelist.txt" % exp_dir, "w") as f:
546
+ f.write("\n".join(opt))
547
+ logger.debug("Write filelist done")
548
+ # 生成config#无需生成config
549
+ # cmd = python_cmd + " train_nsf_sim_cache_sid_load_pretrain.py -e mi-test -sr 40k -f0 1 -bs 4 -g 0 -te 10 -se 5 -pg pretrained/f0G40k.pth -pd pretrained/f0D40k.pth -l 1 -c 0"
550
+ logger.info("Use gpus: %s", str(gpus16))
551
+ if pretrained_G14 == "":
552
+ logger.info("No pretrained Generator")
553
+ if pretrained_D15 == "":
554
+ logger.info("No pretrained Discriminator")
555
+ if version19 == "v1" or sr2 == "40k":
556
+ config_path = "v1/%s.json" % sr2
557
+ else:
558
+ config_path = "v2/%s.json" % sr2
559
+ config_save_path = os.path.join(exp_dir, "config.json")
560
+ if not pathlib.Path(config_save_path).exists():
561
+ with open(config_save_path, "w", encoding="utf-8") as f:
562
+ json.dump(
563
+ config.json_config[config_path],
564
+ f,
565
+ ensure_ascii=False,
566
+ indent=4,
567
+ sort_keys=True,
568
+ )
569
+ f.write("\n")
570
+ if gpus16:
571
+ cmd = (
572
+ '"%s" infer/modules/train/train.py -e "%s" -sr %s -f0 %s -bs %s -g %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s'
573
+ % (
574
+ config.python_cmd,
575
+ exp_dir1,
576
+ sr2,
577
+ 1 if if_f0_3 else 0,
578
+ batch_size12,
579
+ gpus16,
580
+ total_epoch11,
581
+ save_epoch10,
582
+ "-pg %s" % pretrained_G14 if pretrained_G14 != "" else "",
583
+ "-pd %s" % pretrained_D15 if pretrained_D15 != "" else "",
584
+ 1 if if_save_latest13 == i18n("是") else 0,
585
+ 1 if if_cache_gpu17 == i18n("是") else 0,
586
+ 1 if if_save_every_weights18 == i18n("是") else 0,
587
+ version19,
588
+ )
589
+ )
590
+ else:
591
+ cmd = (
592
+ '"%s" infer/modules/train/train.py -e "%s" -sr %s -f0 %s -bs %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s'
593
+ % (
594
+ config.python_cmd,
595
+ exp_dir1,
596
+ sr2,
597
+ 1 if if_f0_3 else 0,
598
+ batch_size12,
599
+ total_epoch11,
600
+ save_epoch10,
601
+ "-pg %s" % pretrained_G14 if pretrained_G14 != "" else "",
602
+ "-pd %s" % pretrained_D15 if pretrained_D15 != "" else "",
603
+ 1 if if_save_latest13 == i18n("是") else 0,
604
+ 1 if if_cache_gpu17 == i18n("是") else 0,
605
+ 1 if if_save_every_weights18 == i18n("是") else 0,
606
+ version19,
607
+ )
608
+ )
609
+ logger.info("Execute: " + cmd)
610
+ p = Popen(cmd, shell=True, cwd=now_dir)
611
+ p.wait()
612
+ return "训练结束, 您可查看控制台训练日志或实验文件夹下的train.log"
613
+
614
+
615
+ # but4.click(train_index, [exp_dir1], info3)
616
+ def train_index(exp_dir1, version19):
617
+ # exp_dir = "%s/logs/%s" % (now_dir, exp_dir1)
618
+ exp_dir = "logs/%s" % (exp_dir1)
619
+ os.makedirs(exp_dir, exist_ok=True)
620
+ feature_dir = (
621
+ "%s/3_feature256" % (exp_dir)
622
+ if version19 == "v1"
623
+ else "%s/3_feature768" % (exp_dir)
624
+ )
625
+ if not os.path.exists(feature_dir):
626
+ return "请先进行特征提取!"
627
+ listdir_res = list(os.listdir(feature_dir))
628
+ if len(listdir_res) == 0:
629
+ return "请先进行特征提取!"
630
+ infos = []
631
+ npys = []
632
+ for name in sorted(listdir_res):
633
+ phone = np.load("%s/%s" % (feature_dir, name))
634
+ npys.append(phone)
635
+ big_npy = np.concatenate(npys, 0)
636
+ big_npy_idx = np.arange(big_npy.shape[0])
637
+ np.random.shuffle(big_npy_idx)
638
+ big_npy = big_npy[big_npy_idx]
639
+ if big_npy.shape[0] > 2e5:
640
+ infos.append("Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0])
641
+ yield "\n".join(infos)
642
+ try:
643
+ big_npy = (
644
+ MiniBatchKMeans(
645
+ n_clusters=10000,
646
+ verbose=True,
647
+ batch_size=256 * config.n_cpu,
648
+ compute_labels=False,
649
+ init="random",
650
+ )
651
+ .fit(big_npy)
652
+ .cluster_centers_
653
+ )
654
+ except:
655
+ info = traceback.format_exc()
656
+ logger.info(info)
657
+ infos.append(info)
658
+ yield "\n".join(infos)
659
+
660
+ np.save("%s/total_fea.npy" % exp_dir, big_npy)
661
+ n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
662
+ infos.append("%s,%s" % (big_npy.shape, n_ivf))
663
+ yield "\n".join(infos)
664
+ index = faiss.index_factory(256 if version19 == "v1" else 768, "IVF%s,Flat" % n_ivf)
665
+ # index = faiss.index_factory(256if version19=="v1"else 768, "IVF%s,PQ128x4fs,RFlat"%n_ivf)
666
+ infos.append("training")
667
+ yield "\n".join(infos)
668
+ index_ivf = faiss.extract_index_ivf(index) #
669
+ index_ivf.nprobe = 1
670
+ index.train(big_npy)
671
+ faiss.write_index(
672
+ index,
673
+ "%s/trained_IVF%s_Flat_nprobe_%s_%s_%s.index"
674
+ % (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
675
+ )
676
+ infos.append("adding")
677
+ yield "\n".join(infos)
678
+ batch_size_add = 8192
679
+ for i in range(0, big_npy.shape[0], batch_size_add):
680
+ index.add(big_npy[i : i + batch_size_add])
681
+ faiss.write_index(
682
+ index,
683
+ "%s/added_IVF%s_Flat_nprobe_%s_%s_%s.index"
684
+ % (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
685
+ )
686
+ infos.append(
687
+ "成功构建索引 added_IVF%s_Flat_nprobe_%s_%s_%s.index"
688
+ % (n_ivf, index_ivf.nprobe, exp_dir1, version19)
689
+ )
690
+ try:
691
+ link = os.link if platform.system() == "Windows" else os.symlink
692
+ link(
693
+ "%s/added_IVF%s_Flat_nprobe_%s_%s_%s.index"
694
+ % (exp_dir, n_ivf, index_ivf.nprobe, exp_dir1, version19),
695
+ "%s/%s_IVF%s_Flat_nprobe_%s_%s_%s.index"
696
+ % (
697
+ outside_index_root,
698
+ exp_dir1,
699
+ n_ivf,
700
+ index_ivf.nprobe,
701
+ exp_dir1,
702
+ version19,
703
+ ),
704
+ )
705
+ infos.append("链接索引到外部-%s" % (outside_index_root))
706
+ except:
707
+ infos.append("链接索引到外部-%s失败" % (outside_index_root))
708
+
709
+ # faiss.write_index(index, '%s/added_IVF%s_Flat_FastScan_%s.index'%(exp_dir,n_ivf,version19))
710
+ # infos.append("成功构建索引,added_IVF%s_Flat_FastScan_%s.index"%(n_ivf,version19))
711
+ yield "\n".join(infos)
712
+
713
+
714
+ # but5.click(train1key, [exp_dir1, sr2, if_f0_3, trainset_dir4, spk_id5, gpus6, np7, f0method8, save_epoch10, total_epoch11, batch_size12, if_save_latest13, pretrained_G14, pretrained_D15, gpus16, if_cache_gpu17], info3)
715
+ def train1key(
716
+ exp_dir1,
717
+ sr2,
718
+ if_f0_3,
719
+ trainset_dir4,
720
+ spk_id5,
721
+ np7,
722
+ f0method8,
723
+ save_epoch10,
724
+ total_epoch11,
725
+ batch_size12,
726
+ if_save_latest13,
727
+ pretrained_G14,
728
+ pretrained_D15,
729
+ gpus16,
730
+ if_cache_gpu17,
731
+ if_save_every_weights18,
732
+ version19,
733
+ gpus_rmvpe,
734
+ ):
735
+ infos = []
736
+
737
+ def get_info_str(strr):
738
+ infos.append(strr)
739
+ return "\n".join(infos)
740
+
741
+ # step1:处理数据
742
+ yield get_info_str(i18n("step1:正在处理数据"))
743
+ [get_info_str(_) for _ in preprocess_dataset(trainset_dir4, exp_dir1, sr2, np7)]
744
+
745
+ # step2a:提取音高
746
+ yield get_info_str(i18n("step2:正在提取音高&正在提取特征"))
747
+ [
748
+ get_info_str(_)
749
+ for _ in extract_f0_feature(
750
+ gpus16, np7, f0method8, if_f0_3, exp_dir1, version19, gpus_rmvpe
751
+ )
752
+ ]
753
+
754
+ # step3a:训练模型
755
+ yield get_info_str(i18n("step3a:正在训练模型"))
756
+ click_train(
757
+ exp_dir1,
758
+ sr2,
759
+ if_f0_3,
760
+ spk_id5,
761
+ save_epoch10,
762
+ total_epoch11,
763
+ batch_size12,
764
+ if_save_latest13,
765
+ pretrained_G14,
766
+ pretrained_D15,
767
+ gpus16,
768
+ if_cache_gpu17,
769
+ if_save_every_weights18,
770
+ version19,
771
+ )
772
+ yield get_info_str(
773
+ i18n("训练结束, 您可查看控制台训练日志或实验文件夹下的train.log")
774
+ )
775
+
776
+ # step3b:训练索引
777
+ [get_info_str(_) for _ in train_index(exp_dir1, version19)]
778
+ yield get_info_str(i18n("全流程结束!"))
779
+
780
+
781
+ # ckpt_path2.change(change_info_,[ckpt_path2],[sr__,if_f0__])
782
+ def change_info_(ckpt_path):
783
+ if not os.path.exists(ckpt_path.replace(os.path.basename(ckpt_path), "train.log")):
784
+ return {"__type__": "update"}, {"__type__": "update"}, {"__type__": "update"}
785
+ try:
786
+ with open(
787
+ ckpt_path.replace(os.path.basename(ckpt_path), "train.log"), "r"
788
+ ) as f:
789
+ info = eval(f.read().strip("\n").split("\n")[0].split("\t")[-1])
790
+ sr, f0 = info["sample_rate"], info["if_f0"]
791
+ version = "v2" if ("version" in info and info["version"] == "v2") else "v1"
792
+ return sr, str(f0), version
793
+ except:
794
+ traceback.print_exc()
795
+ return {"__type__": "update"}, {"__type__": "update"}, {"__type__": "update"}
796
+
797
+
798
+ F0GPUVisible = config.dml == False
799
+
800
+
801
+ def change_f0_method(f0method8):
802
+ if f0method8 == "rmvpe_gpu":
803
+ visible = F0GPUVisible
804
+ else:
805
+ visible = False
806
+ return {"visible": visible, "__type__": "update"}
807
+
808
+
809
+ with gr.Blocks(title="RVC WebUI") as app:
810
+ gr.Markdown("## RVC WebUI")
811
+ gr.Markdown(
812
+ value=i18n(
813
+ "本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
814
+ )
815
+ )
816
+ with gr.Tabs():
817
+ with gr.TabItem(i18n("训练")):
818
+ gr.Markdown(
819
+ value=i18n(
820
+ "step1: 填写实验配置. 实验数据放在logs下, 每个实验一个文件夹, 需手工输入实验名路径, 内含实验配置, 日志, 训练得到的模型文件. "
821
+ )
822
+ )
823
+ with gr.Row():
824
+ exp_dir1 = gr.Textbox(label=i18n("输入实验名"), value="mi-test")
825
+ sr2 = gr.Radio(
826
+ label=i18n("目标采样率"),
827
+ choices=["40k", "48k"],
828
+ value="40k",
829
+ interactive=True,
830
+ )
831
+ if_f0_3 = gr.Radio(
832
+ label=i18n("模型是否带音高指导(唱歌一定要, 语音可以不要)"),
833
+ choices=[True, False],
834
+ value=True,
835
+ interactive=True,
836
+ )
837
+ version19 = gr.Radio(
838
+ label=i18n("版本"),
839
+ choices=["v1", "v2"],
840
+ value="v2",
841
+ interactive=True,
842
+ visible=True,
843
+ )
844
+ np7 = gr.Slider(
845
+ minimum=0,
846
+ maximum=config.n_cpu,
847
+ step=1,
848
+ label=i18n("提取音高和处理数据使用的CPU进程数"),
849
+ value=int(np.ceil(config.n_cpu / 1.5)),
850
+ interactive=True,
851
+ )
852
+ with gr.Group(): # 暂时单人的, 后面支持最多4人的#数据处理
853
+ gr.Markdown(
854
+ value=i18n(
855
+ "step2a: 自动遍历训练文件夹下所有可解码成音频的文件并进行切片归一化, 在实验目录下生成2个wav文件夹; 暂时只支持单人训练. "
856
+ )
857
+ )
858
+ with gr.Row():
859
+ trainset_dir4 = gr.Textbox(
860
+ label=i18n("输入训练文件夹路径"),
861
+ value=i18n("E:\\语音音频+标注\\米津玄师\\src"),
862
+ )
863
+ spk_id5 = gr.Slider(
864
+ minimum=0,
865
+ maximum=4,
866
+ step=1,
867
+ label=i18n("请指定说话人id"),
868
+ value=0,
869
+ interactive=True,
870
+ )
871
+ but1 = gr.Button(i18n("处理数据"), variant="primary")
872
+ info1 = gr.Textbox(label=i18n("输出信息"), value="")
873
+ but1.click(
874
+ preprocess_dataset,
875
+ [trainset_dir4, exp_dir1, sr2, np7],
876
+ [info1],
877
+ api_name="train_preprocess",
878
+ )
879
+ with gr.Group():
880
+ gr.Markdown(
881
+ value=i18n(
882
+ "step2b: 使用CPU提取音高(如果模型带音高), 使用GPU提取特征(选择卡号)"
883
+ )
884
+ )
885
+ with gr.Row():
886
+ with gr.Column():
887
+ gpus6 = gr.Textbox(
888
+ label=i18n(
889
+ "以-分隔输入使用的卡号, 例如 0-1-2 使用卡0和卡1和卡2"
890
+ ),
891
+ value=gpus,
892
+ interactive=True,
893
+ visible=F0GPUVisible,
894
+ )
895
+ gpu_info9 = gr.Textbox(
896
+ label=i18n("显卡信息"), value=gpu_info, visible=F0GPUVisible
897
+ )
898
+ with gr.Column():
899
+ f0method8 = gr.Radio(
900
+ label=i18n(
901
+ "选择音高提取算法:输入歌声可用pm提速,高质量语音但CPU差可用dio提速,harvest质量更好但慢,rmvpe效果最好且微吃CPU/GPU"
902
+ ),
903
+ choices=["pm", "harvest", "dio", "rmvpe", "rmvpe_gpu"],
904
+ value="rmvpe_gpu",
905
+ interactive=True,
906
+ )
907
+ gpus_rmvpe = gr.Textbox(
908
+ label=i18n(
909
+ "rmvpe卡号配置:以-分隔输入使用的不同进程卡号,例如0-0-1使用在卡0上跑2个进程并在卡1上跑1个进程"
910
+ ),
911
+ value="%s-%s" % (gpus, gpus),
912
+ interactive=True,
913
+ visible=F0GPUVisible,
914
+ )
915
+ but2 = gr.Button(i18n("特征提取"), variant="primary")
916
+ info2 = gr.Textbox(label=i18n("输出信息"), value="", max_lines=8)
917
+ f0method8.change(
918
+ fn=change_f0_method,
919
+ inputs=[f0method8],
920
+ outputs=[gpus_rmvpe],
921
+ )
922
+ but2.click(
923
+ extract_f0_feature,
924
+ [
925
+ gpus6,
926
+ np7,
927
+ f0method8,
928
+ if_f0_3,
929
+ exp_dir1,
930
+ version19,
931
+ gpus_rmvpe,
932
+ ],
933
+ [info2],
934
+ api_name="train_extract_f0_feature",
935
+ )
936
+ with gr.Group():
937
+ gr.Markdown(value=i18n("step3: 填写训练设置, 开始训练模型和索引"))
938
+ with gr.Row():
939
+ save_epoch10 = gr.Slider(
940
+ minimum=1,
941
+ maximum=50,
942
+ step=1,
943
+ label=i18n("保存频率save_every_epoch"),
944
+ value=5,
945
+ interactive=True,
946
+ )
947
+ total_epoch11 = gr.Slider(
948
+ minimum=2,
949
+ maximum=1000,
950
+ step=1,
951
+ label=i18n("总训练轮数total_epoch"),
952
+ value=20,
953
+ interactive=True,
954
+ )
955
+ batch_size12 = gr.Slider(
956
+ minimum=1,
957
+ maximum=40,
958
+ step=1,
959
+ label=i18n("每张显卡的batch_size"),
960
+ value=default_batch_size,
961
+ interactive=True,
962
+ )
963
+ if_save_latest13 = gr.Radio(
964
+ label=i18n("是否仅保存最新的ckpt文件以节省硬盘空间"),
965
+ choices=[i18n("是"), i18n("否")],
966
+ value=i18n("否"),
967
+ interactive=True,
968
+ )
969
+ if_cache_gpu17 = gr.Radio(
970
+ label=i18n(
971
+ "是否缓存所有训练集至显存. 10min以下小数据可缓存以加速训练, 大数据缓存会炸显存也加不了多少速"
972
+ ),
973
+ choices=[i18n("是"), i18n("否")],
974
+ value=i18n("否"),
975
+ interactive=True,
976
+ )
977
+ if_save_every_weights18 = gr.Radio(
978
+ label=i18n(
979
+ "是否在每次保存时间点将最终小模型保存至weights文件夹"
980
+ ),
981
+ choices=[i18n("是"), i18n("否")],
982
+ value=i18n("否"),
983
+ interactive=True,
984
+ )
985
+ with gr.Row():
986
+ pretrained_G14 = gr.Textbox(
987
+ label=i18n("加载预训练底模G路径"),
988
+ value="assets/pretrained_v2/f0G40k.pth",
989
+ interactive=True,
990
+ )
991
+ pretrained_D15 = gr.Textbox(
992
+ label=i18n("加载预训练底模D路径"),
993
+ value="assets/pretrained_v2/f0D40k.pth",
994
+ interactive=True,
995
+ )
996
+ sr2.change(
997
+ change_sr2,
998
+ [sr2, if_f0_3, version19],
999
+ [pretrained_G14, pretrained_D15],
1000
+ )
1001
+ version19.change(
1002
+ change_version19,
1003
+ [sr2, if_f0_3, version19],
1004
+ [pretrained_G14, pretrained_D15, sr2],
1005
+ )
1006
+ if_f0_3.change(
1007
+ change_f0,
1008
+ [if_f0_3, sr2, version19],
1009
+ [f0method8, gpus_rmvpe, pretrained_G14, pretrained_D15],
1010
+ )
1011
+ gpus16 = gr.Textbox(
1012
+ label=i18n(
1013
+ "以-分隔输入使用的卡号, 例如 0-1-2 使用卡0和卡1和卡2"
1014
+ ),
1015
+ value=gpus,
1016
+ interactive=True,
1017
+ )
1018
+ but3 = gr.Button(i18n("训练模型"), variant="primary")
1019
+ but4 = gr.Button(i18n("训练特征索引"), variant="primary")
1020
+ but5 = gr.Button(i18n("一键训练"), variant="primary")
1021
+ info3 = gr.Textbox(label=i18n("输出信息"), value="", max_lines=10)
1022
+ but3.click(
1023
+ click_train,
1024
+ [
1025
+ exp_dir1,
1026
+ sr2,
1027
+ if_f0_3,
1028
+ spk_id5,
1029
+ save_epoch10,
1030
+ total_epoch11,
1031
+ batch_size12,
1032
+ if_save_latest13,
1033
+ pretrained_G14,
1034
+ pretrained_D15,
1035
+ gpus16,
1036
+ if_cache_gpu17,
1037
+ if_save_every_weights18,
1038
+ version19,
1039
+ ],
1040
+ info3,
1041
+ api_name="train_start",
1042
+ )
1043
+ but4.click(train_index, [exp_dir1, version19], info3)
1044
+ but5.click(
1045
+ train1key,
1046
+ [
1047
+ exp_dir1,
1048
+ sr2,
1049
+ if_f0_3,
1050
+ trainset_dir4,
1051
+ spk_id5,
1052
+ np7,
1053
+ f0method8,
1054
+ save_epoch10,
1055
+ total_epoch11,
1056
+ batch_size12,
1057
+ if_save_latest13,
1058
+ pretrained_G14,
1059
+ pretrained_D15,
1060
+ gpus16,
1061
+ if_cache_gpu17,
1062
+ if_save_every_weights18,
1063
+ version19,
1064
+ gpus_rmvpe,
1065
+ ],
1066
+ info3,
1067
+ api_name="train_start_all",
1068
+ )
1069
+
1070
+ with gr.TabItem(i18n("ckpt处理")):
1071
+ with gr.Group():
1072
+ gr.Markdown(value=i18n("模型融合, 可用于测试音色融合"))
1073
+ with gr.Row():
1074
+ ckpt_a = gr.Textbox(
1075
+ label=i18n("A模型路径"), value="", interactive=True
1076
+ )
1077
+ ckpt_b = gr.Textbox(
1078
+ label=i18n("B模型路径"), value="", interactive=True
1079
+ )
1080
+ alpha_a = gr.Slider(
1081
+ minimum=0,
1082
+ maximum=1,
1083
+ label=i18n("A模型权重"),
1084
+ value=0.5,
1085
+ interactive=True,
1086
+ )
1087
+ with gr.Row():
1088
+ sr_ = gr.Radio(
1089
+ label=i18n("目标采样率"),
1090
+ choices=["40k", "48k"],
1091
+ value="40k",
1092
+ interactive=True,
1093
+ )
1094
+ if_f0_ = gr.Radio(
1095
+ label=i18n("模型是否带音高指导"),
1096
+ choices=[i18n("是"), i18n("否")],
1097
+ value=i18n("是"),
1098
+ interactive=True,
1099
+ )
1100
+ info__ = gr.Textbox(
1101
+ label=i18n("要置入的模型信息"),
1102
+ value="",
1103
+ max_lines=8,
1104
+ interactive=True,
1105
+ )
1106
+ name_to_save0 = gr.Textbox(
1107
+ label=i18n("保存的模型名不带后缀"),
1108
+ value="",
1109
+ max_lines=1,
1110
+ interactive=True,
1111
+ )
1112
+ version_2 = gr.Radio(
1113
+ label=i18n("模型版本型号"),
1114
+ choices=["v1", "v2"],
1115
+ value="v1",
1116
+ interactive=True,
1117
+ )
1118
+ with gr.Row():
1119
+ but6 = gr.Button(i18n("融合"), variant="primary")
1120
+ info4 = gr.Textbox(label=i18n("输出信息"), value="", max_lines=8)
1121
+ but6.click(
1122
+ merge,
1123
+ [
1124
+ ckpt_a,
1125
+ ckpt_b,
1126
+ alpha_a,
1127
+ sr_,
1128
+ if_f0_,
1129
+ info__,
1130
+ name_to_save0,
1131
+ version_2,
1132
+ ],
1133
+ info4,
1134
+ api_name="ckpt_merge",
1135
+ ) # def merge(path1,path2,alpha1,sr,f0,info):
1136
+ with gr.Group():
1137
+ gr.Markdown(
1138
+ value=i18n("修改模型信息(仅支持weights文件夹下提取的小模型文件)")
1139
+ )
1140
+ with gr.Row():
1141
+ ckpt_path0 = gr.Textbox(
1142
+ label=i18n("模型路径"), value="", interactive=True
1143
+ )
1144
+ info_ = gr.Textbox(
1145
+ label=i18n("要改的模型信息"),
1146
+ value="",
1147
+ max_lines=8,
1148
+ interactive=True,
1149
+ )
1150
+ name_to_save1 = gr.Textbox(
1151
+ label=i18n("保存的文件名, 默认空为和源文件同名"),
1152
+ value="",
1153
+ max_lines=8,
1154
+ interactive=True,
1155
+ )
1156
+ with gr.Row():
1157
+ but7 = gr.Button(i18n("修改"), variant="primary")
1158
+ info5 = gr.Textbox(label=i18n("输出信息"), value="", max_lines=8)
1159
+ but7.click(
1160
+ change_info,
1161
+ [ckpt_path0, info_, name_to_save1],
1162
+ info5,
1163
+ api_name="ckpt_modify",
1164
+ )
1165
+ with gr.Group():
1166
+ gr.Markdown(
1167
+ value=i18n("查看模型信息(仅支持weights文件夹下提取的小模型文件)")
1168
+ )
1169
+ with gr.Row():
1170
+ ckpt_path1 = gr.Textbox(
1171
+ label=i18n("模型路径"), value="", interactive=True
1172
+ )
1173
+ but8 = gr.Button(i18n("查看"), variant="primary")
1174
+ info6 = gr.Textbox(label=i18n("输出信息"), value="", max_lines=8)
1175
+ but8.click(show_info, [ckpt_path1], info6, api_name="ckpt_show")
1176
+ with gr.Group():
1177
+ gr.Markdown(
1178
+ value=i18n(
1179
+ "模型提取(输入logs文件夹下大文件模型路径),适用于训一半不想训了模型没有自动提取保存小文件模型,或者想测试中间模型的情况"
1180
+ )
1181
+ )
1182
+ with gr.Row():
1183
+ ckpt_path2 = gr.Textbox(
1184
+ label=i18n("模型路径"),
1185
+ value="E:\\codes\\py39\\logs\\mi-test_f0_48k\\G_23333.pth",
1186
+ interactive=True,
1187
+ )
1188
+ save_name = gr.Textbox(
1189
+ label=i18n("保存名"), value="", interactive=True
1190
+ )
1191
+ sr__ = gr.Radio(
1192
+ label=i18n("目标采样率"),
1193
+ choices=["32k", "40k", "48k"],
1194
+ value="40k",
1195
+ interactive=True,
1196
+ )
1197
+ if_f0__ = gr.Radio(
1198
+ label=i18n("模型是否带音高指导,1是0否"),
1199
+ choices=["1", "0"],
1200
+ value="1",
1201
+ interactive=True,
1202
+ )
1203
+ version_1 = gr.Radio(
1204
+ label=i18n("模型版本型号"),
1205
+ choices=["v1", "v2"],
1206
+ value="v2",
1207
+ interactive=True,
1208
+ )
1209
+ info___ = gr.Textbox(
1210
+ label=i18n("要置入的模型信息"),
1211
+ value="",
1212
+ max_lines=8,
1213
+ interactive=True,
1214
+ )
1215
+ but9 = gr.Button(i18n("提取"), variant="primary")
1216
+ info7 = gr.Textbox(label=i18n("输出信息"), value="", max_lines=8)
1217
+ ckpt_path2.change(
1218
+ change_info_, [ckpt_path2], [sr__, if_f0__, version_1]
1219
+ )
1220
+ but9.click(
1221
+ extract_small_model,
1222
+ [ckpt_path2, save_name, sr__, if_f0__, info___, version_1],
1223
+ info7,
1224
+ api_name="ckpt_extract",
1225
+ )
1226
+
1227
+ with gr.TabItem(i18n("Onnx导出")):
1228
+ with gr.Row():
1229
+ ckpt_dir = gr.Textbox(
1230
+ label=i18n("RVC模型路径"), value="", interactive=True
1231
+ )
1232
+ with gr.Row():
1233
+ onnx_dir = gr.Textbox(
1234
+ label=i18n("Onnx输出路径"), value="", interactive=True
1235
+ )
1236
+ with gr.Row():
1237
+ infoOnnx = gr.Label(label="info")
1238
+ with gr.Row():
1239
+ butOnnx = gr.Button(i18n("导出Onnx模型"), variant="primary")
1240
+ butOnnx.click(
1241
+ export_onnx, [ckpt_dir, onnx_dir], infoOnnx, api_name="export_onnx"
1242
+ )
1243
+
1244
+ tab_faq = i18n("常见问题解答")
1245
+ with gr.TabItem(tab_faq):
1246
+ try:
1247
+ if tab_faq == "常见问题解答":
1248
+ with open("docs/cn/faq.md", "r", encoding="utf8") as f:
1249
+ info = f.read()
1250
+ else:
1251
+ with open("docs/en/faq_en.md", "r", encoding="utf8") as f:
1252
+ info = f.read()
1253
+ gr.Markdown(value=info)
1254
+ except:
1255
+ gr.Markdown(traceback.format_exc())
1256
+
1257
+ if config.iscolab:
1258
+ app.queue(concurrency_count=511, max_size=1022).launch(share=True)
1259
+ else:
1260
+ app.queue(concurrency_count=511, max_size=1022).launch(
1261
+ server_name="0.0.0.0",
1262
+ inbrowser=not config.noautoopen,
1263
+ server_port=config.listen_port,
1264
+ quiet=True,
1265
+ )
infer/modules/train/train.py CHANGED
@@ -11,9 +11,9 @@ import datetime
11
 
12
  from infer.lib.train import utils
13
 
14
- hps = utils.get_hparams()
15
- os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",")
16
- n_gpus = len(hps.gpus.split("-"))
17
  from random import randint, shuffle
18
 
19
  import torch
@@ -54,18 +54,18 @@ from infer.lib.train.data_utils import (
54
  TextAudioLoaderMultiNSFsid,
55
  )
56
 
57
- if hps.version == "v1":
58
- from infer.lib.infer_pack.models import MultiPeriodDiscriminator
59
- from infer.lib.infer_pack.models import SynthesizerTrnMs256NSFsid as RVC_Model_f0
60
- from infer.lib.infer_pack.models import (
61
- SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
62
- )
63
- else:
64
- from infer.lib.infer_pack.models import (
65
- SynthesizerTrnMs768NSFsid as RVC_Model_f0,
66
- SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
67
- MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
68
- )
69
 
70
  from infer.lib.train.losses import (
71
  discriminator_loss,
@@ -76,8 +76,6 @@ from infer.lib.train.losses import (
76
  from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
77
  from infer.lib.train.process_ckpt import savee
78
 
79
- global_step = 0
80
-
81
 
82
  class EpochRecorder:
83
  def __init__(self):
@@ -92,33 +90,31 @@ class EpochRecorder:
92
  return f"[{current_time}] | ({elapsed_time_str})"
93
 
94
 
95
- def main():
96
- n_gpus = torch.cuda.device_count()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- if torch.cuda.is_available() == False and torch.backends.mps.is_available() == True:
99
- n_gpus = 1
100
- if n_gpus < 1:
101
- # patch to unblock people without gpus. there is probably a better way.
102
- print("NO GPU DETECTED: falling back to CPU - this may take a while")
103
- n_gpus = 1
104
- os.environ["MASTER_ADDR"] = "localhost"
105
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
106
- children = []
107
  logger = utils.get_logger(hps.model_dir)
108
- for i in range(n_gpus):
109
- subproc = mp.Process(
110
- target=run,
111
- args=(i, n_gpus, hps, logger),
112
- )
113
- children.append(subproc)
114
- subproc.start()
115
-
116
- for i in range(n_gpus):
117
- children[i].join()
118
 
119
 
120
- def run(rank, n_gpus, hps, logger: logging.Logger):
121
- global global_step
122
  if rank == 0:
123
  # logger = utils.get_logger(hps.model_dir)
124
  logger.info(hps)
@@ -215,13 +211,14 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
215
  _, _, _, epoch_str = utils.load_checkpoint(
216
  utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
217
  )
218
- global_step = (epoch_str - 1) * len(train_loader)
 
219
  # epoch_str = 1
220
  # global_step = 0
221
  except: # 如果首次不能加载,加载pretrain
222
  # traceback.print_exc()
223
  epoch_str = 1
224
- global_step = 0
225
  if hps.pretrainG != "":
226
  if rank == 0:
227
  logger.info("loaded pretrained %s" % (hps.pretrainG))
@@ -252,6 +249,7 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
252
  torch.load(hps.pretrainD, map_location="cpu")["model"]
253
  )
254
  )
 
255
 
256
  scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
257
  optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
@@ -264,6 +262,7 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
264
 
265
  cache = []
266
  for epoch in range(epoch_str, hps.train.epochs + 1):
 
267
  if rank == 0:
268
  train_and_evaluate(
269
  rank,
@@ -277,6 +276,7 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
277
  logger,
278
  [writer, writer_eval],
279
  cache,
 
280
  )
281
  else:
282
  train_and_evaluate(
@@ -291,13 +291,25 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
291
  None,
292
  None,
293
  cache,
 
294
  )
295
  scheduler_g.step()
296
  scheduler_d.step()
297
 
298
 
299
  def train_and_evaluate(
300
- rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache
 
 
 
 
 
 
 
 
 
 
 
301
  ):
302
  net_g, net_d = nets
303
  optim_g, optim_d = optims
@@ -306,7 +318,6 @@ def train_and_evaluate(
306
  writer, writer_eval = writers
307
 
308
  train_loader.batch_sampler.set_epoch(epoch)
309
- global global_step
310
 
311
  net_g.train()
312
  net_d.train()
@@ -500,7 +511,7 @@ def train_and_evaluate(
500
  scaler.update()
501
 
502
  if rank == 0:
503
- if global_step % hps.train.log_interval == 0:
504
  lr = optim_g.param_groups[0]["lr"]
505
  logger.info(
506
  "Train Epoch: {} [{:.0f}%]".format(
@@ -513,7 +524,7 @@ def train_and_evaluate(
513
  if loss_kl > 9:
514
  loss_kl = 9
515
 
516
- logger.info([global_step, lr])
517
  logger.info(
518
  f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
519
  )
@@ -554,11 +565,11 @@ def train_and_evaluate(
554
  }
555
  utils.summarize(
556
  writer=writer,
557
- global_step=global_step,
558
  images=image_dict,
559
  scalars=scalar_dict,
560
  )
561
- global_step += 1
562
  # /Run steps
563
 
564
  if epoch % hps.save_every_epoch == 0 and rank == 0:
@@ -568,14 +579,14 @@ def train_and_evaluate(
568
  optim_g,
569
  hps.train.learning_rate,
570
  epoch,
571
- os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
572
  )
573
  utils.save_checkpoint(
574
  net_d,
575
  optim_d,
576
  hps.train.learning_rate,
577
  epoch,
578
- os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
579
  )
580
  else:
581
  utils.save_checkpoint(
@@ -606,7 +617,7 @@ def train_and_evaluate(
606
  ckpt,
607
  hps.sample_rate,
608
  hps.if_f0,
609
- hps.name + "_e%s_s%s" % (epoch, global_step),
610
  epoch,
611
  hps.version,
612
  hps,
@@ -633,8 +644,3 @@ def train_and_evaluate(
633
  )
634
  sleep(1)
635
  os._exit(2333333)
636
-
637
-
638
- if __name__ == "__main__":
639
- torch.multiprocessing.set_start_method("spawn")
640
- main()
 
11
 
12
  from infer.lib.train import utils
13
 
14
+ # hps = utils.get_hparams()
15
+ # os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",")
16
+ # n_gpus = len(hps.gpus.split("-"))
17
  from random import randint, shuffle
18
 
19
  import torch
 
54
  TextAudioLoaderMultiNSFsid,
55
  )
56
 
57
+ # if hps.version == "v1":
58
+ # from infer.lib.infer_pack.models import MultiPeriodDiscriminator
59
+ # from infer.lib.infer_pack.models import SynthesizerTrnMs256NSFsid as RVC_Model_f0
60
+ # from infer.lib.infer_pack.models import (
61
+ # SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
62
+ # )
63
+ # else:
64
+ from infer.lib.infer_pack.models import (
65
+ SynthesizerTrnMs768NSFsid as RVC_Model_f0,
66
+ SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
67
+ MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
68
+ )
69
 
70
  from infer.lib.train.losses import (
71
  discriminator_loss,
 
76
  from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
77
  from infer.lib.train.process_ckpt import savee
78
 
 
 
79
 
80
  class EpochRecorder:
81
  def __init__(self):
 
90
  return f"[{current_time}] | ({elapsed_time_str})"
91
 
92
 
93
+ def train(exp_dir: str):
94
+ state = {"global_step": 0}
95
+
96
+ hps = utils.get_hparams_from_dir(exp_dir)
97
+ hps.experiment_dir = exp_dir
98
+ hps.save_every_epoch = False
99
+ hps.name = os.path.basename(exp_dir)
100
+ hps.total_epoch = 100
101
+ hps.pretrainG = ""
102
+ hps.pretrainD = ""
103
+ hps.version = "v2"
104
+ hps.gpus = "0"
105
+ hps.train.batch_size = 8
106
+ hps.sample_rate = "40k"
107
+ hps.if_f0 = 1
108
+ hps.if_latest = 1
109
+ hps.save_every_weights = "0"
110
+ hps.if_cache_data_in_gpu = True
111
+ hps.data.training_files = "%s/filelist.txt" % exp_dir
112
 
 
 
 
 
 
 
 
 
 
113
  logger = utils.get_logger(hps.model_dir)
114
+ run(0, 1, hps, logger, state)
 
 
 
 
 
 
 
 
 
115
 
116
 
117
+ def run(rank, n_gpus, hps, logger: logging.Logger, state):
 
118
  if rank == 0:
119
  # logger = utils.get_logger(hps.model_dir)
120
  logger.info(hps)
 
211
  _, _, _, epoch_str = utils.load_checkpoint(
212
  utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
213
  )
214
+ state["global_step"] = (epoch_str - 1) * len(train_loader)
215
+ print("loaded", epoch_str)
216
  # epoch_str = 1
217
  # global_step = 0
218
  except: # 如果首次不能加载,加载pretrain
219
  # traceback.print_exc()
220
  epoch_str = 1
221
+ state["global_step"] = 0
222
  if hps.pretrainG != "":
223
  if rank == 0:
224
  logger.info("loaded pretrained %s" % (hps.pretrainG))
 
249
  torch.load(hps.pretrainD, map_location="cpu")["model"]
250
  )
251
  )
252
+ print("new")
253
 
254
  scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
255
  optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
 
262
 
263
  cache = []
264
  for epoch in range(epoch_str, hps.train.epochs + 1):
265
+ print("epoch", epoch)
266
  if rank == 0:
267
  train_and_evaluate(
268
  rank,
 
276
  logger,
277
  [writer, writer_eval],
278
  cache,
279
+ state,
280
  )
281
  else:
282
  train_and_evaluate(
 
291
  None,
292
  None,
293
  cache,
294
+ state,
295
  )
296
  scheduler_g.step()
297
  scheduler_d.step()
298
 
299
 
300
  def train_and_evaluate(
301
+ rank,
302
+ epoch,
303
+ hps,
304
+ nets,
305
+ optims,
306
+ schedulers,
307
+ scaler,
308
+ loaders,
309
+ logger,
310
+ writers,
311
+ cache,
312
+ state,
313
  ):
314
  net_g, net_d = nets
315
  optim_g, optim_d = optims
 
318
  writer, writer_eval = writers
319
 
320
  train_loader.batch_sampler.set_epoch(epoch)
 
321
 
322
  net_g.train()
323
  net_d.train()
 
511
  scaler.update()
512
 
513
  if rank == 0:
514
+ if state["global_step"] % hps.train.log_interval == 0:
515
  lr = optim_g.param_groups[0]["lr"]
516
  logger.info(
517
  "Train Epoch: {} [{:.0f}%]".format(
 
524
  if loss_kl > 9:
525
  loss_kl = 9
526
 
527
+ logger.info([state["global_step"], lr])
528
  logger.info(
529
  f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
530
  )
 
565
  }
566
  utils.summarize(
567
  writer=writer,
568
+ global_step=state["global_step"],
569
  images=image_dict,
570
  scalars=scalar_dict,
571
  )
572
+ state["global_step"] += 1
573
  # /Run steps
574
 
575
  if epoch % hps.save_every_epoch == 0 and rank == 0:
 
579
  optim_g,
580
  hps.train.learning_rate,
581
  epoch,
582
+ os.path.join(hps.model_dir, "G_{}.pth".format(state["global_step"])),
583
  )
584
  utils.save_checkpoint(
585
  net_d,
586
  optim_d,
587
  hps.train.learning_rate,
588
  epoch,
589
+ os.path.join(hps.model_dir, "D_{}.pth".format(state["global_step"])),
590
  )
591
  else:
592
  utils.save_checkpoint(
 
617
  ckpt,
618
  hps.sample_rate,
619
  hps.if_f0,
620
+ hps.name + "_e%s_s%s" % (epoch, state["global_step"]),
621
  epoch,
622
  hps.version,
623
  hps,
 
644
  )
645
  sleep(1)
646
  os._exit(2333333)
 
 
 
 
 
logs/mute/0_gt_wavs/mute32k.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9edcf85ec77e88bd01edf3d887bdc418d3596d573f7ad2694da546f41dae6baf
3
+ size 192078
logs/mute/0_gt_wavs/mute40k.spec.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bbca900dcff32be4d664383a705d53ebc6829027ac8a07d78308d472c9087a1
3
+ size 1230339
logs/mute/0_gt_wavs/mute40k.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67a816e77b50cb9f016e49e5c01f07e080c4e3b82b7a8ac3e64bcb143f90f31b
3
+ size 240078
logs/mute/0_gt_wavs/mute48k.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f2bb4daaa106e351aebb001e5a25de985c0b472f22e8d60676bc924a79056ee
3
+ size 288078
logs/mute/1_16k_wavs/mute.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e233e86ba1be365e1133f157d56b61110086b89650ecfbdfc013c759e466250
3
+ size 96078
logs/mute/2a_f0/mute.wav.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b9acf9ab7facdb032e1d687fe35182670b0b94566c4b209ae48c239d19956a6
3
+ size 1332
logs/mute/2b-f0nsf/mute.wav.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30792849c8e72d67e6691754077f2888b101cb741e9c7f193c91dd9692870c87
3
+ size 2536
logs/mute/3_feature256/mute.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64d5abbac078e19a3f649c0d78a02cb33a71407ded3ddf2db78e6b803d0c0126
3
+ size 152704
logs/mute/3_feature768/mute.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16ef62b957887ac9f0913aa5158f18983afff1ef5a3e4c5fd067ac20fc380d54
3
+ size 457856