MAZALA2024 commited on
Commit
6f00c3d
·
verified ·
1 Parent(s): d5183ee

Update voice_processing.py

Browse files
Files changed (1) hide show
  1. voice_processing.py +30 -45
voice_processing.py CHANGED
@@ -23,11 +23,7 @@ from lib.infer_pack.models import (
23
  from rmvpe import RMVPE
24
  from vc_infer_pipeline import VC
25
 
26
- # Set up logging
27
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
28
- logger = logging.getLogger(__name__)
29
-
30
- # Set logging levels for other libraries
31
  logging.getLogger("fairseq").setLevel(logging.WARNING)
32
  logging.getLogger("numba").setLevel(logging.WARNING)
33
  logging.getLogger("markdown_it").setLevel(logging.WARNING)
@@ -56,7 +52,7 @@ def model_data(model_name):
56
  for f in os.listdir(f"{model_root}/{model_name}")
57
  if f.endswith(".pth")
58
  ][0]
59
- logger.info(f"Loading {pth_path}")
60
  cpt = torch.load(pth_path, map_location="cpu")
61
  tgt_sr = cpt["config"][-1]
62
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
@@ -76,7 +72,7 @@ def model_data(model_name):
76
  raise ValueError("Unknown version")
77
  del net_g.enc_q
78
  net_g.load_state_dict(cpt["weight"], strict=False)
79
- logger.info("Model loaded")
80
  net_g.eval().to(config.device)
81
  if config.is_half:
82
  net_g = net_g.half()
@@ -90,11 +86,11 @@ def model_data(model_name):
90
  if f.endswith(".index")
91
  ]
92
  if len(index_files) == 0:
93
- logger.info("No index file found")
94
  index_file = ""
95
  else:
96
  index_file = index_files[0]
97
- logger.info(f"Index file found: {index_file}")
98
 
99
  return tgt_sr, net_g, vc, version, index_file, if_f0
100
 
@@ -123,8 +119,6 @@ def run_async_in_thread(fn, *args):
123
  loop.close()
124
  return result
125
 
126
- executor = ThreadPoolExecutor(max_workers=config.n_cpu)
127
-
128
  def parallel_tts(tasks):
129
  with ThreadPoolExecutor() as executor:
130
  futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
@@ -139,24 +133,21 @@ async def tts(
139
  use_uploaded_voice,
140
  uploaded_voice,
141
  ):
142
- try:
143
- # Default values for parameters used in EdgeTTS
144
- speed = 0 # Default speech speed
145
- f0_up_key = 0 # Default pitch adjustment
146
- f0_method = "rmvpe" # Default pitch extraction method
147
- protect = 0.33 # Default protect value
148
- filter_radius = 3
149
- resample_sr = 0
150
- rms_mix_rate = 0.25
151
- edge_time = 0 # Initialize edge_time
152
-
153
- edge_output_filename = get_unique_filename("mp3")
154
-
155
- logger.info(f"Starting TTS process for text: {tts_text[:50]}...")
156
 
 
157
  if use_uploaded_voice:
158
  if uploaded_voice is None:
159
- logger.error("No voice file uploaded.")
160
  return "No voice file uploaded.", None, None
161
 
162
  # Process the uploaded voice file
