Update voice_processing.py
Browse files- voice_processing.py +68 -59
voice_processing.py
CHANGED
@@ -6,8 +6,6 @@ import time
|
|
6 |
import traceback
|
7 |
import tempfile
|
8 |
from concurrent.futures import ThreadPoolExecutor
|
9 |
-
import base64
|
10 |
-
|
11 |
|
12 |
import edge_tts
|
13 |
import librosa
|
@@ -45,13 +43,6 @@ model_root = "weights"
|
|
45 |
models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
|
46 |
models.sort()
|
47 |
|
48 |
-
def get_voices():
|
49 |
-
return list(voice_mapping.keys())
|
50 |
-
|
51 |
-
def get_model_names():
|
52 |
-
model_root = "weights" # Adjust this path if your models are stored elsewhere
|
53 |
-
return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
|
54 |
-
|
55 |
def get_unique_filename(extension):
|
56 |
return f"{uuid.uuid4()}.{extension}"
|
57 |
|
@@ -116,6 +107,10 @@ def load_hubert():
|
|
116 |
hubert_model = hubert_model.float()
|
117 |
return hubert_model.eval()
|
118 |
|
|
|
|
|
|
|
|
|
119 |
# Add this helper function to ensure a new event loop is created if none exists
|
120 |
def run_async_in_thread(fn, *args):
|
121 |
loop = asyncio.new_event_loop()
|
@@ -138,47 +133,67 @@ async def tts(
|
|
138 |
use_uploaded_voice,
|
139 |
uploaded_voice,
|
140 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
edge_output_filename = get_unique_filename("mp3")
|
142 |
-
try:
|
143 |
-
# Default values for parameters
|
144 |
-
speed = 0
|
145 |
-
f0_up_key = 0
|
146 |
-
f0_method = "rmvpe"
|
147 |
-
protect = 0.33
|
148 |
-
filter_radius = 3
|
149 |
-
resample_sr = 0
|
150 |
-
rms_mix_rate = 0.25
|
151 |
-
edge_time = 0
|
152 |
|
|
|
153 |
if use_uploaded_voice:
|
154 |
if uploaded_voice is None:
|
155 |
-
|
156 |
|
|
|
157 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
158 |
tmp_file.write(uploaded_voice)
|
159 |
uploaded_file_path = tmp_file.name
|
|
|
160 |
audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
|
161 |
else:
|
|
|
162 |
if limitation and len(tts_text) > 12000:
|
163 |
-
|
|
|
|
|
|
|
|
|
164 |
|
|
|
165 |
t0 = time.time()
|
166 |
speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
|
167 |
-
await edge_tts.Communicate(
|
168 |
-
|
|
|
|
|
|
|
|
|
169 |
audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
|
170 |
|
|
|
171 |
duration = len(audio) / sr
|
172 |
print(f"Audio duration: {duration}s")
|
173 |
if limitation and duration >= 20000:
|
174 |
-
|
|
|
|
|
|
|
|
|
175 |
|
176 |
f0_up_key = int(f0_up_key)
|
177 |
tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
|
178 |
|
|
|
179 |
if f0_method == "rmvpe":
|
180 |
vc.model_rmvpe = rmvpe_model
|
181 |
|
|
|
182 |
times = [0, 0, 0]
|
183 |
audio_opt = vc.pipeline(
|
184 |
hubert_model,
|
@@ -204,49 +219,40 @@ async def tts(
|
|
204 |
if tgt_sr != resample_sr and resample_sr >= 16000:
|
205 |
tgt_sr = resample_sr
|
206 |
|
207 |
-
info = f"Success. Time: tts: {edge_time
|
208 |
print(info)
|
209 |
-
|
210 |
-
# Convert audio to base64
|
211 |
-
with open(edge_output_filename, "rb") as audio_file:
|
212 |
-
audio_base64 = base64.b64encode(audio_file.read()).decode('utf-8')
|
213 |
-
|
214 |
-
audio_data_uri = f"data:audio/mp3;base64,{audio_base64}"
|
215 |
-
|
216 |
return (
|
217 |
info,
|
218 |
-
|
219 |
-
(tgt_sr, audio_opt)
|
220 |
)
|
221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
except Exception as e:
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
if os.path.exists(edge_output_filename):
|
227 |
-
os.remove(edge_output_filename)
|
228 |
-
return (str(e), None, None)
|
229 |
|
230 |
voice_mapping = {
|
231 |
"Mongolian Male": "mn-MN-BataaNeural",
|
232 |
"Mongolian Female": "mn-MN-YesuiNeural"
|
233 |
-
# Add more mappings as needed
|
234 |
}
|
235 |
|
236 |
hubert_model = load_hubert()
|
237 |
|
238 |
rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
|
239 |
|
240 |
-
#
|
241 |
-
max_concurrent_tasks = 16 # Adjust based on server capacity
|
242 |
-
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
243 |
-
|
244 |
-
# Global ThreadPoolExecutor
|
245 |
-
executor = ThreadPoolExecutor(max_workers=max_concurrent_tasks)
|
246 |
-
|
247 |
class TTSProcessor:
|
248 |
def __init__(self, config):
|
249 |
self.config = config
|
|
|
|
|
250 |
self.queue = asyncio.Queue()
|
251 |
self.is_processing = False
|
252 |
|
@@ -260,28 +266,31 @@ class TTSProcessor:
|
|
260 |
return await task
|
261 |
|
262 |
async def _tts_task(self, model_name, tts_text, tts_voice, index_rate, use_uploaded_voice, uploaded_voice):
|
263 |
-
async with semaphore:
|
264 |
return await tts(model_name, tts_text, tts_voice, index_rate, use_uploaded_voice, uploaded_voice)
|
265 |
|
266 |
async def _process_queue(self):
|
267 |
self.is_processing = True
|
268 |
while not self.queue.empty():
|
269 |
task = await self.queue.get()
|
270 |
-
|
271 |
-
|
272 |
-
except asyncio.CancelledError:
|
273 |
-
print("Task was cancelled")
|
274 |
-
except Exception as e:
|
275 |
-
print(f"Task failed with error: {e}")
|
276 |
-
finally:
|
277 |
-
self.queue.task_done()
|
278 |
self.is_processing = False
|
279 |
|
280 |
# Initialize the TTSProcessor
|
281 |
tts_processor = TTSProcessor(config)
|
282 |
|
|
|
283 |
async def parallel_tts_processor(tasks):
|
284 |
return await asyncio.gather(*(tts_processor.tts(*task) for task in tasks))
|
285 |
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import traceback
|
7 |
import tempfile
|
8 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
9 |
|
10 |
import edge_tts
|
11 |
import librosa
|
|
|
43 |
models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
|
44 |
models.sort()
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
def get_unique_filename(extension):
|
47 |
return f"{uuid.uuid4()}.{extension}"
|
48 |
|
|
|
107 |
hubert_model = hubert_model.float()
|
108 |
return hubert_model.eval()
|
109 |
|
110 |
+
def get_model_names():
|
111 |
+
model_root = "weights" # Assuming this is where your models are stored
|
112 |
+
return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
|
113 |
+
|
114 |
# Add this helper function to ensure a new event loop is created if none exists
|
115 |
def run_async_in_thread(fn, *args):
|
116 |
loop = asyncio.new_event_loop()
|
|
|
133 |
use_uploaded_voice,
|
134 |
uploaded_voice,
|
135 |
):
|
136 |
+
# Default values for parameters used in EdgeTTS
|
137 |
+
speed = 0 # Default speech speed
|
138 |
+
f0_up_key = 0 # Default pitch adjustment
|
139 |
+
f0_method = "rmvpe" # Default pitch extraction method
|
140 |
+
protect = 0.33 # Default protect value
|
141 |
+
filter_radius = 3
|
142 |
+
resample_sr = 0
|
143 |
+
rms_mix_rate = 0.25
|
144 |
+
edge_time = 0 # Initialize edge_time
|
145 |
+
|
146 |
edge_output_filename = get_unique_filename("mp3")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
+
try:
|
149 |
if use_uploaded_voice:
|
150 |
if uploaded_voice is None:
|
151 |
+
return "No voice file uploaded.", None, None
|
152 |
|
153 |
+
# Process the uploaded voice file
|
154 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
155 |
tmp_file.write(uploaded_voice)
|
156 |
uploaded_file_path = tmp_file.name
|
157 |
+
|
158 |
audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
|
159 |
else:
|
160 |
+
# EdgeTTS processing
|
161 |
if limitation and len(tts_text) > 12000:
|
162 |
+
return (
|
163 |
+
f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.",
|
164 |
+
None,
|
165 |
+
None,
|
166 |
+
)
|
167 |
|
168 |
+
# Invoke Edge TTS
|
169 |
t0 = time.time()
|
170 |
speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
|
171 |
+
await edge_tts.Communicate(
|
172 |
+
tts_text, tts_voice, rate=speed_str
|
173 |
+
).save(edge_output_filename)
|
174 |
+
t1 = time.time()
|
175 |
+
edge_time = t1 - t0
|
176 |
+
|
177 |
audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
|
178 |
|
179 |
+
# Common processing after loading the audio
|
180 |
duration = len(audio) / sr
|
181 |
print(f"Audio duration: {duration}s")
|
182 |
if limitation and duration >= 20000:
|
183 |
+
return (
|
184 |
+
f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
|
185 |
+
None,
|
186 |
+
None,
|
187 |
+
)
|
188 |
|
189 |
f0_up_key = int(f0_up_key)
|
190 |
tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
|
191 |
|
192 |
+
# Setup for RMVPE or other pitch extraction methods
|
193 |
if f0_method == "rmvpe":
|
194 |
vc.model_rmvpe = rmvpe_model
|
195 |
|
196 |
+
# Perform voice conversion pipeline
|
197 |
times = [0, 0, 0]
|
198 |
audio_opt = vc.pipeline(
|
199 |
hubert_model,
|
|
|
219 |
if tgt_sr != resample_sr and resample_sr >= 16000:
|
220 |
tgt_sr = resample_sr
|
221 |
|
222 |
+
info = f"Success. Time: tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
|
223 |
print(info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
return (
|
225 |
info,
|
226 |
+
edge_output_filename if not use_uploaded_voice else None,
|
227 |
+
(tgt_sr, audio_opt),
|
228 |
)
|
229 |
|
230 |
+
except EOFError:
|
231 |
+
info = (
|
232 |
+
"output not valid. This may occur when input text and speaker do not match."
|
233 |
+
)
|
234 |
+
print(info)
|
235 |
+
return info, None, None
|
236 |
except Exception as e:
|
237 |
+
traceback_info = traceback.format_exc()
|
238 |
+
print(traceback_info)
|
239 |
+
return str(e), None, None
|
|
|
|
|
|
|
240 |
|
241 |
voice_mapping = {
|
242 |
"Mongolian Male": "mn-MN-BataaNeural",
|
243 |
"Mongolian Female": "mn-MN-YesuiNeural"
|
|
|
244 |
}
|
245 |
|
246 |
hubert_model = load_hubert()
|
247 |
|
248 |
rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
|
249 |
|
250 |
+
# Add the optimized TTSProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
class TTSProcessor:
|
252 |
def __init__(self, config):
|
253 |
self.config = config
|
254 |
+
self.executor = ThreadPoolExecutor(max_workers=config.n_cpu)
|
255 |
+
self.semaphore = asyncio.Semaphore(config.max_concurrent_tts)
|
256 |
self.queue = asyncio.Queue()
|
257 |
self.is_processing = False
|
258 |
|
|
|
266 |
return await task
|
267 |
|
268 |
async def _tts_task(self, model_name, tts_text, tts_voice, index_rate, use_uploaded_voice, uploaded_voice):
|
269 |
+
async with self.semaphore:
|
270 |
return await tts(model_name, tts_text, tts_voice, index_rate, use_uploaded_voice, uploaded_voice)
|
271 |
|
272 |
async def _process_queue(self):
|
273 |
self.is_processing = True
|
274 |
while not self.queue.empty():
|
275 |
task = await self.queue.get()
|
276 |
+
await task
|
277 |
+
self.queue.task_done()
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
self.is_processing = False
|
279 |
|
280 |
# Initialize the TTSProcessor
|
281 |
tts_processor = TTSProcessor(config)
|
282 |
|
283 |
+
# Update parallel_tts to use TTSProcessor
|
284 |
async def parallel_tts_processor(tasks):
|
285 |
return await asyncio.gather(*(tts_processor.tts(*task) for task in tasks))
|
286 |
|
287 |
+
def parallel_tts_wrapper(tasks):
|
288 |
+
loop = asyncio.get_event_loop()
|
289 |
+
return loop.run_until_complete(parallel_tts_processor(tasks))
|
290 |
+
|
291 |
+
# Keep the original parallel_tts function
|
292 |
+
# def parallel_tts(tasks):
|
293 |
+
# with ThreadPoolExecutor() as executor:
|
294 |
+
# futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
|
295 |
+
# results = [future.result() for future in futures]
|
296 |
+
# return results
|