MAZALA2024 commited on
Commit
58f1964
·
verified ·
1 Parent(s): 62c32a4

Update voice_processing.py

Browse files
Files changed (1) hide show
  1. voice_processing.py +68 -59
voice_processing.py CHANGED
@@ -6,8 +6,6 @@ import time
6
  import traceback
7
  import tempfile
8
  from concurrent.futures import ThreadPoolExecutor
9
- import base64
10
-
11
 
12
  import edge_tts
13
  import librosa
@@ -45,13 +43,6 @@ model_root = "weights"
45
  models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
46
  models.sort()
47
 
48
- def get_voices():
49
- return list(voice_mapping.keys())
50
-
51
- def get_model_names():
52
- model_root = "weights" # Adjust this path if your models are stored elsewhere
53
- return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
54
-
55
  def get_unique_filename(extension):
56
  return f"{uuid.uuid4()}.{extension}"
57
 
@@ -116,6 +107,10 @@ def load_hubert():
116
  hubert_model = hubert_model.float()
117
  return hubert_model.eval()
118
 
 
 
 
 
119
  # Add this helper function to ensure a new event loop is created if none exists
120
  def run_async_in_thread(fn, *args):
121
  loop = asyncio.new_event_loop()
@@ -138,47 +133,67 @@ async def tts(
138
  use_uploaded_voice,
139
  uploaded_voice,
140
  ):
 
 
 
 
 
 
 
 
 
 
141
  edge_output_filename = get_unique_filename("mp3")
142
- try:
143
- # Default values for parameters
144
- speed = 0
145
- f0_up_key = 0
146
- f0_method = "rmvpe"
147
- protect = 0.33
148
- filter_radius = 3
149
- resample_sr = 0
150
- rms_mix_rate = 0.25
151
- edge_time = 0
152
 
 
153
  if use_uploaded_voice:
154
  if uploaded_voice is None:
155
- raise ValueError("No voice file uploaded.")
156
 
 
157
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
158
  tmp_file.write(uploaded_voice)
159
  uploaded_file_path = tmp_file.name
 
160
  audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
161
  else:
 
162
  if limitation and len(tts_text) > 12000:
