MAZALA2024 commited on
Commit
0cbbac2
·
verified ·
1 Parent(s): fe5a693

Update voice_processing.py

Browse files
Files changed (1) hide show
  1. voice_processing.py +89 -111
voice_processing.py CHANGED
@@ -1,18 +1,11 @@
1
- import asyncio
2
- import datetime
3
- import logging
4
  import os
5
  import time
6
  import traceback
7
- import tempfile
8
- from concurrent.futures import ThreadPoolExecutor
9
-
10
- import edge_tts
11
- import librosa
12
  import torch
 
 
13
  from fairseq import checkpoint_utils
14
- import uuid
15
-
16
  from config import Config
17
  from lib.infer_pack.models import (
18
  SynthesizerTrnMs256NSFsid,
@@ -20,39 +13,63 @@ from lib.infer_pack.models import (
20
  SynthesizerTrnMs768NSFsid,
21
  SynthesizerTrnMs768NSFsid_nono,
22
  )
23
- from rmvpe import RMVPE
24
  from vc_infer_pipeline import VC
25
-
26
- # Set logging levels
27
- logging.getLogger("fairseq").setLevel(logging.WARNING)
28
- logging.getLogger("numba").setLevel(logging.WARNING)
29
- logging.getLogger("markdown_it").setLevel(logging.WARNING)
30
- logging.getLogger("urllib3").setLevel(logging.WARNING)
31
- logging.getLogger("matplotlib").setLevel(logging.WARNING)
32
-
33
- limitation = os.getenv("SYSTEM") == "spaces"
34
 
35
  config = Config()
36
 
37
- # Edge TTS
38
- tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
39
- tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"] # Specific voices
 
40
 
41
- # RVC models
42
- model_root = "weights"
43
- models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
44
- models.sort()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def get_unique_filename(extension):
47
  return f"{uuid.uuid4()}.{extension}"
48
 
 
 
 
 
49
  def model_data(model_name):
50
- pth_path = [
51
- f"{model_root}/{model_name}/{f}"
52
- for f in os.listdir(f"{model_root}/{model_name}")
53
- if f.endswith(".pth")
54
- ][0]
55
- print(f"Loading {pth_path}")
 
 
 
 
 
 
 
56
  cpt = torch.load(pth_path, map_location="cpu")
57
  tgt_sr = cpt["config"][-1]
58
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
@@ -70,62 +87,33 @@ def model_data(model_name):
70
  net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
71
  else:
72
  raise ValueError("Unknown version")
 
73
  del net_g.enc_q
74
  net_g.load_state_dict(cpt["weight"], strict=False)
75
- print("Model loaded")
76
  net_g.eval().to(config.device)
77
  if config.is_half:
78
  net_g = net_g.half()
79
  else:
80
  net_g = net_g.float()
 
 
81
  vc = VC(tgt_sr, config)
82
 
83
  index_files = [
84
- f"{model_root}/{model_name}/{f}"
85
- for f in os.listdir(f"{model_root}/{model_name}")
86
- if f.endswith(".index")
87
  ]
88
- if len(index_files) == 0:
89
- print("No index file found")
90
- index_file = ""
91
- else:
92
- index_file = index_files[0]
93
  print(f"Index file found: {index_file}")
94
-
95
- return tgt_sr, net_g, vc, version, index_file, if_f0
96
-
97
- def load_hubert():
98
- models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
99
- ["hubert_base.pt"],
100
- suffix="",
101
- )
102
- hubert_model = models[0]
103
- hubert_model = hubert_model.to(config.device)
104
- if config.is_half:
105
- hubert_model = hubert_model.half()
106
  else:
107
- hubert_model = hubert_model.float()
108
- return hubert_model.eval()
109
 
110
- def get_model_names():
111
- model_root = "weights" # Assuming this is where your models are stored
112
- return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
113
 
