Vijish commited on
Commit
7dd7200
·
verified ·
1 Parent(s): 90e73a5

Upload voice_processing.py

Browse files
Files changed (1) hide show
  1. voice_processing.py +231 -0
voice_processing.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import datetime
3
+ import logging
4
+ import os
5
+ import time
6
+ import traceback
7
+ import tempfile
8
+ from concurrent.futures import ThreadPoolExecutor
9
+
10
+ # import edge_tts # Commented out as we're not using Edge TTS
11
+ import librosa
12
+ import torch
13
+ from fairseq import checkpoint_utils
14
+ import uuid
15
+
16
+ from config import Config
17
+ from lib.infer_pack.models import (
18
+ SynthesizerTrnMs256NSFsid,
19
+ SynthesizerTrnMs256NSFsid_nono,
20
+ SynthesizerTrnMs768NSFsid,
21
+ SynthesizerTrnMs768NSFsid_nono,
22
+ )
23
+ from rmvpe import RMVPE
24
+ from vc_infer_pipeline import VC
25
+
26
+ model_cache = {}
27
+
28
+
29
+ # Set logging levels
30
+ logging.getLogger("fairseq").setLevel(logging.WARNING)
31
+ logging.getLogger("numba").setLevel(logging.WARNING)
32
+ logging.getLogger("markdown_it").setLevel(logging.WARNING)
33
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
34
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
35
+
36
+ limitation = os.getenv("SYSTEM") == "spaces"
37
+
38
+ config = Config()
39
+
40
+ # Edge TTS voices
41
+ # tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
42
+ # tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"]
43
+
44
+ # RVC models directory
45
+ model_root = "weights"
46
+ models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
47
+ models.sort()
48
+
49
+ def get_unique_filename(extension):
50
+ return f"{uuid.uuid4()}.{extension}"
51
+
52
+ def model_data(model_name):
53
+ # We will not modify this function to cache models
54
+ pth_path = [
55
+ f"{model_root}/{model_name}/{f}"
56
+ for f in os.listdir(f"{model_root}/{model_name}")
57
+ if f.endswith(".pth")
58
+ ][0]
59
+ print(f"Loading {pth_path}")
60
+ cpt = torch.load(pth_path, map_location="cpu")
61
+ tgt_sr = cpt["config"][-1]
62
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
63
+ if_f0 = cpt.get("f0", 1)
64
+ version = cpt.get("version", "v1")
65
+ if version == "v1":
66
+ if if_f0 == 1:
67
+ net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
68
+ else:
69
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
70
+ elif version == "v2":
71
+ if if_f0 == 1:
72
+ net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
73
+ else:
74
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
75
+ else:
76
+ raise ValueError("Unknown version")
77
+ del net_g.enc_q
78
+ net_g.load_state_dict(cpt["weight"], strict=False)
79
+ print("Model loaded")
80
+ net_g.eval().to(config.device)
81
+ if config.is_half:
82
+ net_g = net_g.half()
83
+ else:
84
+ net_g = net_g.float()
85
+ vc = VC(tgt_sr, config)
86
+
87
+ index_files = [
88
+ f"{model_root}/{model_name}/{f}"
89
+ for f in os.listdir(f"{model_root}/{model_name}")
90
+ if f.endswith(".index")
91
+ ]
92
+ if len(index_files) == 0:
93
+ print("No index file found")
94
+ index_file = ""
95
+ else:
96
+ index_file = index_files[0]
97
+ print(f"Index file found: {index_file}")
98
+
99
+ return tgt_sr, net_g, vc, version, index_file, if_f0
100
+
101
+ def load_hubert():
102
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
103
+ ["hubert_base.pt"],
104
+ suffix="",
105
+ )
106
+ hubert_model = models[0]
107
+ hubert_model = hubert_model.to(config.device)
108
+ if config.is_half:
109
+ hubert_model = hubert_model.half()
110
+ else:
111
+ hubert_model = hubert_model.float()
112
+ return hubert_model.eval()
113
+
114
+ def get_model_names():
115
+ return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
116
+
117
+ # Initialize the global models
118
+ hubert_model = load_hubert()
119
+ rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
120
+
121
+ # voice_mapping = {
122
+ # "Mongolian Male": "mn-MN-BataaNeural",
123
+ # "Mongolian Female": "mn-MN-YesuiNeural"
124
+ # }
125
+
126
+ # Function to run async functions in a new event loop within a thread
127
+ def run_async_in_thread(fn, *args):
128
+ loop = asyncio.new_event_loop()
129
+ asyncio.set_event_loop(loop)
130
+ result = loop.run_until_complete(fn(*args))
131
+ loop.close()
132
+ return result
133
+
134
+ def parallel_tts(tasks):
135
+ with ThreadPoolExecutor(max_workers=10) as executor:
136
+ # futures = [executor.submit(run_async_in_thread, tts, *task) for task in tasks] # Original line
137
+ futures = [executor.submit(run_async_in_thread, process_audio, *task) for task in tasks] # New line
138
+ results = [future.result() for future in futures]
139
+ return results
140
+
141
+ # Keep the original tts function but commented out
142
+ '''
143
+ async def tts(
144
+ model_name,
145
+ tts_text,
146
+ tts_voice,
147
+ index_rate,
148
+ use_uploaded_voice,
149
+ uploaded_voice,
150
+ ):
151
+ # Original TTS function code here
152
+ ...
153
+ '''
154
+
155
+ # New function for audio processing only
156
+ async def process_audio(
157
+ model_name,
158
+ text_placeholder,
159
+ voice_placeholder,
160
+ index_rate,
161
+ use_uploaded_voice,
162
+ uploaded_voice,
163
+ ):
164
+ # Default values for parameters
165
+ f0_up_key = 0
166
+ f0_method = "rmvpe"
167
+ protect = 0.33
168
+ filter_radius = 3
169
+ resample_sr = 0
170
+ rms_mix_rate = 0.25
171
+
172
+ try:
173
+ if uploaded_voice is None:
174
+ return "No voice file uploaded.", None, None
175
+
176
+ # Process the uploaded voice file - read the file instead of writing it
177
+ audio, sr = librosa.load(uploaded_voice, sr=16000, mono=True) # Load directly from filepath
178
+
179
+ duration = len(audio) / sr
180
+ print(f"Audio duration: {duration}s")
181
+ if limitation and duration >= 20000:
182
+ return (
183
+ f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
184
+ None,
185
+ None,
186
+ )
187
+
188
+ # Load the model and process audio
189
+ tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
190
+
191
+ if f0_method == "rmvpe":
192
+ vc.model_rmvpe = rmvpe_model
193
+
194
+ times = [0, 0, 0]
195
+ audio_opt = vc.pipeline(
196
+ hubert_model,
197
+ net_g,
198
+ 0,
199
+ audio,
200
+ uploaded_voice, # Use the filepath directly
201
+ times,
202
+ f0_up_key,
203
+ f0_method,
204
+ index_file,
205
+ index_rate,
206
+ if_f0,
207
+ filter_radius,
208
+ tgt_sr,
209
+ resample_sr,
210
+ rms_mix_rate,
211
+ version,
212
+ protect,
213
+ None,
214
+ )
215
+
216
+ if tgt_sr != resample_sr and resample_sr >= 16000:
217
+ tgt_sr = resample_sr
218
+
219
+ info = f"Success. Time: npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
220
+ print(info)
221
+ return (
222
+ info,
223
+ None,
224
+ (tgt_sr, audio_opt),
225
+ )
226
+
227
+ except Exception as e:
228
+ traceback_info = traceback.format_exc()
229
+ print(traceback_info)
230
+ return str(e), None, None
231
+