MAZALA2024 commited on
Commit
b2461b4
·
verified ·
1 Parent(s): 3ae57b7

Update voice_processing.py

Browse files
Files changed (1) hide show
  1. voice_processing.py +24 -28
voice_processing.py CHANGED
@@ -34,11 +34,12 @@ limitation = os.getenv("SYSTEM") == "spaces"
34
 
35
  config = Config()
36
 
37
- # Edge TTS
38
- tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
39
- tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"] # Specific voices
 
40
 
41
- # RVC models
42
  model_root = "weights"
43
  models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
44
  models.sort()
@@ -46,7 +47,12 @@ models.sort()
46
  def get_unique_filename(extension):
47
  return f"{uuid.uuid4()}.{extension}"
48
 
 
 
49
  def model_data(model_name):
 
 
 
50
  pth_path = [
51
  f"{model_root}/{model_name}/{f}"
52
  for f in os.listdir(f"{model_root}/{model_name}")
@@ -92,7 +98,8 @@ def model_data(model_name):
92
  index_file = index_files[0]
93
  print(f"Index file found: {index_file}")
94
 
95
- return tgt_sr, net_g, vc, version, index_file, if_f0
 
96
 
97
  def load_hubert():
98
  models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
@@ -108,21 +115,14 @@ def load_hubert():
108
  return hubert_model.eval()
109
 
110
  def get_model_names():
111
- model_root = "weights" # Assuming this is where your models are stored
112
  return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
113
 
114
- # Add this helper function to ensure a new event loop is created if none exists
115
- def run_async_in_thread(fn, *args):
116
- loop = asyncio.new_event_loop()
117
- asyncio.set_event_loop(loop)
118
- result = loop.run_until_complete(fn(*args))
119
- loop.close()
120
- return result
121
 
122
  def parallel_tts(tasks):
123
- with ThreadPoolExecutor() as executor:
124
- futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
125
- results = [future.result() for future in futures]
126
  return results
127
 
128
  async def tts(
@@ -133,7 +133,7 @@ async def tts(
133
  use_uploaded_voice,
134
  uploaded_voice,
135
  ):
136
- # Default values for parameters used in EdgeTTS
137
  speed = 0 # Default speech speed
138
  f0_up_key = 0 # Default pitch adjustment
139
  f0_method = "rmvpe" # Default pitch extraction method
@@ -160,7 +160,7 @@ async def tts(
160
  # EdgeTTS processing
161
  if limitation and len(tts_text) > 12000:
162
  return (
163
- f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.",
164
  None,
165
  None,
166
  )
@@ -179,14 +179,15 @@ async def tts(
179
  # Common processing after loading the audio
180
  duration = len(audio) / sr
181
  print(f"Audio duration: {duration}s")
182
- if limitation and duration >= 20000:
183
  return (
184
- f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
185
  None,
186
  None,
187
  )
188
 
189
  f0_up_key = int(f0_up_key)
 
190
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
191
 
192
  # Setup for RMVPE or other pitch extraction methods
@@ -229,7 +230,7 @@ async def tts(
229
 
230
  except EOFError:
231
  info = (
232
- "output not valid. This may occur when input text and speaker do not match."
233
  )
234
  print(info)
235
  return info, None, None
@@ -238,11 +239,6 @@ async def tts(
238
  print(traceback_info)
239
  return str(e), None, None
240
 
241
- voice_mapping = {
242
- "Mongolian Male": "mn-MN-BataaNeural",
243
- "Mongolian Female": "mn-MN-YesuiNeural"
244
- }
245
-
246
  hubert_model = load_hubert()
247
-
248
- rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
 
34
 
35
  config = Config()
36
 
37
+ # Edge TTS voices
38
+ loop = asyncio.get_event_loop()
39
+ tts_voice_list = loop.run_until_complete(edge_tts.list_voices())
40
+ tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"]
41
 
42
+ # RVC models directory
43
  model_root = "weights"
44
  models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
45
  models.sort()
 
47
  def get_unique_filename(extension):
48
  return f"{uuid.uuid4()}.{extension}"
49
 
50
+ model_cache = {}
51
+
52
  def model_data(model_name):
53
+ if model_name in model_cache:
54
+ return model_cache[model_name]
55
+
56
  pth_path = [
57
  f"{model_root}/{model_name}/{f}"
58
  for f in os.listdir(f"{model_root}/{model_name}")
 
98
  index_file = index_files[0]
99
  print(f"Index file found: {index_file}")
100
 
101
+ model_cache[model_name] = (tgt_sr, net_g, vc, version, index_file, if_f0)
102
+ return model_cache[model_name]
103
 
104
  def load_hubert():
105
  models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
 
115
  return hubert_model.eval()
116
 
117
  def get_model_names():
 
118
  return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
119
 
120
+ # Initialize a global ThreadPoolExecutor
121
+ executor = ThreadPoolExecutor(max_workers=20) # Adjust based on your server
 
 
 
 
 
122
 
123
  def parallel_tts(tasks):
124
+ futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
125
+ results = [future.result() for future in futures]
 
126
  return results
127
 
128
  async def tts(
 
133
  use_uploaded_voice,
134
  uploaded_voice,
135
  ):
136
+ # Default values for parameters
137
  speed = 0 # Default speech speed
138
  f0_up_key = 0 # Default pitch adjustment
139
  f0_method = "rmvpe" # Default pitch extraction method
 
160
  # EdgeTTS processing
161
  if limitation and len(tts_text) > 12000:
162
  return (
163
+ f"Text characters should be at most 12000 in this Hugging Face Space, but got {len(tts_text)} characters.",
164
  None,
165
  None,
166
  )
 
179
  # Common processing after loading the audio
180
  duration = len(audio) / sr
181
  print(f"Audio duration: {duration}s")
182
+ if limitation and duration >= 20:
183
  return (
184
+ f"Audio should be less than 20 seconds in this Hugging Face Space, but got {duration}s.",
185
  None,
186
  None,
187
  )
188
 
189
  f0_up_key = int(f0_up_key)
190
+ # Load the model using cached data
191
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
192
 
193
  # Setup for RMVPE or other pitch extraction methods
 
230
 
231
  except EOFError:
232
  info = (
233
+ "Output not valid. This may occur when input text and speaker do not match."
234
  )
235
  print(info)
236
  return info, None, None
 
239
  print(traceback_info)
240
  return str(e), None, None
241
 
242
+ # Initialize the global models
 
 
 
 
243
  hubert_model = load_hubert()
244
+ rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)