114
- # Add this helper function to ensure a new event loop is created if none exists
115
- def run_async_in_thread(fn, *args):
116
- loop = asyncio.new_event_loop()
117
- asyncio.set_event_loop(loop)
118
- result = loop.run_until_complete(fn(*args))
119
- loop.close()
120
- return result
121
-
122
- def parallel_tts(tasks):
123
- with ThreadPoolExecutor() as executor:
124
- futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks]
125
- results = [future.result() for future in futures]
126
- return results
127
-
128
- async def tts(
129
  model_name,
130
  tts_text,
131
  tts_voice,
@@ -133,8 +121,11 @@ async def tts(
133
  use_uploaded_voice,
134
  uploaded_voice,
135
  ):
 
 
 
 
136
  # Default values for parameters used in EdgeTTS
137
- speed = 0 # Default speech speed
138
  f0_up_key = 0 # Default pitch adjustment
139
  f0_method = "rmvpe" # Default pitch extraction method
140
  protect = 0.33 # Default protect value
@@ -144,52 +135,42 @@ async def tts(
144
  edge_time = 0 # Initialize edge_time
145
 
146
  edge_output_filename = get_unique_filename("mp3")
 
 
147
 
148
  try:
149
  if use_uploaded_voice:
150
  if uploaded_voice is None:
151
  return "No voice file uploaded.", None, None
152
-
153
  # Process the uploaded voice file
154
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
155
  tmp_file.write(uploaded_voice)
156
  uploaded_file_path = tmp_file.name
157
 
158
  audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
 
159
  else:
160
  # EdgeTTS processing
161
- if limitation and len(tts_text) > 12000:
162
- return (
163
- f"Text characters should be at most 12000 in this huggingface space, but got {len(tts_text)} characters.",
164
- None,
165
- None,
166
- )
167
-
168
- # Invoke Edge TTS
169
  t0 = time.time()
 
170
  speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
171
- await edge_tts.Communicate(
172
  tts_text, tts_voice, rate=speed_str
173
- ).save(edge_output_filename)
 
174
  t1 = time.time()
175
  edge_time = t1 - t0
176
 
177
  audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
 
178
 
179
- # Common processing after loading the audio
180
- duration = len(audio) / sr
181
- print(f"Audio duration: {duration}s")
182
- if limitation and duration >= 20000:
183
- return (
184
- f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
185
- None,
186
- None,
187
- )
188
-
189
- f0_up_key = int(f0_up_key)
190
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
191
 
192
- # Setup for RMVPE or other pitch extraction methods
193
  if f0_method == "rmvpe":
194
  vc.model_rmvpe = rmvpe_model
195
 
@@ -198,9 +179,9 @@ async def tts(
198
  audio_opt = vc.pipeline(
199
  hubert_model,
200
  net_g,
201
- 0,
202
  audio,
203
- edge_output_filename if not use_uploaded_voice else uploaded_file_path,
204
  times,
205
  f0_up_key,
206
  f0_method,
@@ -218,31 +199,28 @@ async def tts(
218
 
219
  if tgt_sr != resample_sr and resample_sr >= 16000:
220
  tgt_sr = resample_sr
221
-
222
- info = f"Success. Time: tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
223
  print(info)
224
  return (
225
- info,
226
- edge_output_filename if not use_uploaded_voice else None,
227
  (tgt_sr, audio_opt),
228
  )
229
 
230
  except EOFError:
231
  info = (
232
- "output not valid. This may occur when input text and speaker do not match."
233
  )
234
  print(info)
235
- return info, None, None
236
  except Exception as e:
237
  traceback_info = traceback.format_exc()
238
  print(traceback_info)
239
- return str(e), None, None
240
 
 
241
  voice_mapping = {
242
  "Mongolian Male": "mn-MN-BataaNeural",
243
  "Mongolian Female": "mn-MN-YesuiNeural"
244
  }
245
-
246
- hubert_model = load_hubert()
247
-
248
- rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
 
 
 
 
1
  import os
2
  import time
3
  import traceback
 
 
 
 
 
4
  import torch
5
+ import numpy as np
6
+ import librosa
7
  from fairseq import checkpoint_utils
8
+ from rmvpe import RMVPE
 
9
  from config import Config
10
  from lib.infer_pack.models import (
11
  SynthesizerTrnMs256NSFsid,
 
13
  SynthesizerTrnMs768NSFsid,
14
  SynthesizerTrnMs768NSFsid_nono,
15
  )
 
16
  from vc_infer_pipeline import VC
17
+ import uuid
 
 
 
 
 
 
 
 
18
 
19
  config = Config()
20
 
21
+ # Global models loaded once
22
+ hubert_model = None
23
+ rmvpe_model = None
24
+ model_cache = {} # Cache for RVC models
25
 
26
+ def load_hubert():
27
+ global hubert_model
28
+ if hubert_model is None:
29
+ print("Loading Hubert model...")
30
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
31
+ ["hubert_base.pt"],
32
+ suffix="",
33
+ )
34
+ hubert_model = models[0]
35
+ hubert_model = hubert_model.to(config.device)
36
+ if config.is_half:
37
+ hubert_model = hubert_model.half()
38
+ else:
39
+ hubert_model = hubert_model.float()
40
+ hubert_model.eval()
41
+ print("Hubert model loaded.")
42
+ return hubert_model
43
+
44
+ def load_rmvpe():
45
+ global rmvpe_model
46
+ if rmvpe_model is None:
47
+ print("Loading RMVPE model...")
48
+ rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
49
+ print("RMVPE model loaded.")
50
+ return rmvpe_model
51
 
52
  def get_unique_filename(extension):
53
  return f"{uuid.uuid4()}.{extension}"
54
 
55
+ def get_model_names():
56
+ model_root = "weights" # Assuming this is where your models are stored
57
+ return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
58
+
59
  def model_data(model_name):
60
+ global model_cache
61
+ if model_name in model_cache:
62
+ # Return cached model data
63
+ return model_cache[model_name]
64
+
65
+ model_root = "weights"
66
+ pth_files = [
67
+ f for f in os.listdir(f"{model_root}/{model_name}") if f.endswith(".pth")
68
+ ]
69
+ if not pth_files:
70
+ raise FileNotFoundError(f"No .pth file found for model '{model_name}'")
71
+ pth_path = f"{model_root}/{model_name}/{pth_files[0]}"
72
+ print(f"Loading model from {pth_path}")
73
  cpt = torch.load(pth_path, map_location="cpu")
74
  tgt_sr = cpt["config"][-1]
75
  cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
 
87
  net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
88
  else:
89
  raise ValueError("Unknown version")
90
+
91
  del net_g.enc_q
92
  net_g.load_state_dict(cpt["weight"], strict=False)
 
93
  net_g.eval().to(config.device)
94
  if config.is_half:
95
  net_g = net_g.half()
96
  else:
97
  net_g = net_g.float()
98
+ print(f"Model '{model_name}' loaded.")
99
+
100
  vc = VC(tgt_sr, config)
101
 
102
  index_files = [
103
+ f for f in os.listdir(f"{model_root}/{model_name}") if f.endswith(".index")
 
 
104
  ]
105
+ if index_files:
106
+ index_file = f"{model_root}/{model_name}/{index_files[0]}"
 
 
 
107
  print(f"Index file found: {index_file}")
 
 
 
 
 
 
 
 
 
 
 
 
108
  else:
109
+ index_file = ""
110
+ print("No index file found.")
111
 
112
+ # Cache the loaded model data
113
+ model_cache[model_name] = (tgt_sr, net_g, vc, version, index_file, if_f0)
114
+ return tgt_sr, net_g, vc, version, index_file, if_f0
115
 
116
+ def tts(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  model_name,
118
  tts_text,
119
  tts_voice,
 
121
  use_uploaded_voice,
122
  uploaded_voice,
123
  ):
124
+ # Load models if not already loaded
125
+ load_hubert()
126
+ load_rmvpe()
127
+
128
  # Default values for parameters used in EdgeTTS
 
129
  f0_up_key = 0 # Default pitch adjustment
130
  f0_method = "rmvpe" # Default pitch extraction method
131
  protect = 0.33 # Default protect value
 
135
  edge_time = 0 # Initialize edge_time
136
 
137
  edge_output_filename = get_unique_filename("mp3")
138
+ audio = None
139
+ sr = 16000 # Default sample rate
140
 
141
  try:
142
  if use_uploaded_voice:
143
  if uploaded_voice is None:
144
  return "No voice file uploaded.", None, None
145
+
146
  # Process the uploaded voice file
147
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
148
  tmp_file.write(uploaded_voice)
149
  uploaded_file_path = tmp_file.name
150
 
151
  audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
152
+ input_audio_path = uploaded_file_path
153
  else:
154
  # EdgeTTS processing
155
+ # Note: EdgeTTS code may need to be adjusted based on your implementation
156
+ import edge_tts
 
 
 
 
 
 
157
  t0 = time.time()
158
+ speed = 0 # Default speech speed
159
  speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
160
+ communicate = edge_tts.Communicate(
161
  tts_text, tts_voice, rate=speed_str
162
+ )
163
+ asyncio.run(communicate.save(edge_output_filename))
164
  t1 = time.time()
165
  edge_time = t1 - t0
166
 
167
  audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
168
+ input_audio_path = edge_output_filename
169
 
170
+ # Load the specified RVC model
 
 
 
 
 
 
 
 
 
 
171
  tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
172
 
173
+ # Set RMVPE model for pitch extraction
174
  if f0_method == "rmvpe":
175
  vc.model_rmvpe = rmvpe_model
176
 
 
179
  audio_opt = vc.pipeline(
180
  hubert_model,
181
  net_g,
182
+ 0, # Speaker ID
183
  audio,
184
+ input_audio_path,
185
  times,
186
  f0_up_key,
187
  f0_method,
 
199
 
200
  if tgt_sr != resample_sr and resample_sr >= 16000:
201
  tgt_sr = resample_sr
202
+
203
+ info = f"Success. Time: tts: {edge_time:.2f}s, npy: {times[0]:.2f}s, f0: {times[1]:.2f}s, infer: {times[2]:.2f}s"
204
  print(info)
205
  return (
206
+ {"info": info},
207
+ None, # Return None for edge_output_filename as it's not needed
208
  (tgt_sr, audio_opt),
209
  )
210
 
211
  except EOFError:
212
  info = (
213
+ "Output not valid. This may occur when input text and speaker do not match."
214
  )
215
  print(info)
216
+ return {"error": info}, None, None
217
  except Exception as e:
218
  traceback_info = traceback.format_exc()
219
  print(traceback_info)
220
+ return {"error": str(e)}, None, None
221
 
222
+ # Voice mapping dictionary
223
  voice_mapping = {
224
  "Mongolian Male": "mn-MN-BataaNeural",
225
  "Mongolian Female": "mn-MN-YesuiNeural"
226
  }