Vijish commited on
Commit
56fd60f
·
verified ·
1 Parent(s): c838db5

Update voice_processing.py

Browse files
Files changed (1) hide show
  1. voice_processing.py +248 -248
voice_processing.py CHANGED
@@ -1,248 +1,248 @@
1
- import asyncio
2
- import datetime
3
- import logging
4
- import os
5
- import time
6
- import traceback
7
- import tempfile
8
-
9
- import edge_tts
10
- import librosa
11
- import torch
12
- from fairseq import checkpoint_utils
13
- import uuid
14
-
15
- from config import Config
16
- from lib.infer_pack.models import (
17
- SynthesizerTrnMs256NSFsid,
18
- SynthesizerTrnMs256NSFsid_nono,
19
- SynthesizerTrnMs768NSFsid,
20
- SynthesizerTrnMs768NSFsid_nono,
21
- )
22
- from rmvpe import RMVPE
23
- from vc_infer_pipeline import VC
24
-
25
- # Set logging levels
26
- logging.getLogger("fairseq").setLevel(logging.WARNING)
27
- logging.getLogger("numba").setLevel(logging.WARNING)
28
- logging.getLogger("markdown_it").setLevel(logging.WARNING)
29
- logging.getLogger("urllib3").setLevel(logging.WARNING)
30
- logging.getLogger("matplotlib").setLevel(logging.WARNING)
31
-
32
- limitation = os.getenv("SYSTEM") == "spaces"
33
-
34
- config = Config()
35
-
36
- # Edge TTS
37
- tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
38
- tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"] # Specific voices
39
-
40
- # RVC models
41
- model_root = "weights"
42
- models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
43
- models.sort()
44
-
45
- def get_unique_filename(extension):
46
- return f"{uuid.uuid4()}.{extension}"
47
-
48
-
49
- #edge_output_filename = get_unique_filename("mp3")
50
-
51
-
52
- def model_data(model_name):
53
- # global n_spk, tgt_sr, net_g, vc, cpt, version, index_file
54
- pth_path = [
55
- f"{model_root}/{model_name}/{f}"
56
- for f in os.listdir(f"{model_root}/{model_name}")
57
- if f.endswith(".pth")
58
- ][0]
59
- print(f"Loading {pth_path}")
60
- cpt = torch.load(pth_path, map_location="cpu")
61
- tgt_sr = cpt["config"][-1]
62
- cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
63
- if_f0 = cpt.get("f0", 1)
64
- version = cpt.get("version", "v1")
65
- if version == "v1":
66
- if if_f0 == 1:
67
- net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
68
- else:
69
- net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
70
- elif version == "v2":
71
- if if_f0 == 1:
72
- net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
73
- else:
74
- net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
75
- else:
76
- raise ValueError("Unknown version")
77
- del net_g.enc_q
78
- net_g.load_state_dict(cpt["weight"], strict=False)
79
- print("Model loaded")
80
- net_g.eval().to(config.device)
81
- if config.is_half:
82
- net_g = net_g.half()
83
- else:
84
- net_g = net_g.float()
85
- vc = VC(tgt_sr, config)
86
- # n_spk = cpt["config"][-3]
87
-
88
- index_files = [
89
- f"{model_root}/{model_name}/{f}"
90
- for f in os.listdir(f"{model_root}/{model_name}")
91
- if f.endswith(".index")
92
- ]
93
- if len(index_files) == 0:
94
- print("No index file found")
95
- index_file = ""
96
- else:
97
- index_file = index_files[0]
98
- print(f"Index file found: {index_file}")
99
-
100
- return tgt_sr, net_g, vc, version, index_file, if_f0
101
-
102
-
103
- def load_hubert():
104
- # global hubert_model
105
- models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
106
- ["hubert_base.pt"],
107
- suffix="",
108
- )
109
- hubert_model = models[0]
110
- hubert_model = hubert_model.to(config.device)
111
- if config.is_half:
112
- hubert_model = hubert_model.half()
113
- else:
114
- hubert_model = hubert_model.float()
115
- return hubert_model.eval()
116
-
117
- def get_model_names():
118
- model_root = "weights" # Assuming this is where your models are stored
119
- return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
120
-
121
- async def tts(
122
- model_name,
123
- tts_text,
124
- tts_voice,
125
- index_rate,
126
- use_uploaded_voice,
127
- uploaded_voice,
128
- ):
129
- # Default values for parameters used in EdgeTTS
130
- speed = 0 # Default speech speed
131
- f0_up_key = 0 # Default pitch adjustment
132
- f0_method = "rmvpe" # Default pitch extraction method
133
- protect = 0.33 # Default protect value
134
- filter_radius = 3
135
- resample_sr = 0
136
- rms_mix_rate = 0.25
137
- edge_time = 0 # Initialize edge_time
138
-
139
- edge_output_filename = get_unique_filename("mp3")
140
-
141
-
142
- try:
143
- if use_uploaded_voice:
144
- if uploaded_voice is None:
145
- return "No voice file uploaded.", None, None
146
-
147
- # Process the uploaded voice file
148
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
149
- tmp_file.write(uploaded_voice)
150
- uploaded_file_path = tmp_file.name
151
-
152
- #uploaded_file_path = uploaded_voice.name
153
- audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
154
- else:
155
- # EdgeTTS processing
156
- if limitation and len(tts_text) > 4000:
157
- return (
158
- f"Text characters should be at most 280 in this huggingface space, but got {len(tts_text)} characters.",
159
- None,
160
- None,
161
- )
162
-
163
- # Invoke Edge TTS
164
- t0 = time.time()
165
- speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
166
- await edge_tts.Communicate(
167
- tts_text, tts_voice, rate=speed_str
168
- ).save(edge_output_filename)
169
- t1 = time.time()
170
- edge_time = t1 - t0
171
-
172
- audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
173
-
174
- # Common processing after loading the audio
175
- duration = len(audio) / sr
176
- print(f"Audio duration: {duration}s")
177
- if limitation and duration >= 20:
178
- return (
179
- f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
180
- None,
181
- None,
182
- )
183
-
184
- f0_up_key = int(f0_up_key)
185
- tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
186
-
187
- # Setup for RMVPE or other pitch extraction methods
188
- if f0_method == "rmvpe":
189
- vc.model_rmvpe = rmvpe_model
190
-
191
- # Perform voice conversion pipeline
192
- times = [0, 0, 0]
193
- audio_opt = vc.pipeline(
194
- hubert_model,
195
- net_g,
196
- 0,
197
- audio,
198
- edge_output_filename if not use_uploaded_voice else uploaded_file_path,
199
- times,
200
- f0_up_key,
201
- f0_method,
202
- index_file,
203
- index_rate,
204
- if_f0,
205
- filter_radius,
206
- tgt_sr,
207
- resample_sr,
208
- rms_mix_rate,
209
- version,
210
- protect,
211
- None,
212
- )
213
-
214
- if tgt_sr != resample_sr and resample_sr >= 16000:
215
- tgt_sr = resample_sr
216
-
217
- info = f"Success. Time: tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
218
- print(info)
219
- return (
220
- info,
221
- edge_output_filename if not use_uploaded_voice else None,
222
- (tgt_sr, audio_opt),
223
- edge_output_filename
224
- )
225
-
226
- except EOFError:
227
- info = (
228
- "output not valid. This may occur when input text and speaker do not match."
229
- )
230
- print(info)
231
- return info, None, None
232
- except Exception as e:
233
- traceback_info = traceback.format_exc()
234
- print(traceback_info)
235
- return str(e), None, None
236
-
237
-
238
- voice_mapping = {
239
- "Mongolian Male": "mn-MN-BataaNeural",
240
- "Mongolian Female": "mn-MN-YesuiNeural"
241
- }
242
-
243
-
244
-
245
- hubert_model = load_hubert()
246
-
247
- rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
248
-
 