163
- raise ValueError(f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.")
 
 
 
 
164
 
 
165
  t0 = time.time()
166
  speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
167
- await edge_tts.Communicate(tts_text, tts_voice, rate=speed_str).save(edge_output_filename)
168
- edge_time = time.time() - t0
 
 
 
 
169
  audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
170
 
 
171
  duration = len(audio) / sr
172
  print(f"Audio duration: {duration}s")
173
  if limitation and duration >= 20000:
174
- raise ValueError(f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.")
 
 
 
 
175
 
176
  f0_up_key = int(f0_up_key)
177
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
178
 
 
179
  if f0_method == "rmvpe":
180
  vc.model_rmvpe = rmvpe_model
181
 
 
182
  times = [0, 0, 0]
183
  audio_opt = vc.pipeline(
184
  hubert_model,
@@ -204,49 +219,40 @@ async def tts(
204
  if tgt_sr != resample_sr and resample_sr >= 16000:
205
  tgt_sr = resample_sr
206
 
207
- info = f"Success. Time: tts: {edge_time:.2f}s, npy: {times[0]:.2f}s, f0: {times[1]:.2f}s, infer: {times[2]:.2f}s"
208
  print(info)
209
-
210
- # Convert audio to base64
211
- with open(edge_output_filename, "rb") as audio_file:
212
- audio_base64 = base64.b64encode(audio_file.read()).decode('utf-8')
213
-
214
- audio_data_uri = f"data:audio/mp3;base64,{audio_base64}"
215
-
216
  return (
217
  info,
218
- audio_data_uri,
219
- (tgt_sr, audio_opt) # Return the target sample rate and audio output
220
  )
221
 
 
 
 
 
 
 
222
  except Exception as e:
223
- print(f"Error in TTS task: {str(e)}")
224
- import traceback
225
- traceback.print_exc()
226
- if os.path.exists(edge_output_filename):
227
- os.remove(edge_output_filename)
228
- return (str(e), None, None)
229
 
230
  voice_mapping = {
231
  "Mongolian Male": "mn-MN-BataaNeural",
232
  "Mongolian Female": "mn-MN-YesuiNeural"
233
- # Add more mappings as needed
234
  }
235
 
236
  hubert_model = load_hubert()
237
 
238
  rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
239
 
240
- # Global semaphore to control concurrency
241
- max_concurrent_tasks = 16 # Adjust based on server capacity
242
- semaphore = asyncio.Semaphore(max_concurrent_tasks)
243
-
244
- # Global ThreadPoolExecutor
245
- executor = ThreadPoolExecutor(max_workers=max_concurrent_tasks)
246
-
247
  class TTSProcessor:
248
  def __init__(self, config):
249
  self.config = config
 
 
250
  self.queue = asyncio.Queue()
251
  self.is_processing = False
252
 
@@ -260,28 +266,31 @@ class TTSProcessor:
260
  return await task
261
 
262
  async def _tts_task(self, model_name, tts_text, tts_voice, index_rate, use_uploaded_voice, uploaded_voice):
263
- async with semaphore:
264
  return await tts(model_name, tts_text, tts_voice, index_rate, use_uploaded_voice, uploaded_voice)
265
 
266
  async def _process_queue(self):
267
  self.is_processing = True
268
  while not self.queue.empty():
269
  task = await self.queue.get()
270
- try:
271
- await task
272
- except asyncio.CancelledError:
273
- print("Task was cancelled")
274
- except Exception as e:
275
- print(f"Task failed with error: {e}")
276
- finally:
277
- self.queue.task_done()
278
  self.is_processing = False
279
 
280
  # Initialize the TTSProcessor
281
  tts_processor = TTSProcessor(config)
282
 
 
283
  async def parallel_tts_processor(tasks):
284
  return await asyncio.gather(*(tts_processor.tts(*task) for task in tasks))
285
 
286
- async def parallel_tts_wrapper(tasks):
287
- return await parallel_tts_processor(tasks)
 
 
 
 
 
 
 
 
 
6
  import traceback
7
  import tempfile
8
  from concurrent.futures import ThreadPoolExecutor
 
 
9
 
10
  import edge_tts
11
  import librosa
 
43
  models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
44
  models.sort()
45
 
 
 
 
 
 
 
 
46
  def get_unique_filename(extension):
47
  return f"{uuid.uuid4()}.{extension}"
48
 
 
107
  hubert_model = hubert_model.float()
108
  return hubert_model.eval()
109
 
110
+ def get_model_names():
111
+ model_root = "weights" # Assuming this is where your models are stored
112
+ return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
113
+
114
  # Add this helper function to ensure a new event loop is created if none exists
115
  def run_async_in_thread(fn, *args):
116
  loop = asyncio.new_event_loop()
 
133
  use_uploaded_voice,
134
  uploaded_voice,
135
  ):
136
+ # Default values for parameters used in EdgeTTS
137
+ speed = 0 # Default speech speed
138
+ f0_up_key = 0 # Default pitch adjustment
139
+ f0_method = "rmvpe" # Default pitch extraction method
140
+ protect = 0.33 # Default protect value
141
+ filter_radius = 3
142
+ resample_sr = 0
143
+ rms_mix_rate = 0.25
144
+ edge_time = 0 # Initialize edge_time
145
+
146
  edge_output_filename = get_unique_filename("mp3")
 
 
 
 
 
 
 
 
 
 
147
 
148
+ try:
149
  if use_uploaded_voice:
150
  if uploaded_voice is None:
151
+ return "No voice file uploaded.", None, None
152
 
153
+ # Process the uploaded voice file
154
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
155
  tmp_file.write(uploaded_voice)
156
  uploaded_file_path = tmp_file.name
157
+
158
  audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
159
  else:
160
+ # EdgeTTS processing
161
  if limitation and len(tts_text) > 12000:
162
+ return (
163
+ f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.",
164
+ None,
165
+ None,
166
+ )
167
 
168
+ # Invoke Edge TTS
169
  t0 = time.time()
170
  speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
171
+ await edge_tts.Communicate(
172
+ tts_text, tts_voice, rate=speed_str
173
+ ).save(edge_output_filename)
174
+ t1 = time.time()
175
+ edge_time = t1 - t0
176
+
177
  audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
178
 
179
+ # Common processing after loading the audio
180
  duration = len(audio) / sr
181
  print(f"Audio duration: {duration}s")
182
  if limitation and duration >= 20000:
183
+ return (
184
+ f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
185
+ None,
186
+ None,
187
+ )
188
 
189
  f0_up_key = int(f0_up_key)
190
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
191
 
192
+ # Setup for RMVPE or other pitch extraction methods
193
  if f0_method == "rmvpe":
194
  vc.model_rmvpe = rmvpe_model
195
 
196
+ # Perform voice conversion pipeline
197
  times = [0, 0, 0]
198
  audio_opt = vc.pipeline(
199
  hubert_model,
 
219
  if tgt_sr != resample_sr and resample_sr >= 16000:
220
  tgt_sr = resample_sr
221
 
222
+ info = f"Success. Time: tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
223
  print(info)
 
 
 
 
 
 
 
224
  return (
225
  info,
226
+ edge_output_filename if not use_uploaded_voice else None,
227
+ (tgt_sr, audio_opt),
228
  )
229
 
230
+ except EOFError:
231
+ info = (
232
+ "output not valid. This may occur when input text and speaker do not match."
233
+ )
234
+ print(info)
235
+ return info, None, None
236
  except Exception as e:
237
+ traceback_info = traceback.format_exc()
238
+ print(traceback_info)
239
+ return str(e), None, None
 
 
 
240
 
241
  voice_mapping = {
242
  "Mongolian Male": "mn-MN-BataaNeural",
243
  "Mongolian Female": "mn-MN-YesuiNeural"
 
244
  }
245
 
246
  hubert_model = load_hubert()
247
 
248
  rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
249
 
250
+ # Add the optimized TTSProcessor
 
 
 
 
 
 
251
  class TTSProcessor:
252
  def __init__(self, config):
253
  self.config = config
254
+ self.executor = ThreadPoolExecutor(max_workers=config.n_cpu)
255
+ self.semaphore = asyncio.Semaphore(config.max_concurrent_tts)
256
  self.queue = asyncio.Queue()
257
  self.is_processing = False
258
 
 
266
  return await task
267
 
268
  async def _tts_task(self, model_name, tts_text, tts_voice, index_rate, use_uploaded_voice, uploaded_voice):
269
+ async with self.semaphore:
270
  return await tts(model_name, tts_text, tts_voice, index_rate, use_uploaded_voice, uploaded_voice)
271
 
272
  async def _process_queue(self):
273
  self.is_processing = True
274
  while not self.queue.empty():
275
  task = await self.queue.get()
276
+ await task
277
+ self.queue.task_done()
 
 
 
 
 
 
278
  self.is_processing = False
279
 
280
  # Initialize the TTSProcessor
281
  tts_processor = TTSProcessor(config)
282
 
283
+ # Update parallel_tts to use TTSProcessor
284
  async def parallel_tts_processor(tasks):
285
  return await asyncio.gather(*(tts_processor.tts(*task) for task in tasks))
286
 
287
+ def parallel_tts_wrapper(tasks):
288
+ loop = asyncio.get_event_loop()
289
+ return loop.run_until_complete(parallel_tts_processor(tasks))
290
+
291
+ # Keep the original parallel_tts function
292
+ # def parallel_tts(tasks):
293
+ # with ThreadPoolExecutor() as executor:
294
+ # futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
295
+ # results = [future.result() for future in futures]
296
+ # return results