ZeroRVC / app /train.py
JacobLinCool's picture
perf: set batch size to 128
a0da4cc
raw
history blame
6.37 kB
import os
import shutil
import traceback
import faiss
import gradio as gr
import numpy as np
from sklearn.cluster import MiniBatchKMeans
from random import shuffle
from glob import glob
from infer.modules.train.train import train
from zero import zero
def write_filelist(exp_dir: str) -> None:
if_f0_3 = True
spk_id5 = 0
gt_wavs_dir = "%s/0_gt_wavs" % (exp_dir)
feature_dir = "%s/3_feature768" % (exp_dir)
if if_f0_3:
f0_dir = "%s/2a_f0" % (exp_dir)
f0nsf_dir = "%s/2b-f0nsf" % (exp_dir)
names = (
set([name.split(".")[0] for name in os.listdir(gt_wavs_dir)])
& set([name.split(".")[0] for name in os.listdir(feature_dir)])
& set([name.split(".")[0] for name in os.listdir(f0_dir)])
& set([name.split(".")[0] for name in os.listdir(f0nsf_dir)])
)
else:
names = set([name.split(".")[0] for name in os.listdir(gt_wavs_dir)]) & set(
[name.split(".")[0] for name in os.listdir(feature_dir)]
)
opt = []
for name in names:
if if_f0_3:
opt.append(
"%s/%s.wav|%s/%s.npy|%s/%s.wav.npy|%s/%s.wav.npy|%s"
% (
gt_wavs_dir.replace("\\", "\\\\"),
name,
feature_dir.replace("\\", "\\\\"),
name,
f0_dir.replace("\\", "\\\\"),
name,
f0nsf_dir.replace("\\", "\\\\"),
name,
spk_id5,
)
)
else:
opt.append(
"%s/%s.wav|%s/%s.npy|%s"
% (
gt_wavs_dir.replace("\\", "\\\\"),
name,
feature_dir.replace("\\", "\\\\"),
name,
spk_id5,
)
)
fea_dim = 768
now_dir = os.getcwd()
sr2 = "40k"
if if_f0_3:
for _ in range(2):
opt.append(
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature%s/mute.npy|%s/logs/mute/2a_f0/mute.wav.npy|%s/logs/mute/2b-f0nsf/mute.wav.npy|%s"
% (now_dir, sr2, now_dir, fea_dim, now_dir, now_dir, spk_id5)
)
else:
for _ in range(2):
opt.append(
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature%s/mute.npy|%s"
% (now_dir, sr2, now_dir, fea_dim, spk_id5)
)
shuffle(opt)
with open("%s/filelist.txt" % exp_dir, "w") as f:
f.write("\n".join(opt))
@zero(duration=240)
def train_model(exp_dir: str) -> str:
shutil.copy("config.json", exp_dir)
write_filelist(exp_dir)
train(exp_dir)
models = glob(f"{exp_dir}/G_*.pth")
print(models)
if not models:
raise gr.Error("No model found")
latest_model = max(models, key=os.path.getctime)
return latest_model
def train_index(exp_dir: str) -> str:
feature_dir = "%s/3_feature768" % (exp_dir)
if not os.path.exists(feature_dir):
raise gr.Error("Please extract features first.")
listdir_res = list(os.listdir(feature_dir))
if len(listdir_res) == 0:
raise gr.Error("Please extract features first.")
npys = []
for name in sorted(listdir_res):
phone = np.load("%s/%s" % (feature_dir, name))
npys.append(phone)
big_npy = np.concatenate(npys, 0)
big_npy_idx = np.arange(big_npy.shape[0])
np.random.shuffle(big_npy_idx)
big_npy = big_npy[big_npy_idx]
if big_npy.shape[0] > 2e5:
print("Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0])
try:
big_npy = (
MiniBatchKMeans(
n_clusters=10000,
verbose=True,
batch_size=256 * 8,
compute_labels=False,
init="random",
)
.fit(big_npy)
.cluster_centers_
)
except:
info = traceback.format_exc()
print(info)
raise gr.Error(info)
np.save("%s/total_fea.npy" % exp_dir, big_npy)
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
print("%s,%s" % (big_npy.shape, n_ivf))
index = faiss.index_factory(768, "IVF%s,Flat" % n_ivf)
# index = faiss.index_factory(256if version19=="v1"else 768, "IVF%s,PQ128x4fs,RFlat"%n_ivf)
print("training")
index_ivf = faiss.extract_index_ivf(index) #
index_ivf.nprobe = 1
index.train(big_npy)
faiss.write_index(
index,
"%s/trained_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe),
)
print("adding")
batch_size_add = 8192
for i in range(0, big_npy.shape[0], batch_size_add):
index.add(big_npy[i : i + batch_size_add])
faiss.write_index(
index,
"%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe),
)
print("built added_IVF%s_Flat_nprobe_%s.index" % (n_ivf, index_ivf.nprobe))
return "%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe)
class TrainTab:
def __init__(self):
pass
def ui(self):
gr.Markdown("# Training")
gr.Markdown(
"You can start training the model by clicking the button below. "
"Each time you click the button, the model will train for 20 epochs, which takes about 10 minutes on ZeroGPU (A100). "
"Tha latest *training checkpoint* will be avaible below."
)
with gr.Row():
self.train_btn = gr.Button(value="Train", variant="primary")
self.latest_checkpoint = gr.File(label="Latest checkpoint")
with gr.Row():
self.train_index_btn = gr.Button(value="Train index", variant="primary")
self.trained_index = gr.File(label="Trained index")
def build(self, exp_dir: gr.Textbox):
self.train_btn.click(
fn=train_model,
inputs=[exp_dir],
outputs=[self.latest_checkpoint],
).success(
fn=train_model,
inputs=[exp_dir],
outputs=[self.latest_checkpoint],
)
self.train_index_btn.click(
fn=train_index,
inputs=[exp_dir],
outputs=[self.trained_index],
)