ZeroRVC / app.py
JacobLinCool's picture
feat: infer
3a010aa
raw
history blame
13.7 kB
from typing import Tuple
from prelude import prelude
prelude()
import os
import traceback
import numpy as np
from sklearn.cluster import MiniBatchKMeans
from random import shuffle
import gradio as gr
import zipfile
import tempfile
import shutil
import faiss
from glob import glob
from infer.modules.train.preprocess import PreProcess
from infer.modules.train.extract.extract_f0_rmvpe import FeatureInput
from infer.modules.train.extract_feature_print import HubertFeatureExtractor
from infer.modules.train.train import train
from infer.lib.train.process_ckpt import extract_small_model
from infer.modules.vc.modules import VC
from configs.config import Config
import demucs.separate
import soundfile as sf
from zero import zero
from model import device
def extract_audio_files(zip_file: str, target_dir: str) -> list[str]:
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(target_dir)
audio_files = [
os.path.join(target_dir, f)
for f in os.listdir(target_dir)
if f.endswith((".wav", ".mp3", ".ogg"))
]
if not audio_files:
raise gr.Error("No audio files found at the top level of the zip file")
return audio_files
def preprocess(zip_file: str) -> str:
temp_dir = tempfile.mkdtemp()
print(f"Using exp dir: {temp_dir}")
data_dir = os.path.join(temp_dir, "_data")
os.makedirs(data_dir)
audio_files = extract_audio_files(zip_file, data_dir)
pp = PreProcess(40000, temp_dir, 3.0, False)
pp.pipeline_mp_inp_dir(data_dir, 4)
pp.logfile.seek(0)
log = pp.logfile.read()
return temp_dir, f"Preprocessed {len(audio_files)} audio files.\n{log}"
@zero(duration=300)
def extract_features(exp_dir: str) -> str:
err = None
fi = FeatureInput(exp_dir)
try:
fi.run()
except Exception as e:
err = e
fi.logfile.seek(0)
log = fi.logfile.read()
if err:
log = f"Error: {err}\n{log}"
return log
hfe = HubertFeatureExtractor(exp_dir)
try:
hfe.run()
except Exception as e:
err = e
hfe.logfile.seek(0)
log += hfe.logfile.read()
if err:
log = f"Error: {err}\n{log}"
return log
def write_filelist(exp_dir: str) -> None:
if_f0_3 = True
spk_id5 = 0
gt_wavs_dir = "%s/0_gt_wavs" % (exp_dir)
feature_dir = "%s/3_feature768" % (exp_dir)
if if_f0_3:
f0_dir = "%s/2a_f0" % (exp_dir)
f0nsf_dir = "%s/2b-f0nsf" % (exp_dir)
names = (
set([name.split(".")[0] for name in os.listdir(gt_wavs_dir)])
& set([name.split(".")[0] for name in os.listdir(feature_dir)])
& set([name.split(".")[0] for name in os.listdir(f0_dir)])
& set([name.split(".")[0] for name in os.listdir(f0nsf_dir)])
)
else:
names = set([name.split(".")[0] for name in os.listdir(gt_wavs_dir)]) & set(
[name.split(".")[0] for name in os.listdir(feature_dir)]
)
opt = []
for name in names:
if if_f0_3:
opt.append(
"%s/%s.wav|%s/%s.npy|%s/%s.wav.npy|%s/%s.wav.npy|%s"
% (
gt_wavs_dir.replace("\\", "\\\\"),
name,
feature_dir.replace("\\", "\\\\"),
name,
f0_dir.replace("\\", "\\\\"),
name,
f0nsf_dir.replace("\\", "\\\\"),
name,
spk_id5,
)
)
else:
opt.append(
"%s/%s.wav|%s/%s.npy|%s"
% (
gt_wavs_dir.replace("\\", "\\\\"),
name,
feature_dir.replace("\\", "\\\\"),
name,
spk_id5,
)
)
fea_dim = 768
now_dir = os.getcwd()
sr2 = "40k"
if if_f0_3:
for _ in range(2):
opt.append(
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature%s/mute.npy|%s/logs/mute/2a_f0/mute.wav.npy|%s/logs/mute/2b-f0nsf/mute.wav.npy|%s"
% (now_dir, sr2, now_dir, fea_dim, now_dir, now_dir, spk_id5)
)
else:
for _ in range(2):
opt.append(
"%s/logs/mute/0_gt_wavs/mute%s.wav|%s/logs/mute/3_feature%s/mute.npy|%s"
% (now_dir, sr2, now_dir, fea_dim, spk_id5)
)
shuffle(opt)
with open("%s/filelist.txt" % exp_dir, "w") as f:
f.write("\n".join(opt))
@zero(duration=300)
def train_model(exp_dir: str) -> str:
shutil.copy("config.json", exp_dir)
write_filelist(exp_dir)
train(exp_dir)
models = glob(f"{exp_dir}/G_*.pth")
print(models)
if not models:
raise gr.Error("No model found")
latest_model = max(models, key=os.path.getctime)
return latest_model
def download_weight(exp_dir: str) -> str:
models = glob(f"{exp_dir}/G_*.pth")
if not models:
raise gr.Error("No model found")
latest_model = max(models, key=os.path.getctime)
print(f"Latest model: {latest_model}")
name = os.path.basename(exp_dir)
out = os.path.join(exp_dir, f"{name}.pth")
extract_small_model(
latest_model, out, "40k", True, "Model trained by ZeroGPU.", "v2"
)
return out
def train_index(exp_dir: str) -> str:
feature_dir = "%s/3_feature768" % (exp_dir)
if not os.path.exists(feature_dir):
raise gr.Error("Please extract features first.")
listdir_res = list(os.listdir(feature_dir))
if len(listdir_res) == 0:
raise gr.Error("Please extract features first.")
npys = []
for name in sorted(listdir_res):
phone = np.load("%s/%s" % (feature_dir, name))
npys.append(phone)
big_npy = np.concatenate(npys, 0)
big_npy_idx = np.arange(big_npy.shape[0])
np.random.shuffle(big_npy_idx)
big_npy = big_npy[big_npy_idx]
if big_npy.shape[0] > 2e5:
print("Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0])
try:
big_npy = (
MiniBatchKMeans(
n_clusters=10000,
verbose=True,
batch_size=256 * 8,
compute_labels=False,
init="random",
)
.fit(big_npy)
.cluster_centers_
)
except:
info = traceback.format_exc()
print(info)
raise gr.Error(info)
np.save("%s/total_fea.npy" % exp_dir, big_npy)
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
print("%s,%s" % (big_npy.shape, n_ivf))
index = faiss.index_factory(768, "IVF%s,Flat" % n_ivf)
# index = faiss.index_factory(256if version19=="v1"else 768, "IVF%s,PQ128x4fs,RFlat"%n_ivf)
print("training")
index_ivf = faiss.extract_index_ivf(index) #
index_ivf.nprobe = 1
index.train(big_npy)
faiss.write_index(
index,
"%s/trained_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe),
)
print("adding")
batch_size_add = 8192
for i in range(0, big_npy.shape[0], batch_size_add):
index.add(big_npy[i : i + batch_size_add])
faiss.write_index(
index,
"%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe),
)
print("built added_IVF%s_Flat_nprobe_%s.index" % (n_ivf, index_ivf.nprobe))
return "%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe)
def download_expdir(exp_dir: str) -> str:
shutil.make_archive(exp_dir, "zip", exp_dir)
return f"{exp_dir}.zip"
def restore_expdir(zip: str) -> str:
exp_dir = tempfile.mkdtemp()
shutil.unpack_archive(zip, exp_dir)
return exp_dir
@zero(duration=120)
def infer(exp_dir: str, original_audio: str, f0add: int) -> Tuple[int, np.ndarray]:
name = os.path.basename(exp_dir)
model = os.path.join(exp_dir, f"{name}.pth")
if not os.path.exists(model):
raise gr.Error("Model not found")
index = glob(f"{exp_dir}/added_*.index")
if not index:
raise gr.Error("Index not found")
base = os.path.basename(original_audio)
base = os.path.splitext(base)[0]
demucs.separate.main(
["--two-stems", "vocals", "-d", str(device), "-n", "htdemucs", original_audio]
)
out = os.path.join("separated", "htdemucs", base, "vocals.wav")
cfg = Config()
vc = VC(cfg)
vc.get_vc(model)
_, wav_opt = vc.vc_single(
0,
out,
f0add,
None,
"rmvpe",
index,
None,
0.5,
3,
0,
1,
0.33,
)
sr = wav_opt[0]
data = wav_opt[1]
return sr, data
def merge(exp_dir: str, original_audio: str, vocal: Tuple[int, np.ndarray]) -> str:
base = os.path.basename(original_audio)
base = os.path.splitext(base)[0]
music = os.path.join("separated", "htdemucs", base, "no-vocals.wav")
tmp = os.path.join(exp_dir, "tmp.wav")
sf.write(tmp, vocal[1], vocal[0])
os.system(
f"ffmpeg -i {music} -i {tmp} -filter_complex '[1]volume=2[a];[0][a]amix=inputs=2:duration=first:dropout_transition=2' {tmp}.merged.mp3"
)
return f"{tmp}.merged.mp3"
with gr.Blocks() as app:
# allow user to manually select the experiment directory
exp_dir = gr.Textbox(
label="Experiment directory (don't touch it unless you know what you are doing)",
visible=True,
interactive=True,
)
with gr.Tabs():
with gr.Tab(label="New / Restore"):
with gr.Row():
with gr.Column():
zip_file = gr.File(
label="Upload a zip file containing audio files for training",
file_types=["zip"],
)
preprocess_output = gr.Textbox(
label="Preprocessing output", lines=5
)
preprocess_btn = gr.Button(
value="Start New Experiment", variant="primary"
)
with gr.Row():
restore_zip_file = gr.File(
label="Upload the experiment directory zip file",
file_types=["zip"],
)
restore_btn = gr.Button(value="Restore Experiment", variant="primary")
with gr.Tab(label="Extract features"):
with gr.Row():
extract_features_btn = gr.Button(
value="Extract features", variant="primary"
)
with gr.Row():
extract_features_output = gr.Textbox(
label="Feature extraction output", lines=10
)
with gr.Tab(label="Train"):
with gr.Row():
train_btn = gr.Button(value="Train", variant="primary")
latest_model = gr.File(label="Latest checkpoint")
with gr.Row():
train_index_btn = gr.Button(value="Train index", variant="primary")
trained_index = gr.File(label="Trained index")
with gr.Tab(label="Download"):
with gr.Row():
download_weight_btn = gr.Button(
value="Download latest model", variant="primary"
)
download_weight_output = gr.File(label="Download latest model")
with gr.Row():
download_expdir_btn = gr.Button(
value="Download experiment directory", variant="primary"
)
download_expdir_output = gr.File(label="Download experiment directory")
with gr.Tab(label="Inference"):
with gr.Row():
original_audio = gr.Audio(
label="Upload original audio",
type="filepath",
show_download_button=True,
)
f0add = gr.Slider(
label="F0 add",
minimum=-16,
maximum=16,
step=1,
value=0,
)
infer_btn = gr.Button(value="Infer", variant="primary")
with gr.Row():
infer_output = gr.Audio(label="Inferred audio")
with gr.Row():
merge_output = gr.Audio(label="Merged audio")
preprocess_btn.click(
fn=preprocess,
inputs=[zip_file],
outputs=[exp_dir, preprocess_output],
)
extract_features_btn.click(
fn=extract_features,
inputs=[exp_dir],
outputs=[extract_features_output],
)
train_btn.click(
fn=train_model,
inputs=[exp_dir],
outputs=[latest_model],
).success(
fn=train_model,
inputs=[exp_dir],
outputs=[latest_model],
)
train_index_btn.click(
fn=train_index,
inputs=[exp_dir],
outputs=[trained_index],
)
download_weight_btn.click(
fn=download_weight,
inputs=[exp_dir],
outputs=[download_weight_output],
)
download_expdir_btn.click(
fn=download_expdir,
inputs=[exp_dir],
outputs=[download_expdir_output],
)
restore_btn.click(
fn=restore_expdir,
inputs=[restore_zip_file],
outputs=[exp_dir],
)
infer_btn.click(
fn=infer,
inputs=[exp_dir, original_audio, f0add],
outputs=[infer_output],
).success(
fn=merge,
inputs=[exp_dir, original_audio, infer_output],
outputs=[merge_output],
)
app.launch()