Update voice_processing.py
Browse files- voice_processing.py +24 -28
voice_processing.py
CHANGED
@@ -34,11 +34,12 @@ limitation = os.getenv("SYSTEM") == "spaces"
|
|
34 |
|
35 |
config = Config()
|
36 |
|
37 |
-
# Edge TTS
|
38 |
-
|
39 |
-
|
|
|
40 |
|
41 |
-
# RVC models
|
42 |
model_root = "weights"
|
43 |
models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
|
44 |
models.sort()
|
@@ -46,7 +47,12 @@ models.sort()
|
|
46 |
def get_unique_filename(extension):
|
47 |
return f"{uuid.uuid4()}.{extension}"
|
48 |
|
|
|
|
|
49 |
def model_data(model_name):
|
|
|
|
|
|
|
50 |
pth_path = [
|
51 |
f"{model_root}/{model_name}/{f}"
|
52 |
for f in os.listdir(f"{model_root}/{model_name}")
|
@@ -92,7 +98,8 @@ def model_data(model_name):
|
|
92 |
index_file = index_files[0]
|
93 |
print(f"Index file found: {index_file}")
|
94 |
|
95 |
-
|
|
|
96 |
|
97 |
def load_hubert():
|
98 |
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
@@ -108,21 +115,14 @@ def load_hubert():
|
|
108 |
return hubert_model.eval()
|
109 |
|
110 |
def get_model_names():
|
111 |
-
model_root = "weights" # Assuming this is where your models are stored
|
112 |
return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
|
113 |
|
114 |
-
#
|
115 |
-
|
116 |
-
loop = asyncio.new_event_loop()
|
117 |
-
asyncio.set_event_loop(loop)
|
118 |
-
result = loop.run_until_complete(fn(*args))
|
119 |
-
loop.close()
|
120 |
-
return result
|
121 |
|
122 |
def parallel_tts(tasks):
|
123 |
-
|
124 |
-
|
125 |
-
results = [future.result() for future in futures]
|
126 |
return results
|
127 |
|
128 |
async def tts(
|
@@ -133,7 +133,7 @@ async def tts(
|
|
133 |
use_uploaded_voice,
|
134 |
uploaded_voice,
|
135 |
):
|
136 |
-
# Default values for parameters
|
137 |
speed = 0 # Default speech speed
|
138 |
f0_up_key = 0 # Default pitch adjustment
|
139 |
f0_method = "rmvpe" # Default pitch extraction method
|
@@ -160,7 +160,7 @@ async def tts(
|
|
160 |
# EdgeTTS processing
|
161 |
if limitation and len(tts_text) > 12000:
|
162 |
return (
|
163 |
-
f"Text characters should be at most 12000 in this
|
164 |
None,
|
165 |
None,
|
166 |
)
|
@@ -179,14 +179,15 @@ async def tts(
|
|
179 |
# Common processing after loading the audio
|
180 |
duration = len(audio) / sr
|
181 |
print(f"Audio duration: {duration}s")
|
182 |
-
if limitation and duration >=
|
183 |
return (
|
184 |
-
f"Audio should be less than 20 seconds in this
|
185 |
None,
|
186 |
None,
|
187 |
)
|
188 |
|
189 |
f0_up_key = int(f0_up_key)
|
|
|
190 |
tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
|
191 |
|
192 |
# Setup for RMVPE or other pitch extraction methods
|
@@ -229,7 +230,7 @@ async def tts(
|
|
229 |
|
230 |
except EOFError:
|
231 |
info = (
|
232 |
-
"
|
233 |
)
|
234 |
print(info)
|
235 |
return info, None, None
|
@@ -238,11 +239,6 @@ async def tts(
|
|
238 |
print(traceback_info)
|
239 |
return str(e), None, None
|
240 |
|
241 |
-
|
242 |
-
"Mongolian Male": "mn-MN-BataaNeural",
|
243 |
-
"Mongolian Female": "mn-MN-YesuiNeural"
|
244 |
-
}
|
245 |
-
|
246 |
hubert_model = load_hubert()
|
247 |
-
|
248 |
-
rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
|
|
|
34 |
|
35 |
config = Config()
|
36 |
|
37 |
+
# Edge TTS voices
|
38 |
+
loop = asyncio.get_event_loop()
|
39 |
+
tts_voice_list = loop.run_until_complete(edge_tts.list_voices())
|
40 |
+
tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"]
|
41 |
|
42 |
+
# RVC models directory
|
43 |
model_root = "weights"
|
44 |
models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
|
45 |
models.sort()
|
|
|
47 |
def get_unique_filename(extension):
|
48 |
return f"{uuid.uuid4()}.{extension}"
|
49 |
|
50 |
+
model_cache = {}
|
51 |
+
|
52 |
def model_data(model_name):
|
53 |
+
if model_name in model_cache:
|
54 |
+
return model_cache[model_name]
|
55 |
+
|
56 |
pth_path = [
|
57 |
f"{model_root}/{model_name}/{f}"
|
58 |
for f in os.listdir(f"{model_root}/{model_name}")
|
|
|
98 |
index_file = index_files[0]
|
99 |
print(f"Index file found: {index_file}")
|
100 |
|
101 |
+
model_cache[model_name] = (tgt_sr, net_g, vc, version, index_file, if_f0)
|
102 |
+
return model_cache[model_name]
|
103 |
|
104 |
def load_hubert():
|
105 |
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
|
|
115 |
return hubert_model.eval()
|
116 |
|
117 |
def get_model_names():
|
|
|
118 |
return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
|
119 |
|
120 |
+
# Initialize a global ThreadPoolExecutor
|
121 |
+
executor = ThreadPoolExecutor(max_workers=20) # Adjust based on your server
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
def parallel_tts(tasks):
|
124 |
+
futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
|
125 |
+
results = [future.result() for future in futures]
|
|
|
126 |
return results
|
127 |
|
128 |
async def tts(
|
|
|
133 |
use_uploaded_voice,
|
134 |
uploaded_voice,
|
135 |
):
|
136 |
+
# Default values for parameters
|
137 |
speed = 0 # Default speech speed
|
138 |
f0_up_key = 0 # Default pitch adjustment
|
139 |
f0_method = "rmvpe" # Default pitch extraction method
|
|
|
160 |
# EdgeTTS processing
|
161 |
if limitation and len(tts_text) > 12000:
|
162 |
return (
|
163 |
+
f"Text characters should be at most 12000 in this Hugging Face Space, but got {len(tts_text)} characters.",
|
164 |
None,
|
165 |
None,
|
166 |
)
|
|
|
179 |
# Common processing after loading the audio
|
180 |
duration = len(audio) / sr
|
181 |
print(f"Audio duration: {duration}s")
|
182 |
+
if limitation and duration >= 20:
|
183 |
return (
|
184 |
+
f"Audio should be less than 20 seconds in this Hugging Face Space, but got {duration}s.",
|
185 |
None,
|
186 |
None,
|
187 |
)
|
188 |
|
189 |
f0_up_key = int(f0_up_key)
|
190 |
+
# Load the model using cached data
|
191 |
tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
|
192 |
|
193 |
# Setup for RMVPE or other pitch extraction methods
|
|
|
230 |
|
231 |
except EOFError:
|
232 |
info = (
|
233 |
+
"Output not valid. This may occur when input text and speaker do not match."
|
234 |
)
|
235 |
print(info)
|
236 |
return info, None, None
|
|
|
239 |
print(traceback_info)
|
240 |
return str(e), None, None
|
241 |
|
242 |
+
# Initialize the global models
|
|
|
|
|
|
|
|
|
243 |
hubert_model = load_hubert()
|
244 |
+
rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
|
|