MAZALA2024 commited on
Commit
e88a126
·
verified ·
1 Parent(s): 49d1e2e

Update voice_processing.py

Browse files
Files changed (1) hide show
  1. voice_processing.py +124 -118
voice_processing.py CHANGED
@@ -53,7 +53,7 @@ def model_data(model_name):
53
  if f.endswith(".pth")
54
  ][0]
55
  print(f"Loading {pth_path}")
56
- cpt = torch.load(pth_path, map_location="cpu", weights_only=True)
57
  tgt_sr = cpt["config"][-1]
58
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
59
  if_f0 = cpt.get("f0", 1)
@@ -108,135 +108,141 @@ def load_hubert():
108
  return hubert_model.eval()
109
 
110
  def get_model_names():
 
111
  return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
112
 
113
- import nest_asyncio
114
- nest_asyncio.apply()
115
-
116
- def get_unique_filename(extension):
117
- return f"{uuid.uuid4().hex[:8]}.{extension}"
118
-
119
- class TTSProcessor:
120
- def __init__(self, config):
121
- self.config = config
122
- self.executor = ThreadPoolExecutor(max_workers=config.n_cpu)
123
- self.semaphore = asyncio.Semaphore(config.max_concurrent_tts)
124
- self.last_request_time = time.time()
125
- self.rate_limit = config.tts_rate_limit
126
- self.temp_dir = tempfile.mkdtemp()
127
-
128
- async def tts(self, model_name, tts_text, tts_voice, index_rate, use_uploaded_voice, uploaded_voice):
129
- async with self.semaphore:
130
- current_time = time.time()
131
- time_since_last_request = current_time - self.last_request_time
132
- if time_since_last_request < 1 / self.rate_limit:
133
- await asyncio.sleep(1 / self.rate_limit - time_since_last_request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- self.last_request_time = time.time()
136
-
137
- loop = asyncio.get_running_loop()
138
- return await loop.run_in_executor(
139
- self.executor,
140
- self._tts_process,
141
- model_name,
142
- tts_text,
143
- tts_voice,
144
- index_rate,
145
- use_uploaded_voice,
146
- uploaded_voice
147
- )
148
 
149
- def _tts_process(self, model_name, tts_text, tts_voice, index_rate, use_uploaded_voice, uploaded_voice):
150
- try:
151
- edge_output_filename = os.path.join(self.temp_dir, get_unique_filename("mp3"))
152
-
153
- if use_uploaded_voice:
154
- if uploaded_voice is None:
155
- return "No voice file uploaded.", None, None
156
-
157
- uploaded_file_path = os.path.join(self.temp_dir, get_unique_filename("wav"))
158
- with open(uploaded_file_path, "wb") as f:
159
- f.write(uploaded_voice)
160
-
161
- audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
162
- else:
163
- if limitation and len(tts_text) > 12000:
164
- return (
165
- f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.",
166
- None,
167
- None,
168
- )
169
-
170
- speed = 0 # Default speech speed
171
- speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
172
-
173
- # Use synchronous approach for Edge TTS
174
- communicate = edge_tts.Communicate(tts_text, tts_voice, rate=speed_str)
175
- asyncio.get_event_loop().run_until_complete(communicate.save(edge_output_filename))
176
-
177
- audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
178
-
179
- duration = len(audio) / sr
180
- if limitation and duration >= 20000:
181
  return (
182
- f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
183
  None,
184
  None,
185
  )
186
-
187
- f0_up_key = 0
188
- tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
189
-
190
- if hasattr(self, 'model_rmvpe'):
191
- vc.model_rmvpe = self.model_rmvpe
192
-
193
- times = [0, 0, 0]
194
- audio_opt = vc.pipeline(
195
- hubert_model,
196
- net_g,
197
- 0,
198
- audio,
199
- edge_output_filename if not use_uploaded_voice else uploaded_file_path,
200
- times,
201
- f0_up_key,
202
- "rmvpe",
203
- index_file,
204
- index_rate,
205
- if_f0,
206
- 3, # filter_radius
207
- tgt_sr,
208
- 0, # resample_sr
209
- 0.25, # rms_mix_rate
210
- version,
211
- 0.33, # protect
212
- None,
213
- )
214
-
215
- info = f"Success. Time: tts: {times[0]}s, npy: {times[1]}s, f0: {times[2]}s"
216
- print(info)
217
  return (
218
- info,
219
- edge_output_filename if not use_uploaded_voice else None,
220
- (tgt_sr, audio_opt),
221
  )
222
 
223
- except Exception as e:
224
- logging.error(f"Error in TTS processing: {str(e)}")
225
- logging.error(traceback.format_exc())
226
- return str(e), None, None
227
-
228
- def __del__(self):
229
- # Clean up temporary directory
230
- import shutil
231
- shutil.rmtree(self.temp_dir, ignore_errors=True)
232
-
233
- # Initialize global variables
234
- tts_processor = TTSProcessor(config)
235
- hubert_model = load_hubert()
236
- rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
237
- tts_processor.model_rmvpe = rmvpe_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  voice_mapping = {
240
  "Mongolian Male": "mn-MN-BataaNeural",
241
  "Mongolian Female": "mn-MN-YesuiNeural"
242
- }
 
 
 
 
 
53
  if f.endswith(".pth")
54
  ][0]
