MAZALA2024 commited on
Commit
1d94066
·
verified ·
1 Parent(s): 282faa3

Update voice_processing.py

Browse files
Files changed (1) hide show
  1. voice_processing.py +37 -57
voice_processing.py CHANGED
@@ -14,6 +14,7 @@ from fairseq import checkpoint_utils
14
  import uuid
15
 
16
  from config import Config
 
17
  from lib.infer_pack.models import (
18
  SynthesizerTrnMs256NSFsid,
19
  SynthesizerTrnMs256NSFsid_nono,
@@ -23,6 +24,9 @@ from lib.infer_pack.models import (
23
  from rmvpe import RMVPE
24
  from vc_infer_pipeline import VC
25
 
 
 
 
26
  # Set logging levels
27
  logging.getLogger("fairseq").setLevel(logging.WARNING)
28
  logging.getLogger("numba").setLevel(logging.WARNING)
@@ -34,21 +38,9 @@ limitation = os.getenv("SYSTEM") == "spaces"
34
 
35
  config = Config()
36
 
37
- # Define voice_mapping first to ensure it's always available
38
- voice_mapping = {
39
- "Mongolian Male": "mn-MN-BataaNeural",
40
- "Mongolian Female": "mn-MN-YesuiNeural"
41
- }
42
-
43
  # Edge TTS voices
44
- try:
45
- loop = asyncio.get_event_loop()
46
- tts_voice_list = loop.run_until_complete(edge_tts.list_voices())
47
- tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"]
48
- except Exception as e:
49
- logging.error(f"Error loading Edge TTS voices: {e}")
50
- tts_voice_list = []
51
- tts_voices = []
52
 
53
  # RVC models directory
54
  model_root = "weights"
@@ -58,24 +50,15 @@ models.sort()
58
  def get_unique_filename(extension):
59
  return f"{uuid.uuid4()}.{extension}"
60
 
61
- model_cache = {}
62
-
63
  def model_data(model_name):
64
- if model_name in model_cache:
65
- return model_cache[model_name]
66
-
67
  pth_path = [
68
  f"{model_root}/{model_name}/{f}"
69
  for f in os.listdir(f"{model_root}/{model_name}")
70
  if f.endswith(".pth")
71
  ][0]
72
  print(f"Loading {pth_path}")
73
- try:
74
- cpt = torch.load(pth_path, map_location="cpu")
75
- except Exception as e:
76
- logging.error(f"Error loading model {pth_path}: {e}")
77
- raise e
78
-
79
  tgt_sr = cpt["config"][-1]
80
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
81
  if_f0 = cpt.get("f0", 1)
@@ -114,32 +97,34 @@ def model_data(model_name):
114
  index_file = index_files[0]
115
  print(f"Index file found: {index_file}")
116
 
117
- model_cache[model_name] = (tgt_sr, net_g, vc, version, index_file, if_f0)
118
- return model_cache[model_name]
119
 
120
  def load_hubert():
121
- try:
122
- models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
123
- ["hubert_base.pt"],
124
- suffix="",
125
- )
126
- hubert_model = models[0]
127
- hubert_model = hubert_model.to(config.device)
128
- if config.is_half:
129
- hubert_model = hubert_model.half()
130
- else:
131
- hubert_model = hubert_model.float()
132
- return hubert_model.eval()
133
- except Exception as e:
134
- logging.error(f"Error loading HuBERT model: {e}")
135
- raise e
136
 
137
  def get_model_names():
138
  return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
139
 
140
- # Initialize a global ThreadPoolExecutor
141
- executor = ThreadPoolExecutor(max_workers=20) # Adjust based on your server
 
142
 
 
 
 
 
 
 
143
  def run_async_in_thread(fn, *args):
144
  loop = asyncio.new_event_loop()
145
  asyncio.set_event_loop(loop)
@@ -148,8 +133,10 @@ def run_async_in_thread(fn, *args):
148
  return result
149
 
150
  def parallel_tts(tasks):
151
- futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
152
- results = [future.result() for future in futures]
 
 
153
  return results
154
 
155
  async def tts(
@@ -187,7 +174,7 @@ async def tts(
187
  # EdgeTTS processing
188
  if limitation and len(tts_text) > 12000:
189
  return (
190
- f"Text characters should be at most 12000 in this Hugging Face Space, but got {len(tts_text)} characters.",
191
  None,
192
  None,
193
  )
@@ -206,15 +193,15 @@ async def tts(
206
  # Common processing after loading the audio
207
  duration = len(audio) / sr
208
  print(f"Audio duration: {duration}s")
209
- if limitation and duration >= 20:
210
  return (
211
- f"Audio should be less than 20 seconds in this Hugging Face Space, but got {duration}s.",
212
  None,
213
  None,
214
  )
215
 
216
  f0_up_key = int(f0_up_key)
217
- # Load the model using cached data
218
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
219
 
220
  # Setup for RMVPE or other pitch extraction methods
@@ -266,10 +253,3 @@ async def tts(
266
  print(traceback_info)
267
  return str(e), None, None
268
 
269
- # Initialize the global models
270
- try:
271
- hubert_model = load_hubert()
272
- rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
273
- except Exception as e:
274
- logging.error(f"Failed to initialize global models: {e}")
275
- # Optionally, you can exit or handle the error as needed
 
14
  import uuid
15
 
16
  from config import Config
17
+ from config import Config, voice_mapping
18
  from lib.infer_pack.models import (
19
  SynthesizerTrnMs256NSFsid,
20
  SynthesizerTrnMs256NSFsid_nono,
 
24
  from rmvpe import RMVPE
25
  from vc_infer_pipeline import VC
26
 
27
+ model_cache = {}
28
+
29
+
30
  # Set logging levels
31
  logging.getLogger("fairseq").setLevel(logging.WARNING)
32
  logging.getLogger("numba").setLevel(logging.WARNING)
 
38
 
39
  config = Config()
40
 
 
 
 
 
 
 
41
  # Edge TTS voices
42
+ tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
43
+ tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"]
 
 
 
 
 
 
44
 
45
  # RVC models directory
46
  model_root = "weights"
 
50
  def get_unique_filename(extension):
51
  return f"{uuid.uuid4()}.{extension}"
52
 
 
 
53
  def model_data(model_name):
54
+ # We will not modify this function to cache models
 
 
55
  pth_path = [
56
  f"{model_root}/{model_name}/{f}"
57
  for f in os.listdir(f"{model_root}/{model_name}")
58
  if f.endswith(".pth")
59
  ][0]
60
  print(f"Loading {pth_path}")
61
+ cpt = torch.load(pth_path, map_location="cpu")
 
 
 
 
 
62
  tgt_sr = cpt["config"][-1]
63
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
64
  if_f0 = cpt.get("f0", 1)
 
97
  index_file = index_files[0]
98
  print(f"Index file found: {index_file}")
99
 
100
+ return tgt_sr, net_g, vc, version, index_file, if_f0
 
101
 
102
  def load_hubert():
103
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
104
+ ["hubert_base.pt"],
105
+ suffix="",
106
+ )
107
+ hubert_model = models[0]
108
+ hubert_model = hubert_model.to(config.device)
109
+ if config.is_half:
110
+ hubert_model = hubert_model.half()
111
+ else:
112
+ hubert_model = hubert_model.float()
113
+ return hubert_model.eval()
 
 
 
 
114
 
115
  def get_model_names():
116
  return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
117
 
118
+ # Initialize the global models
119
+ hubert_model = load_hubert()
120
+ rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
121
 
122
+ voice_mapping = {
123
+ "Mongolian Male": "mn-MN-BataaNeural",
124
+ "Mongolian Female": "mn-MN-YesuiNeural"
125
+ }
126
+
127
+ # Function to run async functions in a new event loop within a thread
128
  def run_async_in_thread(fn, *args):
129
  loop = asyncio.new_event_loop()
130
  asyncio.set_event_loop(loop)
 
133
  return result
134
 
135
  def parallel_tts(tasks):
136
+ # Increase max_workers to better utilize CPU and GPU resources
137
+ with ThreadPoolExecutor(max_workers=8) as executor: # Adjust based on your server capacity
138
+ futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
139
+ results = [future.result() for future in futures]
140
  return results
141
 
142
  async def tts(
 
174
  # EdgeTTS processing
175
  if limitation and len(tts_text) > 12000:
176
  return (
177
+ f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.",
178
  None,
179
  None,
180
  )
 
193
  # Common processing after loading the audio
194
  duration = len(audio) / sr
195
  print(f"Audio duration: {duration}s")
196
+ if limitation and duration >= 20000:
197
  return (
198
+ f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
199
  None,
200
  None,
201
  )
202
 
203
  f0_up_key = int(f0_up_key)
204
+ # Load the model
205
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
206
 
207
  # Setup for RMVPE or other pitch extraction methods
 
253
  print(traceback_info)
254
  return str(e), None, None
255