Spaces:
Paused
Paused
from glob import glob | |
import os | |
from typing import Tuple | |
from demucs.separate import main as demucs | |
import gradio as gr | |
import numpy as np | |
import soundfile as sf | |
from configs.config import Config | |
from infer.modules.vc.modules import VC | |
from zero import zero | |
from model import device | |
def infer( | |
exp_dir: str, original_audio: str, f0add: int, index_rate: float, protect: float | |
) -> Tuple[int, np.ndarray]: | |
model = os.path.join(exp_dir, "model.pth") | |
if not os.path.exists(model): | |
raise gr.Error("Model not found") | |
index = glob(f"{exp_dir}/added_*.index") | |
if index: | |
index = index[0] | |
else: | |
index = None | |
base = os.path.basename(original_audio) | |
base = os.path.splitext(base)[0] | |
demucs( | |
["--two-stems", "vocals", "-d", str(device), "-n", "htdemucs", original_audio] | |
) | |
out = os.path.join("separated", "htdemucs", base, "vocals.wav") | |
cfg = Config() | |
vc = VC(cfg) | |
vc.get_vc(model) | |
_, wav_opt = vc.vc_single( | |
0, | |
out, | |
f0add, | |
None, | |
"rmvpe", | |
index, | |
None, | |
index_rate, | |
3, # this only has effect when f0_method is "harvest" | |
0, | |
1, | |
protect, | |
) | |
sr = wav_opt[0] | |
data = wav_opt[1] | |
return sr, data | |
def merge(exp_dir: str, original_audio: str, vocal: Tuple[int, np.ndarray]) -> str: | |
base = os.path.basename(original_audio) | |
base = os.path.splitext(base)[0] | |
music = os.path.join("separated", "htdemucs", base, "no_vocals.wav") | |
tmp = os.path.join(exp_dir, "tmp.wav") | |
sf.write(tmp, vocal[1], vocal[0]) | |
os.system( | |
f"ffmpeg -i {music} -i {tmp} -filter_complex '[1]volume=2[a];[0][a]amix=inputs=2:duration=first:dropout_transition=2' -ac 2 -y {tmp}.merged.mp3" | |
) | |
return f"{tmp}.merged.mp3" | |
class InferenceTab: | |
def __init__(self): | |
pass | |
def ui(self): | |
gr.Markdown("# Inference") | |
gr.Markdown( | |
"After trained model is pruned, you can use it to infer on new music. \n" | |
"Upload the original audio and adjust the F0 add value to generate the inferred audio." | |
) | |
with gr.Row(): | |
self.original_audio = gr.Audio( | |
label="Upload original audio", | |
type="filepath", | |
show_download_button=True, | |
) | |
with gr.Column(): | |
self.f0add = gr.Slider( | |
label="F0 +/-", | |
minimum=-16, | |
maximum=16, | |
step=1, | |
value=0, | |
) | |
self.index_rate = gr.Slider( | |
label="Index rate", | |
minimum=-0, | |
maximum=1, | |
step=0.01, | |
value=0.5, | |
) | |
self.protect = gr.Slider( | |
label="Protect", | |
minimum=0, | |
maximum=1, | |
step=0.01, | |
value=0.33, | |
) | |
self.infer_btn = gr.Button(value="Infer", variant="primary") | |
with gr.Row(): | |
self.infer_output = gr.Audio( | |
label="Inferred audio", show_download_button=True, format="mp3" | |
) | |
with gr.Row(): | |
self.merge_output = gr.Audio( | |
label="Merged audio", show_download_button=True, format="mp3" | |
) | |
def build(self, exp_dir: gr.Textbox): | |
self.infer_btn.click( | |
fn=infer, | |
inputs=[ | |
exp_dir, | |
self.original_audio, | |
self.f0add, | |
self.index_rate, | |
self.protect, | |
], | |
outputs=[self.infer_output], | |
).success( | |
fn=merge, | |
inputs=[exp_dir, self.original_audio, self.infer_output], | |
outputs=[self.merge_output], | |
) | |