Spaces:
Sleeping
Sleeping
import torch, os, traceback, sys, warnings, shutil, numpy as np | |
import gradio as gr | |
import librosa | |
import asyncio | |
import rarfile | |
import edge_tts | |
import yt_dlp | |
import ffmpeg | |
import gdown | |
import subprocess | |
import wave | |
import soundfile as sf | |
from scipy.io import wavfile | |
from datetime import datetime | |
from urllib.parse import urlparse | |
from mega import Mega | |
from flask import Flask, request, jsonify, send_file,session,render_template | |
import base64 | |
import tempfile | |
import threading | |
import hashlib | |
import os | |
import werkzeug | |
from pydub import AudioSegment | |
import uuid | |
from threading import Semaphore | |
from threading import Lock | |
from multiprocessing import Process, SimpleQueue, set_start_method,get_context | |
from queue import Empty | |
from pydub import AudioSegment | |
from flask_dance.contrib.google import make_google_blueprint, google | |
import io | |
from space import ensure_model_in_weights_dir,upload_to_do | |
import boto3 | |
import os | |
import ffmpeg | |
import os | |
app = Flask(__name__) | |
app.secret_key = 'smjain_6789' | |
now_dir = os.getcwd() | |
cpt={} | |
tmp = os.path.join(now_dir, "TEMP") | |
shutil.rmtree(tmp, ignore_errors=True) | |
os.makedirs(tmp, exist_ok=True) | |
os.environ["TEMP"] = tmp | |
split_model="htdemucs" | |
convert_voice_lock = Lock() | |
#concurrent= os.getenv('concurrent', '') | |
# Define the maximum number of concurrent requests | |
MAX_CONCURRENT_REQUESTS=10 | |
# Initialize the semaphore with the maximum number of concurrent requests | |
request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS) | |
task_status_tracker = {} | |
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" # ONLY FOR TESTING, REMOVE IN PRODUCTION | |
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" | |
app.config["GOOGLE_OAUTH_CLIENT_ID"] = "144930881143-n3e3ubers3vkq7jc9doe4iirasgimdt2.apps.googleusercontent.com" | |
app.config["GOOGLE_OAUTH_CLIENT_SECRET"] = "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" | |
google_blueprint = make_google_blueprint(scope=["profile", "email"]) | |
app.register_blueprint(google_blueprint, url_prefix="/login") | |
ACCESS_ID = os.getenv('ACCESS_ID', '') | |
SECRET_KEY = os.getenv('SECRET_KEY', '') | |
#set_start_method('spawn', force=True) | |
from lib.infer_pack.models import ( | |
SynthesizerTrnMs256NSFsid, | |
SynthesizerTrnMs256NSFsid_nono, | |
SynthesizerTrnMs768NSFsid, | |
SynthesizerTrnMs768NSFsid_nono, | |
) | |
from fairseq import checkpoint_utils | |
from vc_infer_pipeline import VC | |
from config import Config | |
config = Config() | |
tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices()) | |
voices = [f"{v['ShortName']}-{v['Gender']}" for v in tts_voice_list] | |
hubert_model = None | |
f0method_mode = ["pm", "harvest", "crepe"] | |
f0method_info = "PM is fast, Harvest is good but extremely slow, and Crepe effect is good but requires GPU (Default: PM)" | |
def index(): | |
# Check if user is logged in | |
return render_template("ui.html") | |
#if google.authorized: | |
# return render_template("index.html", logged_in=True) | |
#else: | |
# return render_template("index.html", logged_in=False) | |
if os.path.isfile("rmvpe.pt"): | |
f0method_mode.insert(2, "rmvpe") | |
f0method_info = "PM is fast, Harvest is good but extremely slow, Rvmpe is alternative to harvest (might be better), and Crepe effect is good but requires GPU (Default: PM)" | |
def load_hubert(): | |
global hubert_model | |
models, _, _ = checkpoint_utils.load_model_ensemble_and_task( | |
["hubert_base.pt"], | |
suffix="", | |
) | |
hubert_model = models[0] | |
hubert_model = hubert_model.to(config.device) | |
if config.is_half: | |
hubert_model = hubert_model.half() | |
else: | |
hubert_model = hubert_model.float() | |
hubert_model.eval() | |
load_hubert() | |
weight_root = "weights" | |
index_root = "weights/index" | |
weights_model = [] | |
weights_index = [] | |
for _, _, model_files in os.walk(weight_root): | |
for file in model_files: | |
if file.endswith(".pth"): | |
weights_model.append(file) | |
for _, _, index_files in os.walk(index_root): | |
for file in index_files: | |
if file.endswith('.index') and "trained" not in file: | |
weights_index.append(os.path.join(index_root, file)) | |
def check_models(): | |
weights_model = [] | |
weights_index = [] | |
for _, _, model_files in os.walk(weight_root): | |
for file in model_files: | |
if file.endswith(".pth"): | |
weights_model.append(file) | |
for _, _, index_files in os.walk(index_root): | |
for file in index_files: | |
if file.endswith('.index') and "trained" not in file: | |
weights_index.append(os.path.join(index_root, file)) | |
return ( | |
gr.Dropdown.update(choices=sorted(weights_model), value=weights_model[0]), | |
gr.Dropdown.update(choices=sorted(weights_index)) | |
) | |
def clean(): | |
return ( | |
gr.Dropdown.update(value=""), | |
gr.Slider.update(visible=False) | |
) | |
# Function to delete files | |
def cleanup_files(file_paths): | |
for path in file_paths: | |
try: | |
os.remove(path) | |
print(f"Deleted {path}") | |
except Exception as e: | |
print(f"Error deleting {path}: {e}") | |
def create_song(): | |
if not google.authorized: | |
return redirect(url_for("google.login")) | |
resp = google.get("/oauth2/v2/userinfo") | |
assert resp.ok, resp.text | |
email = resp.json()["email"] | |
user_info = resp.json() | |
user_id = user_info.get("id") | |
name = user_info.get("name") | |
#if not user_exists(email): | |
# user_data = {'user_id': user_id, 'user_name': name, 'email': email, 'model_created': 'No', 'time_used': '0','model_id':''} | |
# add_user(user_data) | |
#models = get_user_models(email) | |
# Assuming we're interested in whether any model has been created | |
#model_exists = len(models) > 0 | |
return render_template("ui.html", email=email) | |
def download_file(filename): | |
# Configure the client with your credentials | |
session = boto3.session.Session() | |
client = session.client('s3', | |
region_name='nyc3', | |
endpoint_url='https://nyc3.digitaloceanspaces.com', | |
aws_access_key_id=ACCESS_ID, | |
aws_secret_access_key=SECRET_KEY) | |
# Define the bucket and object key | |
bucket_name = 'sing' # Your bucket name | |
object_key = f'{filename}' # Construct the object key | |
# Define the local path to save the file | |
local_file_path = os.path.join('weights', filename) | |
# Download the file from the bucket | |
try: | |
client.download_file(bucket_name, object_key, local_file_path) | |
except client.exceptions.NoSuchKey: | |
return jsonify({'error': 'File not found in the bucket'}), 404 | |
except Exception as e: | |
return jsonify({'error': str(e)}), 500 | |
# Optional: Send the file directly to the client | |
# return send_file(local_file_path, as_attachment=True) | |
return jsonify({'success': True, 'message': 'File downloaded successfully', 'file_path': local_file_path}) | |
def list_weights(): | |
directory = 'weights' | |
files = os.listdir(directory) | |
email = request.args.get('email', default='') | |
if not email: | |
return jsonify({"error": "Email parameter is required"}), 400 | |
list_models(email) | |
# Extract filenames without their extensions | |
filenames = [os.path.splitext(file)[0] for file in files if os.path.isfile(os.path.join(directory, file))] | |
return jsonify(filenames) | |
def logout(): | |
# Clear the session | |
session.clear() | |
#if "google_oauth_token" in session: | |
# del session["google_oauth_token"] | |
return redirect(url_for("index")) | |
def get_status(audio_id): | |
# Retrieve the task status using the unique ID | |
print(audio_id) | |
status_info = task_status_tracker.get(audio_id, {"status": "Unknown ID", "percentage": 0}) | |
return jsonify({"audio_id": audio_id, "status": status_info["status"], "percentage": status_info["percentage"]}) | |
def merge_audio_image(mp3_path, image_path, output_dir, unique_id): | |
# Generate output file path | |
output_path = os.path.join(output_dir, f"{unique_id}.mp4") | |
# Ensure the image file exists | |
if not os.path.isfile(image_path): | |
raise FileNotFoundError(f"Image file not found: {image_path}") | |
# Ensure the audio file exists | |
if not os.path.isfile(mp3_path): | |
raise FileNotFoundError(f"Audio file not found: {mp3_path}") | |
# Get the duration of the audio file | |
probe = ffmpeg.probe(mp3_path) | |
audio_duration = float(probe['format']['duration']) | |
# Create the ffmpeg command to combine image and audio into a video | |
input_image = ffmpeg.input(image_path, loop=1, t=audio_duration) | |
input_audio = ffmpeg.input(mp3_path) | |
# Apply scale and pad filters to the image | |
video_stream = input_image.filter('scale', size='1080x1080', force_original_aspect_ratio='decrease')\ | |
.filter('pad', 1080, 1080, -1, -1) | |
# Combine image and audio into a video | |
ffmpeg.output(video_stream, input_audio, output_path, vcodec='libx264', acodec='aac', strict='experimental', pix_fmt='yuv420p').run() | |
return output_path | |
processed_audio_storage = {} | |
def api_convert_voice(): | |
acquired = request_semaphore.acquire(blocking=False) | |
if not acquired: | |
return jsonify({"error": "Too many requests, please try again later"}), 429 | |
#task_status_tracker[unique_id] = {"status": "Starting", "percentage": 0} | |
try: | |
#if session.get('submitted'): | |
# return jsonify({"error": "Form already submitted"}), 400 | |
# Process the form here... | |
# Set the flag indicating the form has been submitted | |
#session['submitted'] = True | |
print(request.form) | |
print(request.files) | |
print("accessing spk_id") | |
spk_id = request.form['spk_id']+'.pth' | |
print("speaker id path=",spk_id) | |
voice_transform = request.form['voice_transform'] | |
print("before file access") | |
# The file part | |
if 'file' not in request.files: | |
return jsonify({"error": "No file part"}), 400 | |
file = request.files['file'] | |
if file.filename == '': | |
return jsonify({"error": "No selected file"}), 400 | |
if file.content_length > 10 * 1024 * 1024: | |
return jsonify({"error": "File size exceeds 6 MB"}), 400 | |
print("after file access") | |
print("check if model is there in weights dir or not") | |
filename_without_extension = os.path.splitext(file.filename)[0] | |
unique_id = filename_without_extension | |
ensure_model_in_weights_dir(spk_id) | |
print("checking done for the model") | |
content_type_format_map = { | |
'audio/mpeg': 'mp3', | |
'audio/wav': 'wav', | |
'audio/x-wav': 'wav', | |
'audio/mp4': 'mp4', | |
'audio/x-m4a': 'mp4', | |
} | |
# Default to 'mp3' if content type is unknown (or adjust as needed) | |
audio_format = content_type_format_map.get(file.content_type, 'mp3') | |
# Convert the uploaded file to an audio segment | |
audio = AudioSegment.from_file(io.BytesIO(file.read()), format=audio_format) | |
#audio = AudioSegment.from_file(io.BytesIO(file.read()), format="mp3") # Adjust format as necessary | |
file.seek(0) # Reset file pointer after reading | |
# Calculate audio length in minutes | |
audio_length_minutes = len(audio) / 60000.0 # pydub returns length in milliseconds | |
if audio_length_minutes > 5: | |
return jsonify({"error": "Audio length exceeds 5 minutes"}), 400 | |
#created_files = [] | |
# Save the file to a temporary path | |
#unique_id = str(uuid.uuid4()) | |
print(unique_id) | |
filename = werkzeug.utils.secure_filename(file.filename) | |
input_audio_path = os.path.join(tmp, f"{spk_id}_input_audio_{unique_id}.{filename.split('.')[-1]}") | |
file.save(input_audio_path) | |
#created_files.append(input_audio_path) | |
#split audio | |
task_status_tracker[unique_id] = {"status": "Processing: Step 1", "percentage": 30} | |
cut_vocal_and_inst(input_audio_path,spk_id,unique_id) | |
print("audio splitting performed") | |
vocal_path = f"output/{spk_id}_{unique_id}/{split_model}/{spk_id}_input_audio_{unique_id}/vocals.wav" | |
inst = f"output/{spk_id}_{unique_id}/{split_model}/{spk_id}_input_audio_{unique_id}/no_vocals.wav" | |
print("*****before making call to convert ", unique_id) | |
#task_status_tracker[unique_id] = "Processing: Step 2" | |
#output_queue = SimpleQueue() | |
ctx = get_context('spawn') | |
output_queue = ctx.Queue() | |
# Create and start the process | |
p = ctx.Process(target=worker, args=(spk_id, vocal_path, voice_transform, unique_id, output_queue,)) | |
p.start() | |
# Wait for the process to finish and get the result | |
p.join() | |
print("*******waiting for process to complete ") | |
output_path = output_queue.get() | |
task_status_tracker[unique_id] = {"status": "Processing: Step 2", "percentage": 80} | |
#if isinstance(output_path, Exception): | |
# print("Exception in worker:", output_path) | |
#else: | |
# print("output path of converted voice", output_path) | |
#output_path = convert_voice(spk_id, vocal_path, voice_transform,unique_id) | |
output_path1= combine_vocal_and_inst(output_path,inst,unique_id) | |
processed_audio_storage[unique_id] = output_path1 | |
session['processed_audio_id'] = unique_id | |
task_status_tracker[unique_id] = {"status": "Finalizing", "percentage": 100} | |
print(output_path1) | |
#created_files.extend([vocal_path, inst, output_path]) | |
#upload_to_do(output_path1) | |
image_path = 'singer.jpg' | |
os.makedirs("output/result", exist_ok=True) | |
output_dir="output/result" | |
mp4_path = merge_audio_image(output_path1, image_path, output_dir, unique_id) | |
upload_to_do(mp4_path) | |
task_status_tracker[unique_id]["status"] = "Completed" | |
print("file uploaded to Digital ocean space") | |
return jsonify({"message": "File processed successfully", "audio_id": unique_id}), 200 | |
finally: | |
request_semaphore.release() | |
#if os.path.exists(output_path1): | |
# return send_file(output_path1, as_attachment=True) | |
#else: | |
# return jsonify({"error": "File not found."}), 404 | |
def convert_voice_thread_safe(spk_id, vocal_path, voice_transform, unique_id): | |
with convert_voice_lock: | |
return convert_voice(spk_id, vocal_path, voice_transform, unique_id) | |
def get_vc_safe(sid, to_return_protect0): | |
with convert_voice_lock: | |
return get_vc(sid, to_return_protect0) | |
def upload_form(): | |
return render_template('ui.html') | |
def get_processed_audio(audio_id): | |
# Retrieve the path from temporary storage or session | |
if audio_id in processed_audio_storage: | |
file_path = processed_audio_storage[audio_id] | |
return send_file(file_path, as_attachment=True) | |
return jsonify({"error": "File not found."}), 404 | |
def worker(spk_id, input_audio_path, voice_transform, unique_id, output_queue): | |
try: | |
output_audio_path = convert_voice(spk_id, input_audio_path, voice_transform, unique_id) | |
print("output in worker for audio file", output_audio_path) | |
output_queue.put(output_audio_path) | |
print("added to output queue") | |
except Exception as e: | |
print("exception in adding to queue") | |
output_queue.put(e) # Send the exception to the main process for debugging | |
def convert_voice(spk_id, input_audio_path, voice_transform,unique_id): | |
get_vc(spk_id,0.5) | |
print("*****before makinf call to vc ", unique_id) | |
output_audio_path = vc_single( | |
sid=0, | |
input_audio_path=input_audio_path, | |
f0_up_key=voice_transform, # Assuming voice_transform corresponds to f0_up_key | |
f0_file=None , | |
f0_method="rmvpe", | |
file_index=spk_id, # Assuming file_index_path corresponds to file_index | |
index_rate=0.75, | |
filter_radius=3, | |
resample_sr=0, | |
rms_mix_rate=0.25, | |
protect=0.33, # Adjusted from protect_rate to protect to match the function signature, | |
unique_id=unique_id | |
) | |
print(output_audio_path) | |
return output_audio_path | |
def cut_vocal_and_inst(audio_path,spk_id,unique_id): | |
vocal_path = "output/result/audio.wav" | |
os.makedirs("output/result", exist_ok=True) | |
#wavfile.write(vocal_path, audio_data[0], audio_data[1]) | |
#logs.append("Starting the audio splitting process...") | |
#yield "\n".join(logs), None, None | |
print("before executing splitter") | |
command = f"demucs --two-stems=vocals -n {split_model} {audio_path} -o output/{spk_id}_{unique_id}" | |
env = os.environ.copy() | |
# Add or modify the environment variable for this subprocess | |
env["CUDA_VISIBLE_DEVICES"] = "0" | |
#result = subprocess.Popen(command.split(), stdout=subprocess.PIPE, text=True) | |
result = subprocess.run(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |
if result.returncode != 0: | |
print("Demucs process failed:", result.stderr) | |
else: | |
print("Demucs process completed successfully.") | |
print("after executing splitter") | |
#for line in result.stdout: | |
# logs.append(line) | |
# yield "\n".join(logs), None, None | |
print(result.stdout) | |
vocal = f"output/{split_model}/{spk_id}_input_audio/vocals.wav" | |
inst = f"output/{split_model}/{spk_id}_input_audio/no_vocals.wav" | |
#logs.append("Audio splitting complete.") | |
def combine_vocal_and_inst(vocal_path, inst_path, output_path): | |
vocal_volume=1 | |
inst_volume=1 | |
os.makedirs("output/result", exist_ok=True) | |
# Assuming vocal_path and inst_path are now directly passed as arguments | |
output_path = f"output/result/{output_path}.mp3" | |
#command = f'ffmpeg -y -i "{inst_path}" -i "{vocal_path}" -filter_complex [0:a]volume={inst_volume}[i];[1:a]volume={vocal_volume}[v];[i][v]amix=inputs=2:duration=longest[a] -map [a] -b:a 320k -c:a libmp3lame "{output_path}"' | |
#command=f'ffmpeg -y -i "{inst_path}" -i "{vocal_path}" -filter_complex "amix=inputs=2:duration=longest" -b:a 320k -c:a libmp3lame "{output_path}"' | |
# Load the audio files | |
print(vocal_path) | |
print(inst_path) | |
vocal = AudioSegment.from_file(vocal_path) | |
instrumental = AudioSegment.from_file(inst_path) | |
# Overlay the vocal track on top of the instrumental track | |
combined = vocal.overlay(instrumental) | |
# Export the result | |
combined.export(output_path, format="mp3") | |
#result = subprocess.run(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
return output_path | |
def vc_single( | |
sid, | |
input_audio_path, | |
f0_up_key, | |
f0_file, | |
f0_method, | |
file_index, | |
index_rate, | |
filter_radius, | |
resample_sr, | |
rms_mix_rate, | |
protect, | |
unique_id | |
): # spk_item, input_audio0, vc_transform0,f0_file,f0method0 | |
global tgt_sr, net_g, vc, hubert_model, version, cpt | |
print("***** in vc ", unique_id) | |
try: | |
logs = [] | |
print(f"Converting...") | |
audio, sr = librosa.load(input_audio_path, sr=16000, mono=True) | |
print(f"found audio ") | |
f0_up_key = int(f0_up_key) | |
times = [0, 0, 0] | |
if hubert_model == None: | |
load_hubert() | |
print("loaded hubert") | |
if_f0 = 1 | |
audio_opt = vc.pipeline( | |
hubert_model, | |
net_g, | |
0, | |
audio, | |
input_audio_path, | |
times, | |
f0_up_key, | |
f0_method, | |
file_index, | |
# file_big_npy, | |
index_rate, | |
if_f0, | |
filter_radius, | |
tgt_sr, | |
resample_sr, | |
rms_mix_rate, | |
version, | |
protect, | |
f0_file=f0_file | |
) | |
# Get the current thread's name or ID | |
if resample_sr >= 16000 and tgt_sr != resample_sr: | |
tgt_sr = resample_sr | |
index_info = ( | |
"Using index:%s." % file_index | |
if os.path.exists(file_index) | |
else "Index not used." | |
) | |
print("writing to FS") | |
#output_file_path = os.path.join("output", f"converted_audio_{sid}.wav") # Adjust path as needed | |
# Assuming 'unique_id' is passed to convert_voice function along with 'sid' | |
print("***** before writing to file outout ", unique_id) | |
output_file_path = os.path.join("output", f"converted_audio_{sid}_{unique_id}.wav") # Adjust path as needed | |
print("******* output file path ",output_file_path) | |
os.makedirs(os.path.dirname(output_file_path), exist_ok=True) # Create the output directory if it doesn't exist | |
print("create dir") | |
# Save the audio file using the target sampling rate | |
sf.write(output_file_path, audio_opt, tgt_sr) | |
print("wrote to FS") | |
# Return the path to the saved file along with any other information | |
return output_file_path | |
except: | |
info = traceback.format_exc() | |
return info, (None, None) | |
def get_vc(sid, to_return_protect0): | |
global n_spk, tgt_sr, net_g, vc, cpt, version, weights_index | |
if sid == "" or sid == []: | |
global hubert_model | |
if hubert_model is not None: # 考虑到轮询, 需要加个判断看是否 sid 是由有模型切换到无模型的 | |
print("clean_empty_cache") | |
del net_g, n_spk, vc, hubert_model, tgt_sr # ,cpt | |
hubert_model = net_g = n_spk = vc = hubert_model = tgt_sr = None | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
###楼下不这么折腾清理不干净 | |
if_f0 = cpt[sid].get("f0", 1) | |
version = cpt[sid].get("version", "v1") | |
if version == "v1": | |
if if_f0 == 1: | |
net_g = SynthesizerTrnMs256NSFsid( | |
*cpt[sid]["config"], is_half=config.is_half | |
) | |
else: | |
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt[sid]["config"]) | |
elif version == "v2": | |
if if_f0 == 1: | |
net_g = SynthesizerTrnMs768NSFsid( | |
*cpt[sid]["config"], is_half=config.is_half | |
) | |
else: | |
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt[sid]["config"]) | |
del net_g, cpt | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
cpt = None | |
return ( | |
gr.Slider.update(maximum=2333, visible=False), | |
gr.Slider.update(visible=True), | |
gr.Dropdown.update(choices=sorted(weights_index), value=""), | |
gr.Markdown.update(value="# <center> No model selected") | |
) | |
print(f"Loading {sid} model...") | |
selected_model = sid[:-4] | |
cpt[sid] = torch.load(os.path.join(weight_root, sid), map_location="cpu") | |
tgt_sr = cpt[sid]["config"][-1] | |
cpt[sid]["config"][-3] = cpt[sid]["weight"]["emb_g.weight"].shape[0] | |
if_f0 = cpt[sid].get("f0", 1) | |
if if_f0 == 0: | |
to_return_protect0 = { | |
"visible": False, | |
"value": 0.5, | |
"__type__": "update", | |
} | |
else: | |
to_return_protect0 = { | |
"visible": True, | |
"value": to_return_protect0, | |
"__type__": "update", | |
} | |
version = cpt[sid].get("version", "v1") | |
if version == "v1": | |
if if_f0 == 1: | |
net_g = SynthesizerTrnMs256NSFsid(*cpt[sid]["config"], is_half=config.is_half) | |
else: | |
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt[sid]["config"]) | |
elif version == "v2": | |
if if_f0 == 1: | |
net_g = SynthesizerTrnMs768NSFsid(*cpt[sid]["config"], is_half=config.is_half) | |
else: | |
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt[sid]["config"]) | |
del net_g.enc_q | |
print(net_g.load_state_dict(cpt[sid]["weight"], strict=False)) | |
net_g.eval().to(config.device) | |
if config.is_half: | |
net_g = net_g.half() | |
else: | |
net_g = net_g.float() | |
vc = VC(tgt_sr, config) | |
n_spk = cpt[sid]["config"][-3] | |
weights_index = [] | |
for _, _, index_files in os.walk(index_root): | |
for file in index_files: | |
if file.endswith('.index') and "trained" not in file: | |
weights_index.append(os.path.join(index_root, file)) | |
if weights_index == []: | |
selected_index = gr.Dropdown.update(value="") | |
else: | |
selected_index = gr.Dropdown.update(value=weights_index[0]) | |
for index, model_index in enumerate(weights_index): | |
if selected_model in model_index: | |
selected_index = gr.Dropdown.update(value=weights_index[index]) | |
break | |
return ( | |
gr.Slider.update(maximum=n_spk, visible=True), | |
to_return_protect0, | |
selected_index, | |
gr.Markdown.update( | |
f'## <center> {selected_model}\n'+ | |
f'### <center> RVC {version} Model' | |
) | |
) | |
if __name__ == '__main__': | |
app.run(debug=False, port=5000,host='0.0.0.0') | |