Geamavc / infer-web.py
Hev832's picture
Create infer-web.py
c2301e9 verified
raw
history blame
6.4 kB
import os
import shutil
import subprocess
from mega import Mega
import gradio as gr
def download_from_url(url, model):
try:
model = model.replace('.pth', '').replace('.index', '').replace('.zip', '')
url = url.replace('/blob/main/', '/resolve/main/').strip()
for directory in ["downloads", "unzips", "zip"]:
os.makedirs(directory, exist_ok=True)
if url.endswith('.pth'):
subprocess.run(["wget", url, "-O", f'assets/weights/{model}.pth'])
elif url.endswith('.index'):
os.makedirs(f'logs/{model}', exist_ok=True)
subprocess.run(["wget", url, "-O", f'logs/{model}/added_{model}.index'])
elif url.endswith('.zip'):
subprocess.run(["wget", url, "-O", f'downloads/{model}.zip'])
else:
if "drive.google.com" in url:
url = url.split('/')[0]
subprocess.run(["gdown", url, "--fuzzy", "-O", f'downloads/{model}'])
elif "mega.nz" in url:
Mega().download_url(url, 'downloads')
else:
subprocess.run(["wget", url, "-O", f'downloads/{model}'])
downloaded_file = next((f for f in os.listdir("downloads")), None)
if downloaded_file:
if downloaded_file.endswith(".zip"):
shutil.unpack_archive(f'downloads/{downloaded_file}', "unzips", 'zip')
for root, _, files in os.walk('unzips'):
for file in files:
file_path = os.path.join(root, file)
if file.endswith(".index"):
os.makedirs(f'logs/{model}', exist_ok=True)
shutil.copy2(file_path, f'logs/{model}')
elif file.endswith(".pth") and "G_" not in file and "D_" not in file:
shutil.copy(file_path, f'assets/weights/{model}.pth')
elif downloaded_file.endswith(".pth"):
shutil.copy(f'downloads/{downloaded_file}', f'assets/weights/{model}.pth')
elif downloaded_file.endswith(".index"):
os.makedirs(f'logs/{model}', exist_ok=True)
shutil.copy(f'downloads/{downloaded_file}', f'logs/{model}/added_{model}.index')
else:
return "Failed to download file"
return f"Successfully downloaded {model} voice models"
except Exception as e:
return f"Error: {str(e)}"
finally:
shutil.rmtree("downloads", ignore_errors=True)
shutil.rmtree("unzips", ignore_errors=True)
shutil.rmtree("zip", ignore_errors=True)
def listen_to_model(model_path, index_path, pitch, input_path, f0_method, save_as, index_rate, volume_normalization, consonant_protection):
if not os.path.exists(model_path):
return "Model path not found"
if not os.path.exists(index_path):
return f"{index_path} was not found"
if not os.path.exists(input_path):
return f"{input_path} was not found"
os.environ['index_root'] = os.path.dirname(index_path)
index_path = os.path.basename(index_path)
model_name = os.path.basename(model_path)
os.environ['weight_root'] = os.path.dirname(model_path)
try:
command = [
"python", "tools/infer_cli.py",
"--f0up_key", str(pitch),
"--input_path", input_path,
"--index_path", index_path,
"--f0method", f0_method,
"--opt_path", save_as,
"--model_name", model_name,
"--index_rate", str(index_rate),
"--device", "cuda:0",
"--is_half", "True",
"--filter_radius", "3",
"--resample_sr", "0",
"--rms_mix_rate", str(volume_normalization),
"--protect", str(consonant_protection)
]
subprocess.run(command, check=True)
return save_as
except subprocess.CalledProcessError as e:
return f"Error: {str(e)}"
with gr.Blocks() as demo:
gr.Markdown("# RVC V2 Web UI")
with gr.Tabs():
with gr.TabItem("Download Model"):
gr.Markdown("### Download RVC Model")
url_input = gr.Textbox(label="Model URL", placeholder="Enter the model URL here")
model_input = gr.Textbox(label="Model Name", placeholder="Enter the model name here")
download_button = gr.Button("Download")
download_output = gr.Textbox(label="Download Status")
download_button.click(download_from_url, inputs=[url_input, model_input], outputs=download_output)
with gr.TabItem("Listen to Model"):
gr.Markdown("### Listen to Your Model")
model_path_input = gr.Textbox(label="Model Path", value="/content/RVC/assets/weights/Sonic.pth")
index_path_input = gr.Textbox(label="Index Path", value="/content/RVC/logs/Sonic/added_IVF905_Flat_nprobe_1_Sonic_v2.index")
input_path_input = gr.Textbox(label="Input Audio Path", value="/content/RVC/audios/astronauts.mp3")
save_as_input = gr.Textbox(label="Save Output As", value="/content/RVC/audios/cli_output.wav")
f0_method_input = gr.Radio(label="F0 Method", choices=["rmvpe", "pm", "harvest"], value="rmvpe")
with gr.Row():
pitch_input = gr.Slider(label="Pitch", minimum=-12, maximum=12, step=1, value=0)
index_rate_input = gr.Slider(label="Index Rate", minimum=0, maximum=1, step=0.01, value=0.5)
volume_normalization_input = gr.Slider(label="Volume Normalization", minimum=0, maximum=1, step=0.01, value=0)
consonant_protection_input = gr.Slider(label="Consonant Protection", minimum=0, maximum=1, step=0.01, value=0.5)
listen_button = gr.Button("Generate and Listen")
audio_output = gr.Audio(label="Output Audio")
listen_button.click(
listen_to_model,
inputs=[
model_path_input,
index_path_input,
pitch_input,
input_path_input,
f0_method_input,
save_as_input,
index_rate_input,
volume_normalization_input,
consonant_protection_input
],
outputs=audio_output
)
demo.launch()