55
  print(f"Loading {pth_path}")
56
+ cpt = torch.load(pth_path, map_location="cpu")
57
  tgt_sr = cpt["config"][-1]
58
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
59
  if_f0 = cpt.get("f0", 1)
 
108
  return hubert_model.eval()
109
 
110
  def get_model_names():
111
+ model_root = "weights" # Assuming this is where your models are stored
112
  return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
113
 
114
+ # Add this helper function to ensure a new event loop is created if none exists
115
+ def run_async_in_thread(fn, *args):
116
+ loop = asyncio.new_event_loop()
117
+ asyncio.set_event_loop(loop)
118
+ result = loop.run_until_complete(fn(*args))
119
+ loop.close()
120
+ return result
121
+
122
+ def parallel_tts(tasks):
123
+ with ThreadPoolExecutor() as executor:
124
+ futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
125
+ results = [future.result() for future in futures]
126
+ return results
127
+
128
+ async def tts(
129
+ model_name,
130
+ tts_text,
131
+ tts_voice,
132
+ index_rate,
133
+ use_uploaded_voice,
134
+ uploaded_voice,
135
+ ):
136
+ # Default values for parameters used in EdgeTTS
137
+ speed = 0 # Default speech speed
138
+ f0_up_key = 0 # Default pitch adjustment
139
+ f0_method = "rmvpe" # Default pitch extraction method
140
+ protect = 0.33 # Default protect value
141
+ filter_radius = 3
142
+ resample_sr = 0
143
+ rms_mix_rate = 0.25
144
+ edge_time = 0 # Initialize edge_time
145
+
146
+ edge_output_filename = get_unique_filename("mp3")
147
+
148
+ try:
149
+ if use_uploaded_voice:
150
+ if uploaded_voice is None:
151
+ return "No voice file uploaded.", None, None
152
 
153
+ # Process the uploaded voice file
154
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
155
+ tmp_file.write(uploaded_voice)
156
+ uploaded_file_path = tmp_file.name
 
 
 
 
 
 
 
 
 
157
 
158
+ audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
159
+ else:
160
+ # EdgeTTS processing
161
+ if limitation and len(tts_text) > 12000:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  return (
163
+ f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.",
164
  None,
165
  None,
166
  )
167
+
168
+ # Invoke Edge TTS
169
+ t0 = time.time()
170
+ speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
171
+ await edge_tts.Communicate(
172
+ tts_text, tts_voice, rate=speed_str
173
+ ).save(edge_output_filename)
174
+ t1 = time.time()
175
+ edge_time = t1 - t0
176
+
177
+ audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
178
+
179
+ # Common processing after loading the audio
180
+ duration = len(audio) / sr
181
+ print(f"Audio duration: {duration}s")
182
+ if limitation and duration >= 20000:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  return (
184
+ f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
185
+ None,
186
+ None,
187
  )
188
 
189
+ f0_up_key = int(f0_up_key)
190
+ tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
191
+
192
+ # Setup for RMVPE or other pitch extraction methods
193
+ if f0_method == "rmvpe":
194
+ vc.model_rmvpe = rmvpe_model
195
+
196
+ # Perform voice conversion pipeline
197
+ times = [0, 0, 0]
198
+ audio_opt = vc.pipeline(
199
+ hubert_model,
200
+ net_g,
201
+ 0,
202
+ audio,
203
+ edge_output_filename if not use_uploaded_voice else uploaded_file_path,
204
+ times,
205
+ f0_up_key,
206
+ f0_method,
207
+ index_file,
208
+ index_rate,
209
+ if_f0,
210
+ filter_radius,
211
+ tgt_sr,
212
+ resample_sr,
213
+ rms_mix_rate,
214
+ version,
215
+ protect,
216
+ None,
217
+ )
218
+
219
+ if tgt_sr != resample_sr and resample_sr >= 16000:
220
+ tgt_sr = resample_sr
221
+
222
+ info = f"Success. Time: tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
223
+ print(info)
224
+ return (
225
+ info,
226
+ edge_output_filename if not use_uploaded_voice else None,
227
+ (tgt_sr, audio_opt),
228
+ )
229
+
230
+ except EOFError:
231
+ info = (
232
+ "output not valid. This may occur when input text and speaker do not match."
233
+ )
234
+ print(info)
235
+ return info, None, None
236
+ except Exception as e:
237
+ traceback_info = traceback.format_exc()
238
+ print(traceback_info)
239
+ return str(e), None, None
240
 
241
  voice_mapping = {
242
  "Mongolian Male": "mn-MN-BataaNeural",
243
  "Mongolian Female": "mn-MN-YesuiNeural"
244
+ }
245
+
246
+ hubert_model = load_hubert()
247
+
248
+ rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)