Update voice_processing.py
Browse files- voice_processing.py +37 -57
voice_processing.py
CHANGED
@@ -14,6 +14,7 @@ from fairseq import checkpoint_utils
|
|
14 |
import uuid
|
15 |
|
16 |
from config import Config
|
|
|
17 |
from lib.infer_pack.models import (
|
18 |
SynthesizerTrnMs256NSFsid,
|
19 |
SynthesizerTrnMs256NSFsid_nono,
|
@@ -23,6 +24,9 @@ from lib.infer_pack.models import (
|
|
23 |
from rmvpe import RMVPE
|
24 |
from vc_infer_pipeline import VC
|
25 |
|
|
|
|
|
|
|
26 |
# Set logging levels
|
27 |
logging.getLogger("fairseq").setLevel(logging.WARNING)
|
28 |
logging.getLogger("numba").setLevel(logging.WARNING)
|
@@ -34,21 +38,9 @@ limitation = os.getenv("SYSTEM") == "spaces"
|
|
34 |
|
35 |
config = Config()
|
36 |
|
37 |
-
# Define voice_mapping first to ensure it's always available
|
38 |
-
voice_mapping = {
|
39 |
-
"Mongolian Male": "mn-MN-BataaNeural",
|
40 |
-
"Mongolian Female": "mn-MN-YesuiNeural"
|
41 |
-
}
|
42 |
-
|
43 |
# Edge TTS voices
|
44 |
-
|
45 |
-
|
46 |
-
tts_voice_list = loop.run_until_complete(edge_tts.list_voices())
|
47 |
-
tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"]
|
48 |
-
except Exception as e:
|
49 |
-
logging.error(f"Error loading Edge TTS voices: {e}")
|
50 |
-
tts_voice_list = []
|
51 |
-
tts_voices = []
|
52 |
|
53 |
# RVC models directory
|
54 |
model_root = "weights"
|
@@ -58,24 +50,15 @@ models.sort()
|
|
58 |
def get_unique_filename(extension):
|
59 |
return f"{uuid.uuid4()}.{extension}"
|
60 |
|
61 |
-
model_cache = {}
|
62 |
-
|
63 |
def model_data(model_name):
|
64 |
-
|
65 |
-
return model_cache[model_name]
|
66 |
-
|
67 |
pth_path = [
|
68 |
f"{model_root}/{model_name}/{f}"
|
69 |
for f in os.listdir(f"{model_root}/{model_name}")
|
70 |
if f.endswith(".pth")
|
71 |
][0]
|
72 |
print(f"Loading {pth_path}")
|
73 |
-
|
74 |
-
cpt = torch.load(pth_path, map_location="cpu")
|
75 |
-
except Exception as e:
|
76 |
-
logging.error(f"Error loading model {pth_path}: {e}")
|
77 |
-
raise e
|
78 |
-
|
79 |
tgt_sr = cpt["config"][-1]
|
80 |
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
81 |
if_f0 = cpt.get("f0", 1)
|
@@ -114,32 +97,34 @@ def model_data(model_name):
|
|
114 |
index_file = index_files[0]
|
115 |
print(f"Index file found: {index_file}")
|
116 |
|
117 |
-
|
118 |
-
return model_cache[model_name]
|
119 |
|
120 |
def load_hubert():
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
return hubert_model.eval()
|
133 |
-
except Exception as e:
|
134 |
-
logging.error(f"Error loading HuBERT model: {e}")
|
135 |
-
raise e
|
136 |
|
137 |
def get_model_names():
|
138 |
return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
|
139 |
|
140 |
-
# Initialize
|
141 |
-
|
|
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
def run_async_in_thread(fn, *args):
|
144 |
loop = asyncio.new_event_loop()
|
145 |
asyncio.set_event_loop(loop)
|
@@ -148,8 +133,10 @@ def run_async_in_thread(fn, *args):
|
|
148 |
return result
|
149 |
|
150 |
def parallel_tts(tasks):
|
151 |
-
|
152 |
-
|
|
|
|
|
153 |
return results
|
154 |
|
155 |
async def tts(
|
@@ -187,7 +174,7 @@ async def tts(
|
|
187 |
# EdgeTTS processing
|
188 |
if limitation and len(tts_text) > 12000:
|
189 |
return (
|
190 |
-
f"Text characters should be at most 12000 in this
|
191 |
None,
|
192 |
None,
|
193 |
)
|
@@ -206,15 +193,15 @@ async def tts(
|
|
206 |
# Common processing after loading the audio
|
207 |
duration = len(audio) / sr
|
208 |
print(f"Audio duration: {duration}s")
|
209 |
-
if limitation and duration >=
|
210 |
return (
|
211 |
-
f"Audio should be less than 20 seconds in this
|
212 |
None,
|
213 |
None,
|
214 |
)
|
215 |
|
216 |
f0_up_key = int(f0_up_key)
|
217 |
-
# Load the model
|
218 |
tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
|
219 |
|
220 |
# Setup for RMVPE or other pitch extraction methods
|
@@ -266,10 +253,3 @@ async def tts(
|
|
266 |
print(traceback_info)
|
267 |
return str(e), None, None
|
268 |
|
269 |
-
# Initialize the global models
|
270 |
-
try:
|
271 |
-
hubert_model = load_hubert()
|
272 |
-
rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
|
273 |
-
except Exception as e:
|
274 |
-
logging.error(f"Failed to initialize global models: {e}")
|
275 |
-
# Optionally, you can exit or handle the error as needed
|
|
|
14 |
import uuid
|
15 |
|
16 |
from config import Config
|
17 |
+
from config import Config, voice_mapping
|
18 |
from lib.infer_pack.models import (
|
19 |
SynthesizerTrnMs256NSFsid,
|
20 |
SynthesizerTrnMs256NSFsid_nono,
|
|
|
24 |
from rmvpe import RMVPE
|
25 |
from vc_infer_pipeline import VC
|
26 |
|
27 |
+
model_cache = {}
|
28 |
+
|
29 |
+
|
30 |
# Set logging levels
|
31 |
logging.getLogger("fairseq").setLevel(logging.WARNING)
|
32 |
logging.getLogger("numba").setLevel(logging.WARNING)
|
|
|
38 |
|
39 |
config = Config()
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
# Edge TTS voices
|
42 |
+
tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
|
43 |
+
tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"]
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# RVC models directory
|
46 |
model_root = "weights"
|
|
|
50 |
def get_unique_filename(extension):
|
51 |
return f"{uuid.uuid4()}.{extension}"
|
52 |
|
|
|
|
|
53 |
def model_data(model_name):
|
54 |
+
# We will not modify this function to cache models
|
|
|
|
|
55 |
pth_path = [
|
56 |
f"{model_root}/{model_name}/{f}"
|
57 |
for f in os.listdir(f"{model_root}/{model_name}")
|
58 |
if f.endswith(".pth")
|
59 |
][0]
|
60 |
print(f"Loading {pth_path}")
|
61 |
+
cpt = torch.load(pth_path, map_location="cpu")
|
|
|
|
|
|
|
|
|
|
|
62 |
tgt_sr = cpt["config"][-1]
|
63 |
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
64 |
if_f0 = cpt.get("f0", 1)
|
|
|
97 |
index_file = index_files[0]
|
98 |
print(f"Index file found: {index_file}")
|
99 |
|
100 |
+
return tgt_sr, net_g, vc, version, index_file, if_f0
|
|
|
101 |
|
102 |
def load_hubert():
|
103 |
+
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
104 |
+
["hubert_base.pt"],
|
105 |
+
suffix="",
|
106 |
+
)
|
107 |
+
hubert_model = models[0]
|
108 |
+
hubert_model = hubert_model.to(config.device)
|
109 |
+
if config.is_half:
|
110 |
+
hubert_model = hubert_model.half()
|
111 |
+
else:
|
112 |
+
hubert_model = hubert_model.float()
|
113 |
+
return hubert_model.eval()
|
|
|
|
|
|
|
|
|
114 |
|
115 |
def get_model_names():
|
116 |
return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
|
117 |
|
118 |
+
# Initialize the global models
|
119 |
+
hubert_model = load_hubert()
|
120 |
+
rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
|
121 |
|
122 |
+
voice_mapping = {
|
123 |
+
"Mongolian Male": "mn-MN-BataaNeural",
|
124 |
+
"Mongolian Female": "mn-MN-YesuiNeural"
|
125 |
+
}
|
126 |
+
|
127 |
+
# Function to run async functions in a new event loop within a thread
|
128 |
def run_async_in_thread(fn, *args):
|
129 |
loop = asyncio.new_event_loop()
|
130 |
asyncio.set_event_loop(loop)
|
|
|
133 |
return result
|
134 |
|
135 |
def parallel_tts(tasks):
|
136 |
+
# Increase max_workers to better utilize CPU and GPU resources
|
137 |
+
with ThreadPoolExecutor(max_workers=8) as executor: # Adjust based on your server capacity
|
138 |
+
futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
|
139 |
+
results = [future.result() for future in futures]
|
140 |
return results
|
141 |
|
142 |
async def tts(
|
|
|
174 |
# EdgeTTS processing
|
175 |
if limitation and len(tts_text) > 12000:
|
176 |
return (
|
177 |
+
f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.",
|
178 |
None,
|
179 |
None,
|
180 |
)
|
|
|
193 |
# Common processing after loading the audio
|
194 |
duration = len(audio) / sr
|
195 |
print(f"Audio duration: {duration}s")
|
196 |
+
if limitation and duration >= 20000:
|
197 |
return (
|
198 |
+
f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
|
199 |
None,
|
200 |
None,
|
201 |
)
|
202 |
|
203 |
f0_up_key = int(f0_up_key)
|
204 |
+
# Load the model
|
205 |
tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
|
206 |
|
207 |
# Setup for RMVPE or other pitch extraction methods
|
|
|
253 |
print(traceback_info)
|
254 |
return str(e), None, None
|
255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|