1
+ import asyncio
2
+ import datetime
3
+ import logging
4
+ import os
5
+ import time
6
+ import traceback
7
+ import tempfile
8
+
9
+ import edge_tts
10
+ import librosa
11
+ import torch
12
+ from fairseq import checkpoint_utils
13
+ import uuid
14
+
15
+ from config import Config
16
+ from lib.infer_pack.models import (
17
+ SynthesizerTrnMs256NSFsid,
18
+ SynthesizerTrnMs256NSFsid_nono,
19
+ SynthesizerTrnMs768NSFsid,
20
+ SynthesizerTrnMs768NSFsid_nono,
21
+ )
22
+ from rmvpe import RMVPE
23
+ from vc_infer_pipeline import VC
24
+
25
+ # Set logging levels
26
+ logging.getLogger("fairseq").setLevel(logging.WARNING)
27
+ logging.getLogger("numba").setLevel(logging.WARNING)
28
+ logging.getLogger("markdown_it").setLevel(logging.WARNING)
29
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
30
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
31
+
32
+ limitation = os.getenv("SYSTEM") == "spaces"
33
+
34
+ config = Config()
35
+
36
+ # Edge TTS
37
+ tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
38
+ tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"] # Specific voices
39
+
40
+ # RVC models
41
+ model_root = "weights"
42
+ models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
43
+ models.sort()
44
+
45
+ def get_unique_filename(extension):
46
+ return f"{uuid.uuid4()}.{extension}"
47
+
48
+
49
+ #edge_output_filename = get_unique_filename("mp3")
50
+
51
+
52
+ def model_data(model_name):
53
+ # global n_spk, tgt_sr, net_g, vc, cpt, version, index_file
54
+ pth_path = [
55
+ f"{model_root}/{model_name}/{f}"
56
+ for f in os.listdir(f"{model_root}/{model_name}")
57
+ if f.endswith(".pth")
58
+ ][0]
59
+ print(f"Loading {pth_path}")
60
+ cpt = torch.load(pth_path, map_location="cpu")
61
+ tgt_sr = cpt["config"][-1]
62
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
63
+ if_f0 = cpt.get("f0", 1)
64
+ version = cpt.get("version", "v1")
65
+ if version == "v1":
66
+ if if_f0 == 1:
67
+ net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
68
+ else:
69
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
70
+ elif version == "v2":
71
+ if if_f0 == 1:
72
+ net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
73
+ else:
74
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
75
+ else:
76
+ raise ValueError("Unknown version")
77
+ del net_g.enc_q
78
+ net_g.load_state_dict(cpt["weight"], strict=False)
79
+ print("Model loaded")
80
+ net_g.eval().to(config.device)
81
+ if config.is_half:
82
+ net_g = net_g.half()
83
+ else:
84
+ net_g = net_g.float()
85
+ vc = VC(tgt_sr, config)
86
+ # n_spk = cpt["config"][-3]
87
+
88
+ index_files = [
89
+ f"{model_root}/{model_name}/{f}"
90
+ for f in os.listdir(f"{model_root}/{model_name}")
91
+ if f.endswith(".index")
92
+ ]
93
+ if len(index_files) == 0:
94
+ print("No index file found")
95
+ index_file = ""
96
+ else:
97
+ index_file = index_files[0]
98
+ print(f"Index file found: {index_file}")
99
+
100
+ return tgt_sr, net_g, vc, version, index_file, if_f0
101
+
102
+
103
+ def load_hubert():
104
+ # global hubert_model
105
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
106
+ ["hubert_base.pt"],
107
+ suffix="",
108
+ )
109
+ hubert_model = models[0]
110
+ hubert_model = hubert_model.to(config.device)
111
+ if config.is_half:
112
+ hubert_model = hubert_model.half()
113
+ else:
114
+ hubert_model = hubert_model.float()
115
+ return hubert_model.eval()
116
+
117
+ def get_model_names():
118
+ model_root = "weights" # Assuming this is where your models are stored
119
+ return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
120
+
121
+ async def tts(
122
+ model_name,
123
+ tts_text,
124
+ tts_voice,
125
+ index_rate,
126
+ use_uploaded_voice,
127
+ uploaded_voice,
128
+ ):
129
+ # Default values for parameters used in EdgeTTS
130
+ speed = 0 # Default speech speed
131
+ f0_up_key = 0 # Default pitch adjustment
132
+ f0_method = "rmvpe" # Default pitch extraction method
133
+ protect = 0.33 # Default protect value
134
+ filter_radius = 3
135
+ resample_sr = 0
136
+ rms_mix_rate = 0.25
137
+ edge_time = 0 # Initialize edge_time
138
+
139
+ edge_output_filename = get_unique_filename("mp3")
140
+
141
+
142
+ try:
143
+ if use_uploaded_voice:
144
+ if uploaded_voice is None:
145
+ return "No voice file uploaded.", None, None
146
+
147
+ # Process the uploaded voice file
148
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
149
+ tmp_file.write(uploaded_voice)
150
+ uploaded_file_path = tmp_file.name
151
+
152
+ #uploaded_file_path = uploaded_voice.name
153
+ audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
154
+ else:
155
+ # EdgeTTS processing
156
+ if limitation and len(tts_text) > 12000:
157
+ return (
158
+ f"Text characters should be at most 280 in this huggingface space, but got {len(tts_text)} characters.",
159
+ None,
160
+ None,
161
+ )
162
+
163
+ # Invoke Edge TTS
164
+ t0 = time.time()
165
+ speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
166
+ await edge_tts.Communicate(
167
+ tts_text, tts_voice, rate=speed_str
168
+ ).save(edge_output_filename)
169
+ t1 = time.time()
170
+ edge_time = t1 - t0
171
+
172
+ audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
173
+
174
+ # Common processing after loading the audio
175
+ duration = len(audio) / sr
176
+ print(f"Audio duration: {duration}s")
177
+ if limitation and duration >= 20000:
178
+ return (
179
+ f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
180
+ None,
181
+ None,
182
+ )
183
+
184
+ f0_up_key = int(f0_up_key)
185
+ tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
186
+
187
+ # Setup for RMVPE or other pitch extraction methods
188
+ if f0_method == "rmvpe":
189
+ vc.model_rmvpe = rmvpe_model
190
+
191
+ # Perform voice conversion pipeline
192
+ times = [0, 0, 0]
193
+ audio_opt = vc.pipeline(
194
+ hubert_model,
195
+ net_g,
196
+ 0,
197
+ audio,
198
+ edge_output_filename if not use_uploaded_voice else uploaded_file_path,
199
+ times,
200
+ f0_up_key,
201
+ f0_method,
202
+ index_file,
203
+ index_rate,
204
+ if_f0,
205
+ filter_radius,
206
+ tgt_sr,
207
+ resample_sr,
208
+ rms_mix_rate,
209
+ version,
210
+ protect,
211
+ None,
212
+ )
213
+
214
+ if tgt_sr != resample_sr and resample_sr >= 16000:
215
+ tgt_sr = resample_sr
216
+
217
+ info = f"Success. Time: tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
218
+ print(info)
219
+ return (
220
+ info,
221
+ edge_output_filename if not use_uploaded_voice else None,
222
+ (tgt_sr, audio_opt),
223
+ edge_output_filename
224
+ )
225
+
226
+ except EOFError:
227
+ info = (
228
+ "output not valid. This may occur when input text and speaker do not match."
229
+ )
230
+ print(info)
231
+ return info, None, None
232
+ except Exception as e:
233
+ traceback_info = traceback.format_exc()
234
+ print(traceback_info)
235
+ return str(e), None, None
236
+
237
+
238
+ voice_mapping = {
239
+ "Mongolian Male": "mn-MN-BataaNeural",
240
+ "Mongolian Female": "mn-MN-YesuiNeural"
241
+ }
242
+
243
+
244
+
245
+ hubert_model = load_hubert()
246
+
247
+ rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
248
+