MAZALA2024 commited on
Commit
8bc67af
·
verified ·
1 Parent(s): 65c8072

Update voice_processing.py

Browse files
Files changed (1) hide show
  1. voice_processing.py +18 -23
voice_processing.py CHANGED
@@ -23,9 +23,6 @@ from lib.infer_pack.models import (
23
  from rmvpe import RMVPE
24
  from vc_infer_pipeline import VC
25
 
26
- model_cache = {}
27
-
28
-
29
  # Set logging levels
30
  logging.getLogger("fairseq").setLevel(logging.WARNING)
31
  logging.getLogger("numba").setLevel(logging.WARNING)
@@ -37,11 +34,11 @@ limitation = os.getenv("SYSTEM") == "spaces"
37
 
38
  config = Config()
39
 
40
- # Edge TTS voices
41
  tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
42
- tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"]
43
 
44
- # RVC models directory
45
  model_root = "weights"
46
  models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
47
  models.sort()
@@ -50,7 +47,6 @@ def get_unique_filename(extension):
50
  return f"{uuid.uuid4()}.{extension}"
51
 
52
  def model_data(model_name):
53
- # We will not modify this function to cache models
54
  pth_path = [
55
  f"{model_root}/{model_name}/{f}"
56
  for f in os.listdir(f"{model_root}/{model_name}")
@@ -112,18 +108,10 @@ def load_hubert():
112
  return hubert_model.eval()
113
 
114
  def get_model_names():
 
115
  return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
116
 
117
- # Initialize the global models
118
- hubert_model = load_hubert()
119
- rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
120
-
121
- voice_mapping = {
122
- "Mongolian Male": "mn-MN-BataaNeural",
123
- "Mongolian Female": "mn-MN-YesuiNeural"
124
- }
125
-
126
- # Function to run async functions in a new event loop within a thread
127
  def run_async_in_thread(fn, *args):
128
  loop = asyncio.new_event_loop()
129
  asyncio.set_event_loop(loop)
@@ -132,8 +120,7 @@ def run_async_in_thread(fn, *args):
132
  return result
133
 
134
  def parallel_tts(tasks):
135
- # Increase max_workers to better utilize CPU and GPU resources
136
- with ThreadPoolExecutor(max_workers=8) as executor: # Adjust based on your server capacity
137
  futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
138
  results = [future.result() for future in futures]
139
  return results
@@ -146,7 +133,7 @@ async def tts(
146
  use_uploaded_voice,
147
  uploaded_voice,
148
  ):
149
- # Default values for parameters
150
  speed = 0 # Default speech speed
151
  f0_up_key = 0 # Default pitch adjustment
152
  f0_method = "rmvpe" # Default pitch extraction method
@@ -200,7 +187,6 @@ async def tts(
200
  )
201
 
202
  f0_up_key = int(f0_up_key)
203
- # Load the model
204
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
205
 
206
  # Setup for RMVPE or other pitch extraction methods
@@ -243,11 +229,20 @@ async def tts(
243
 
244
  except EOFError:
245
  info = (
246
- "Output not valid. This may occur when input text and speaker do not match."
247
  )
248
  print(info)
249
  return info, None, None
250
  except Exception as e:
251
  traceback_info = traceback.format_exc()
252
  print(traceback_info)
253
- return str(e), None, None
 
 
 
 
 
 
 
 
 
 
23
  from rmvpe import RMVPE
24
  from vc_infer_pipeline import VC
25
 
 
 
 
26
  # Set logging levels
27
  logging.getLogger("fairseq").setLevel(logging.WARNING)
28
  logging.getLogger("numba").setLevel(logging.WARNING)
 
34
 
35
  config = Config()
36
 
37
+ # Edge TTS
38
  tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
39
+ tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"] # Specific voices
40
 
41
+ # RVC models
42
  model_root = "weights"
43
  models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
44
  models.sort()
 
47
  return f"{uuid.uuid4()}.{extension}"
48
 
49
  def model_data(model_name):
 
50
  pth_path = [
51
  f"{model_root}/{model_name}/{f}"
52
  for f in os.listdir(f"{model_root}/{model_name}")
 
108
  return hubert_model.eval()
109
 
110
  def get_model_names():
111
+ model_root = "weights" # Assuming this is where your models are stored
112
  return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
113
 
114
+ # Add this helper function to ensure a new event loop is created if none exists
 
 
 
 
 
 
 
 
 
115
  def run_async_in_thread(fn, *args):
116
  loop = asyncio.new_event_loop()
117
  asyncio.set_event_loop(loop)
 
120
  return result
121
 
122
  def parallel_tts(tasks):
123
+ with ThreadPoolExecutor() as executor:
 
124
  futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
125
  results = [future.result() for future in futures]
126
  return results
 
133
  use_uploaded_voice,
134
  uploaded_voice,
135
  ):
136
+ # Default values for parameters used in EdgeTTS
137
  speed = 0 # Default speech speed
138
  f0_up_key = 0 # Default pitch adjustment
139
  f0_method = "rmvpe" # Default pitch extraction method
 
187
  )
188
 
189
  f0_up_key = int(f0_up_key)
 
190
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
191
 
192
  # Setup for RMVPE or other pitch extraction methods
 
229
 
230
  except EOFError:
231
  info = (
232
+ "output not valid. This may occur when input text and speaker do not match."
233
  )
234
  print(info)
235
  return info, None, None
236
  except Exception as e:
237
  traceback_info = traceback.format_exc()
238
  print(traceback_info)
239
+ return str(e), None, None
240
+
241
+ voice_mapping = {
242
+ "Mongolian Male": "mn-MN-BataaNeural",
243
+ "Mongolian Female": "mn-MN-YesuiNeural"
244
+ }
245
+
246
+ hubert_model = load_hubert()
247
+
248
+ rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)