import os import shutil import traceback import faiss import gradio as gr import numpy as np from sklearn.cluster import MiniBatchKMeans from random import shuffle from glob import glob from infer.modules.train.train import train from zero import zero def write_filelist(exp_dir: str) -> None: if_f0_3 = True spk_id5 = 0 gt_wavs_dir = "%s/0_gt_wavs" % (exp_dir) feature_dir = "%s/3_feature768" % (exp_dir) if if_f0_3: f0_dir = "%s/2a_f0" % (exp_dir) f0nsf_dir = "%s/2b-f0nsf" % (exp_dir) names = ( set([name.split(".")[0] for name in os.listdir(gt_wavs_dir)]) & set([name.split(".")[0] for name in os.listdir(feature_dir)]) & set([name.split(".")[0] for name in os.listdir(f0_dir)]) & set([name.split(".")[0] for name in os.listdir(f0nsf_dir)]) ) else: names = set([name.split(".")[0] for name in os.listdir(gt_wavs_dir)]) & set( [name.split(".")[0] for name in os.listdir(feature_dir)] ) opt = [] for name in names: if if_f0_3: opt.append( "%s/%s.wav|%s/%s.npy|%s/%s.wav.npy|%s/%s.wav.npy|%s" % ( gt_wavs_dir.replace("\\", "\\\\"), name, feature_dir.replace("\\", "\\\\"), name, f0_dir.replace("\\", "\\\\"), name, f0nsf_dir.replace("\\", "\\\\"), name, spk_id5, ) ) else: opt.append( "%s/%s.wav|%s/%s.npy|%s" % ( gt_wavs_dir.replace("\\", "\\\\"), name, feature_dir.replace("\\", "\\\\"), name, spk_id5, ) ) fea_dim = 768 now_dir = os.getcwd() sr2 = "40k" if if_f0_3: for _ in range(2): opt.append( "%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature%s/mute.npy|%s/logs/mute/2a_f0/mute.wav.npy|%s/logs/mute/2b-f0nsf/mute.wav.npy|%s" % (now_dir, sr2, now_dir, fea_dim, now_dir, now_dir, spk_id5) ) else: for _ in range(2): opt.append( "%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature%s/mute.npy|%s" % (now_dir, sr2, now_dir, fea_dim, spk_id5) ) shuffle(opt) with open("%s/filelist.txt" % exp_dir, "w") as f: f.write("\n".join(opt)) @zero(duration=240) def train_model(exp_dir: str) -> str: shutil.copy("config.json", exp_dir) write_filelist(exp_dir) train(exp_dir) models = glob(f"{exp_dir}/G_*.pth") print(models) if not models: raise gr.Error("No model found") latest_model = max(models, key=os.path.getctime) return latest_model def train_index(exp_dir: str) -> str: feature_dir = "%s/3_feature768" % (exp_dir) if not os.path.exists(feature_dir): raise gr.Error("Please extract features first.") listdir_res = list(os.listdir(feature_dir)) if len(listdir_res) == 0: raise gr.Error("Please extract features first.") npys = [] for name in sorted(listdir_res): phone = np.load("%s/%s" % (feature_dir, name)) npys.append(phone) big_npy = np.concatenate(npys, 0) big_npy_idx = np.arange(big_npy.shape[0]) np.random.shuffle(big_npy_idx) big_npy = big_npy[big_npy_idx] if big_npy.shape[0] > 2e5: print("Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0]) try: big_npy = ( MiniBatchKMeans( n_clusters=10000, verbose=True, batch_size=256 * 8, compute_labels=False, init="random", ) .fit(big_npy) .cluster_centers_ ) except: info = traceback.format_exc() print(info) raise gr.Error(info) np.save("%s/total_fea.npy" % exp_dir, big_npy) n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39) print("%s,%s" % (big_npy.shape, n_ivf)) index = faiss.index_factory(768, "IVF%s,Flat" % n_ivf) # index = faiss.index_factory(256if version19=="v1"else 768, "IVF%s,PQ128x4fs,RFlat"%n_ivf) print("training") index_ivf = faiss.extract_index_ivf(index) # index_ivf.nprobe = 1 index.train(big_npy) faiss.write_index( index, "%s/trained_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe), ) print("adding") batch_size_add = 8192 for i in range(0, big_npy.shape[0], batch_size_add): index.add(big_npy[i : i + batch_size_add]) faiss.write_index( index, "%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe), ) print("built added_IVF%s_Flat_nprobe_%s.index" % (n_ivf, index_ivf.nprobe)) return "%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe) class TrainTab: def __init__(self): pass def ui(self): gr.Markdown("# Training") gr.Markdown( "You can start training the model by clicking the button below. " "Each time you click the button, the model will train for 10 epochs, which takes about 3 minutes on ZeroGPU (A100). " "Tha latest *training checkpoint* will be avaible below." ) with gr.Row(): self.train_btn = gr.Button(value="Train", variant="primary") self.latest_checkpoint = gr.File(label="Latest checkpoint") with gr.Row(): self.train_index_btn = gr.Button(value="Train index", variant="primary") self.trained_index = gr.File(label="Trained index") def build(self, exp_dir: gr.Textbox): self.train_btn.click( fn=train_model, inputs=[exp_dir], outputs=[self.latest_checkpoint], ) self.train_index_btn.click( fn=train_index, inputs=[exp_dir], outputs=[self.trained_index], )