tts / voice_processing.py
MAZALA2024's picture
Update voice_processing.py
2a969d1 verified
raw
history blame
9.54 kB
import asyncio
import datetime
import logging
import os
import time
import traceback
import tempfile
from concurrent.futures import ThreadPoolExecutor
from torch.nn.utils.parametrizations import weight_norm
from scipy.io import wavfile
import numpy as np
import traceback
import librosa
import torch
from fairseq import checkpoint_utils
import uuid
from config import Config
from lib.infer_pack.models import (
SynthesizerTrnMs256NSFsid,
SynthesizerTrnMs256NSFsid_nono,
SynthesizerTrnMs768NSFsid,
SynthesizerTrnMs768NSFsid_nono,
)
from rmvpe import RMVPE
from vc_infer_pipeline import VC
model_cache = {}
logger = logging.getLogger('voice_processing')
def load_model(model_name):
"""
Loads an RVC model with proper error handling and logging.
Args:
model_name (str): Name of the model to load (e.g., 'mongolian7-male')
Returns:
tuple: (model, config) or None if loading fails
"""
try:
logger.info(f"Loading model: {model_name}")
# Construct model path
model_dir = "weights"
model_path = os.path.join(model_dir, model_name)
# Find .pth file
pth_files = [f for f in os.listdir(model_path) if f.endswith('.pth')]
if not pth_files:
logger.error(f"No .pth file found in {model_path}")
return None
pth_path = os.path.join(model_path, pth_files[0])
logger.info(f"Found model file: {pth_path}")
# Load model weights
cpt = torch.load(pth_path, map_location="cpu", weights_only=True)
logger.info("Model weights loaded successfully")
# Get configuration
tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
if_f0 = cpt.get("f0", 1)
version = cpt.get("version", "v1")
logger.info(f"Model config: sr={tgt_sr}, if_f0={if_f0}, version={version}")
# Initialize model based on version
if version == "v1":
from lib.infer_pack.models import SynthesizerTrnMs256NSFsid
model = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=False)
else:
from lib.infer_pack.models import SynthesizerTrnMs768NSFsid
model = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=False)
# Load weights and prepare model
model.eval()
model.load_state_dict(cpt["weight"], strict=False)
logger.info("Model initialized successfully")
return model
except Exception as e:
logger.error(f"Error loading model {model_name}: {str(e)}")
logger.error(traceback.format_exc())
return None
def process_audio(model, audio_file, logger, index_rate=0, use_uploaded_voice=True, uploaded_voice=None):
"""Process audio through the model"""
try:
logger.info("Starting audio processing")
if model is None:
logger.error("No model provided for processing")
return None
# Load audio
sr, audio = wavfile.read(audio_file)
logger.info(f"Loaded audio: sr={sr}Hz, shape={audio.shape}")
# Convert to mono if needed
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
audio = audio.astype(np.float32)
# Prepare input tensor
input_tensor = torch.FloatTensor(audio)
if torch.cuda.is_available():
input_tensor = input_tensor.cuda()
model = model.cuda()
# Process through model
with torch.no_grad():
# Prepare required arguments for model.infer()
phone = input_tensor.unsqueeze(0) # Add batch dimension [1, sequence_length]
phone_lengths = torch.LongTensor([len(input_tensor)]).to(input_tensor.device)
pitch = torch.zeros(1, len(input_tensor)).to(input_tensor.device) # Default pitch
nsff0 = torch.zeros_like(pitch).to(input_tensor.device)
sid = torch.LongTensor([0]).to(input_tensor.device) # Speaker ID
# Call infer with all required arguments
output = model.infer(
phone=phone,
phone_lengths=phone_lengths,
pitch=pitch,
nsff0=nsff0,
sid=sid
)
if torch.cuda.is_available():
output = output.cpu()
output = output.numpy()
logger.info(f"Processing complete, output shape: {output.shape}")
return (None, None, (sr, output))
except Exception as e:
logger.error(f"Error processing audio: {str(e)}")
logger.error(traceback.format_exc())
return None
# Set logging levels
logging.getLogger("fairseq").setLevel(logging.WARNING)
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
limitation = os.getenv("SYSTEM") == "spaces"
config = Config()
# Edge TTS voices
# tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
# tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"]
# RVC models directory
model_root = "weights"
models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
models.sort()
def get_unique_filename(extension):
return f"{uuid.uuid4()}.{extension}"
def model_data(model_name):
pth_path = [
f"{model_root}/{model_name}/{f}"
for f in os.listdir(f"{model_root}/{model_name}")
if f.endswith(".pth")
][0]
print(f"Loading {pth_path}")
# Updated model loading with weights_only=True to address the deprecation warning
cpt = torch.load(pth_path, map_location="cpu", weights_only=True)
tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
if_f0 = cpt.get("f0", 1)
version = cpt.get("version", "v1")
if version == "v1":
if if_f0 == 1:
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
else:
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
elif version == "v2":
if if_f0 == 1:
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
else:
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
else:
raise ValueError("Unknown version")
del net_g.enc_q
net_g.load_state_dict(cpt["weight"], strict=False)
print("Model loaded")
net_g.eval().to(config.device)
if config.is_half:
net_g = net_g.half()
else:
net_g = net_g.float()
vc = VC(tgt_sr, config)
index_files = [
f"{model_root}/{model_name}/{f}"
for f in os.listdir(f"{model_root}/{model_name}")
if f.endswith(".index")
]
if len(index_files) == 0:
print("No index file found")
index_file = ""
else:
index_file = index_files[0]
print(f"Index file found: {index_file}")
return tgt_sr, net_g, vc, version, index_file, if_f0
def load_hubert():
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
["hubert_base.pt"],
suffix="",
)
hubert_model = models[0]
hubert_model = hubert_model.to(config.device)
if config.is_half:
hubert_model = hubert_model.half()
else:
hubert_model = hubert_model.float()
return hubert_model.eval()
def get_model_names():
return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
# Initialize the global models
hubert_model = load_hubert()
rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
# voice_mapping = {
# "Mongolian Male": "mn-MN-BataaNeural",
# "Mongolian Female": "mn-MN-YesuiNeural"
# }
# Function to run async functions in a new event loop within a thread
def run_async_in_thread(fn, *args):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(fn(*args))
loop.close()
return result
def parallel_tts(tasks): # Remove any async here
"""Process multiple TTS tasks"""
logger.info(f"Received {len(tasks)} tasks for processing")
results = []
for i, task in enumerate(tasks):
try:
logger.info(f"Processing task {i+1}/{len(tasks)}")
model_name, _, _, slang_rate, use_uploaded_voice, audio_file = task
logger.info(f"Model: {model_name}, Slang rate: {slang_rate}")
model = load_model(model_name)
if model is None:
logger.error(f"Failed to load model {model_name}")
results.append(None)
continue
result = process_audio(
model=model,
audio_file=audio_file,
logger=logger,
index_rate=0,
use_uploaded_voice=use_uploaded_voice,
uploaded_voice=None
)
logger.info(f"Processing completed for task {i+1}")
results.append(result)
except Exception as e:
logger.error(f"Error processing task {i+1}: {str(e)}")
logger.error(traceback.format_exc())
results.append(None)
return results