@@ -165,11 +156,9 @@ async def tts(
165
  uploaded_file_path = tmp_file.name
166
 
167
  audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
168
- logger.info(f"Uploaded voice file loaded. Shape: {audio.shape}, SR: {sr}")
169
  else:
170
  # EdgeTTS processing
171
  if limitation and len(tts_text) > 12000:
172
- logger.error(f"Text characters exceed limit: {len(tts_text)} characters.")
173
  return (
174
  f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.",
175
  None,
@@ -186,13 +175,11 @@ async def tts(
186
  edge_time = t1 - t0
187
 
188
  audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
189
- logger.info(f"Edge TTS audio generated. Shape: {audio.shape}, SR: {sr}")
190
 
191
  # Common processing after loading the audio
192
  duration = len(audio) / sr
193
- logger.info(f"Audio duration: {duration}s")
194
  if limitation and duration >= 20000:
195
- logger.error(f"Audio duration exceeds limit: {duration}s")
196
  return (
197
  f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
198
  None,
@@ -208,7 +195,6 @@ async def tts(
208
 
209
  # Perform voice conversion pipeline
210
  times = [0, 0, 0]
211
- logger.info(f"Starting voice conversion with audio shape: {audio.shape}")
212
  audio_opt = vc.pipeline(
213
  hubert_model,
214
  net_g,
@@ -229,22 +215,28 @@ async def tts(
229
  protect,
230
  None,
231
  )
232
- logger.info(f"Voice conversion completed. Output shape: {audio_opt.shape}")
233
 
234
  if tgt_sr != resample_sr and resample_sr >= 16000:
235
  tgt_sr = resample_sr
236
 
237
- info = f"Success. Time: tts: {edge_time:.2f}s, npy: {times[0]:.2f}s, f0: {times[1]:.2f}s, infer: {times[2]:.2f}s"
238
- logger.info(info)
239
  return (
240
  info,
241
  edge_output_filename if not use_uploaded_voice else None,
242
  (tgt_sr, audio_opt),
243
  )
244
 
 
 
 
 
 
 
245
  except Exception as e:
246
- logger.exception("Error in TTS processing")
247
- return str(e), None, (None, None)
 
248
 
249
  voice_mapping = {
250
  "Mongolian Male": "mn-MN-BataaNeural",
@@ -294,11 +286,4 @@ async def parallel_tts_processor(tasks):
294
 
295
  def parallel_tts_wrapper(tasks):
296
  loop = asyncio.get_event_loop()
297
- return loop.run_until_complete(parallel_tts_processor(tasks))
298
-
299
- # Keep the original parallel_tts function
300
- # def parallel_tts(tasks):
301
- # with ThreadPoolExecutor() as executor:
302
- # futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
303
- # results = [future.result() for future in futures]
304
- # return results
 
23
  from rmvpe import RMVPE
24
  from vc_infer_pipeline import VC
25
 
26
+ # Set logging levels
 
 
 
 
27
  logging.getLogger("fairseq").setLevel(logging.WARNING)
28
  logging.getLogger("numba").setLevel(logging.WARNING)
29
  logging.getLogger("markdown_it").setLevel(logging.WARNING)
 
52
  for f in os.listdir(f"{model_root}/{model_name}")
53
  if f.endswith(".pth")
54
  ][0]
55
+ print(f"Loading {pth_path}")
56
  cpt = torch.load(pth_path, map_location="cpu")
57
  tgt_sr = cpt["config"][-1]
58
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
 
72
  raise ValueError("Unknown version")
73
  del net_g.enc_q
74
  net_g.load_state_dict(cpt["weight"], strict=False)
75
+ print("Model loaded")
76
  net_g.eval().to(config.device)
77
  if config.is_half:
78
  net_g = net_g.half()
 
86
  if f.endswith(".index")
87
  ]
88
  if len(index_files) == 0:
89
+ print("No index file found")
90
  index_file = ""
91
  else:
92
  index_file = index_files[0]
93
+ print(f"Index file found: {index_file}")
94
 
95
  return tgt_sr, net_g, vc, version, index_file, if_f0
96
 
 
119
  loop.close()
120
  return result
121
 
 
 
122
  def parallel_tts(tasks):
123
  with ThreadPoolExecutor() as executor:
124
  futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
 
133
  use_uploaded_voice,
134
  uploaded_voice,
135
  ):
136
+ # Default values for parameters used in EdgeTTS
137
+ speed = 0 # Default speech speed
138
+ f0_up_key = 0 # Default pitch adjustment
139
+ f0_method = "rmvpe" # Default pitch extraction method
140
+ protect = 0.33 # Default protect value
141
+ filter_radius = 3
142
+ resample_sr = 0
143
+ rms_mix_rate = 0.25
144
+ edge_time = 0 # Initialize edge_time
145
+
146
+ edge_output_filename = get_unique_filename("mp3")
 
 
 
147
 
148
+ try:
149
  if use_uploaded_voice:
150
  if uploaded_voice is None:
 
151
  return "No voice file uploaded.", None, None
152
 
153
  # Process the uploaded voice file
 
156
  uploaded_file_path = tmp_file.name
157
 
158
  audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
 
159
  else:
160
  # EdgeTTS processing
161
  if limitation and len(tts_text) > 12000:
 
162
  return (
163
  f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.",
164
  None,
 
175
  edge_time = t1 - t0
176
 
177
  audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
 
178
 
179
  # Common processing after loading the audio
180
  duration = len(audio) / sr
181
+ print(f"Audio duration: {duration}s")
182
  if limitation and duration >= 20000:
 
183
  return (
184
  f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
185
  None,
 
195
 
196
  # Perform voice conversion pipeline
197
  times = [0, 0, 0]
 
198
  audio_opt = vc.pipeline(
199
  hubert_model,
200
  net_g,
 
215
  protect,
216
  None,
217
  )
 
218
 
219
  if tgt_sr != resample_sr and resample_sr >= 16000:
220
  tgt_sr = resample_sr
221
 
222
+ info = f"Success. Time: tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
223
+ print(info)
224
  return (
225
  info,
226
  edge_output_filename if not use_uploaded_voice else None,
227
  (tgt_sr, audio_opt),
228
  )
229
 
230
+ except EOFError:
231
+ info = (
232
+ "output not valid. This may occur when input text and speaker do not match."
233
+ )
234
+ print(info)
235
+ return info, None, None
236
  except Exception as e:
237
+ traceback_info = traceback.format_exc()
238
+ print(traceback_info)
239
+ return str(e), None, None
240
 
241
  voice_mapping = {
242
  "Mongolian Male": "mn-MN-BataaNeural",
 
286
 
287
  def parallel_tts_wrapper(tasks):
288
  loop = asyncio.get_event_loop()
289
+ return loop.run_until_complete(parallel_tts_processor(tasks))