smjain commited on
Commit
058946e
·
verified ·
1 Parent(s): 750eed3

Upload infertest.py

Browse files
Files changed (1) hide show
  1. infertest.py +507 -0
infertest.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os, traceback, sys, warnings, shutil, numpy as np
2
+ import gradio as gr
3
+ import librosa
4
+ import asyncio
5
+ import rarfile
6
+ import edge_tts
7
+ import yt_dlp
8
+ import ffmpeg
9
+ import gdown
10
+ import subprocess
11
+ import wave
12
+ import soundfile as sf
13
+ from scipy.io import wavfile
14
+ from datetime import datetime
15
+ from urllib.parse import urlparse
16
+ from mega import Mega
17
+ from flask import Flask, request, jsonify
18
+ app = Flask(__name__)
19
+
20
+ now_dir = os.getcwd()
21
+ tmp = os.path.join(now_dir, "TEMP")
22
+ shutil.rmtree(tmp, ignore_errors=True)
23
+ os.makedirs(tmp, exist_ok=True)
24
+ os.environ["TEMP"] = tmp
25
+ from lib.infer_pack.models import (
26
+ SynthesizerTrnMs256NSFsid,
27
+ SynthesizerTrnMs256NSFsid_nono,
28
+ SynthesizerTrnMs768NSFsid,
29
+ SynthesizerTrnMs768NSFsid_nono,
30
+ )
31
+ from fairseq import checkpoint_utils
32
+ from vc_infer_pipeline import VC
33
+ from config import Config
34
+ config = Config()
35
+
36
+ tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
37
+ voices = [f"{v['ShortName']}-{v['Gender']}" for v in tts_voice_list]
38
+
39
+ hubert_model = None
40
+
41
+ f0method_mode = ["pm", "harvest", "crepe"]
42
+ f0method_info = "PM is fast, Harvest is good but extremely slow, and Crepe effect is good but requires GPU (Default: PM)"
43
+
44
+ if os.path.isfile("rmvpe.pt"):
45
+ f0method_mode.insert(2, "rmvpe")
46
+ f0method_info = "PM is fast, Harvest is good but extremely slow, Rvmpe is alternative to harvest (might be better), and Crepe effect is good but requires GPU (Default: PM)"
47
+
48
+ def load_hubert():
49
+ global hubert_model
50
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
51
+ ["hubert_base.pt"],
52
+ suffix="",
53
+ )
54
+ hubert_model = models[0]
55
+ hubert_model = hubert_model.to(config.device)
56
+ if config.is_half:
57
+ hubert_model = hubert_model.half()
58
+ else:
59
+ hubert_model = hubert_model.float()
60
+ hubert_model.eval()
61
+
62
+ load_hubert()
63
+
64
+ weight_root = "weights"
65
+ index_root = "weights/index"
66
+ weights_model = []
67
+ weights_index = []
68
+ for _, _, model_files in os.walk(weight_root):
69
+ for file in model_files:
70
+ if file.endswith(".pth"):
71
+ weights_model.append(file)
72
+ for _, _, index_files in os.walk(index_root):
73
+ for file in index_files:
74
+ if file.endswith('.index') and "trained" not in file:
75
+ weights_index.append(os.path.join(index_root, file))
76
+
77
+ def check_models():
78
+ weights_model = []
79
+ weights_index = []
80
+ for _, _, model_files in os.walk(weight_root):
81
+ for file in model_files:
82
+ if file.endswith(".pth"):
83
+ weights_model.append(file)
84
+ for _, _, index_files in os.walk(index_root):
85
+ for file in index_files:
86
+ if file.endswith('.index') and "trained" not in file:
87
+ weights_index.append(os.path.join(index_root, file))
88
+ return (
89
+ gr.Dropdown.update(choices=sorted(weights_model), value=weights_model[0]),
90
+ gr.Dropdown.update(choices=sorted(weights_index))
91
+ )
92
+
93
+ def clean():
94
+ return (
95
+ gr.Dropdown.update(value=""),
96
+ gr.Slider.update(visible=False)
97
+ )
98
+
99
+
100
+
101
+
102
+ @app.route('/convert_voice', methods=['POST'])
103
+ def convert_voice(spk_id, input_audio_path, voice_transform):
104
+
105
+ output_audio_path = vc_single(
106
+ sid=spk_id,
107
+ input_audio_path=input_audio_path,
108
+ f0_up_key=voice_transform, # Assuming voice_transform corresponds to f0_up_key
109
+ f0_file=None if not f0_file else f0_file,
110
+ f0_method="rmvpe",
111
+ file_index=None, # Assuming file_index_path corresponds to file_index
112
+ index_rate=0.75,
113
+ filter_radius=3,
114
+ resample_sr=0,
115
+ rms_mix_rate=0.25,
116
+ protect=0.33 # Adjusted from protect_rate to protect to match the function signature
117
+ )
118
+ return output_audio_path
119
+
120
+
121
+ def vc_single(
122
+ sid,
123
+ input_audio_path,
124
+ f0_up_key,
125
+ f0_file,
126
+ f0_method,
127
+ file_index,
128
+ index_rate,
129
+ filter_radius,
130
+ resample_sr,
131
+ rms_mix_rate,
132
+ protect
133
+ ): # spk_item, input_audio0, vc_transform0,f0_file,f0method0
134
+ global tgt_sr, net_g, vc, hubert_model, version, cpt
135
+ try:
136
+ logs = []
137
+ print(f"Converting...")
138
+ logs.append(f"Converting...")
139
+ yield "\n".join(logs), None
140
+
141
+ f0_up_key = int(f0_up_key)
142
+ times = [0, 0, 0]
143
+ if hubert_model == None:
144
+ load_hubert()
145
+ if_f0 = cpt.get("f0", 1)
146
+ audio_opt = vc.pipeline(
147
+ hubert_model,
148
+ net_g,
149
+ sid,
150
+ audio,
151
+ input_audio_path,
152
+ times,
153
+ f0_up_key,
154
+ f0_method,
155
+ file_index,
156
+ # file_big_npy,
157
+ index_rate,
158
+ if_f0,
159
+ filter_radius,
160
+ tgt_sr,
161
+ resample_sr,
162
+ rms_mix_rate,
163
+ version,
164
+ protect,
165
+ f0_file=f0_file
166
+ )
167
+ if resample_sr >= 16000 and tgt_sr != resample_sr:
168
+ tgt_sr = resample_sr
169
+ index_info = (
170
+ "Using index:%s." % file_index
171
+ if os.path.exists(file_index)
172
+ else "Index not used."
173
+ )
174
+ output_file_path = os.path.join("output", f"converted_audio_{sid}.wav") # Adjust path as needed
175
+ os.makedirs(os.path.dirname(output_file_path), exist_ok=True) # Create the output directory if it doesn't exist
176
+
177
+ # Save the audio file using the target sampling rate
178
+ sf.write(output_file_path, audio_opt, tgt_sr)
179
+
180
+ # Return the path to the saved file along with any other information
181
+
182
+ return (
183
+ f"Success. Audio saved to {output_file_path}\n{index_info}\nTime:\nnpy: %.2fs, f0: %.2fs, infer: %.2fs."
184
+ % (*times,),
185
+ output_file_path,
186
+ )
187
+ except:
188
+ info = traceback.format_exc()
189
+ logger.warning(info)
190
+ return info, (None, None)
191
+
192
+ def get_vc(sid, to_return_protect0):
193
+ global n_spk, tgt_sr, net_g, vc, cpt, version, weights_index
194
+ if sid == "" or sid == []:
195
+ global hubert_model
196
+ if hubert_model is not None: # 考虑到轮询, 需要加个判断看是否 sid 是由有模型切换到无模型的
197
+ print("clean_empty_cache")
198
+ del net_g, n_spk, vc, hubert_model, tgt_sr # ,cpt
199
+ hubert_model = net_g = n_spk = vc = hubert_model = tgt_sr = None
200
+ if torch.cuda.is_available():
201
+ torch.cuda.empty_cache()
202
+ ###楼下不这么折腾清理不干净
203
+ if_f0 = cpt.get("f0", 1)
204
+ version = cpt.get("version", "v1")
205
+ if version == "v1":
206
+ if if_f0 == 1:
207
+ net_g = SynthesizerTrnMs256NSFsid(
208
+ *cpt["config"], is_half=config.is_half
209
+ )
210
+ else:
211
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
212
+ elif version == "v2":
213
+ if if_f0 == 1:
214
+ net_g = SynthesizerTrnMs768NSFsid(
215
+ *cpt["config"], is_half=config.is_half
216
+ )
217
+ else:
218
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
219
+ del net_g, cpt
220
+ if torch.cuda.is_available():
221
+ torch.cuda.empty_cache()
222
+ cpt = None
223
+ return (
224
+ gr.Slider.update(maximum=2333, visible=False),
225
+ gr.Slider.update(visible=True),
226
+ gr.Dropdown.update(choices=sorted(weights_index), value=""),
227
+ gr.Markdown.update(value="# <center> No model selected")
228
+ )
229
+ print(f"Loading {sid} model...")
230
+ selected_model = sid[:-4]
231
+ cpt = torch.load(os.path.join(weight_root, sid), map_location="cpu")
232
+ tgt_sr = cpt["config"][-1]
233
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
234
+ if_f0 = cpt.get("f0", 1)
235
+ if if_f0 == 0:
236
+ to_return_protect0 = {
237
+ "visible": False,
238
+ "value": 0.5,
239
+ "__type__": "update",
240
+ }
241
+ else:
242
+ to_return_protect0 = {
243
+ "visible": True,
244
+ "value": to_return_protect0,
245
+ "__type__": "update",
246
+ }
247
+ version = cpt.get("version", "v1")
248
+ if version == "v1":
249
+ if if_f0 == 1:
250
+ net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
251
+ else:
252
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
253
+ elif version == "v2":
254
+ if if_f0 == 1:
255
+ net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
256
+ else:
257
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
258
+ del net_g.enc_q
259
+ print(net_g.load_state_dict(cpt["weight"], strict=False))
260
+ net_g.eval().to(config.device)
261
+ if config.is_half:
262
+ net_g = net_g.half()
263
+ else:
264
+ net_g = net_g.float()
265
+ vc = VC(tgt_sr, config)
266
+ n_spk = cpt["config"][-3]
267
+ weights_index = []
268
+ for _, _, index_files in os.walk(index_root):
269
+ for file in index_files:
270
+ if file.endswith('.index') and "trained" not in file:
271
+ weights_index.append(os.path.join(index_root, file))
272
+ if weights_index == []:
273
+ selected_index = gr.Dropdown.update(value="")
274
+ else:
275
+ selected_index = gr.Dropdown.update(value=weights_index[0])
276
+ for index, model_index in enumerate(weights_index):
277
+ if selected_model in model_index:
278
+ selected_index = gr.Dropdown.update(value=weights_index[index])
279
+ break
280
+ return (
281
+ gr.Slider.update(maximum=n_spk, visible=True),
282
+ to_return_protect0,
283
+ selected_index,
284
+ gr.Markdown.update(
285
+ f'## <center> {selected_model}\n'+
286
+ f'### <center> RVC {version} Model'
287
+ )
288
+ )
289
+
290
+ def find_audio_files(folder_path, extensions):
291
+ audio_files = []
292
+ for root, dirs, files in os.walk(folder_path):
293
+ for file in files:
294
+ if any(file.endswith(ext) for ext in extensions):
295
+ audio_files.append(file)
296
+ return audio_files
297
+
298
+ def vc_multi(
299
+ spk_item,
300
+ vc_input,
301
+ vc_output,
302
+ vc_transform0,
303
+ f0method0,
304
+ file_index,
305
+ index_rate,
306
+ filter_radius,
307
+ resample_sr,
308
+ rms_mix_rate,
309
+ protect,
310
+ ):
311
+ global tgt_sr, net_g, vc, hubert_model, version, cpt
312
+ logs = []
313
+ logs.append("Converting...")
314
+ yield "\n".join(logs)
315
+ print()
316
+ try:
317
+ if os.path.exists(vc_input):
318
+ folder_path = vc_input
319
+ extensions = [".mp3", ".wav", ".flac", ".ogg"]
320
+ audio_files = find_audio_files(folder_path, extensions)
321
+ for index, file in enumerate(audio_files, start=1):
322
+ audio, sr = librosa.load(os.path.join(folder_path, file), sr=16000, mono=True)
323
+ input_audio_path = folder_path, file
324
+ f0_up_key = int(vc_transform0)
325
+ times = [0, 0, 0]
326
+ if hubert_model == None:
327
+ load_hubert()
328
+ if_f0 = cpt.get("f0", 1)
329
+ audio_opt = vc.pipeline(
330
+ hubert_model,
331
+ net_g,
332
+ spk_item,
333
+ audio,
334
+ input_audio_path,
335
+ times,
336
+ f0_up_key,
337
+ f0method0,
338
+ file_index,
339
+ index_rate,
340
+ if_f0,
341
+ filter_radius,
342
+ tgt_sr,
343
+ resample_sr,
344
+ rms_mix_rate,
345
+ version,
346
+ protect,
347
+ f0_file=None
348
+ )
349
+ if resample_sr >= 16000 and tgt_sr != resample_sr:
350
+ tgt_sr = resample_sr
351
+ output_path = f"{os.path.join(vc_output, file)}"
352
+ os.makedirs(os.path.join(vc_output), exist_ok=True)
353
+ sf.write(
354
+ output_path,
355
+ audio_opt,
356
+ tgt_sr,
357
+ )
358
+ info = f"{index} / {len(audio_files)} | {file}"
359
+ print(info)
360
+ logs.append(info)
361
+ yield "\n".join(logs)
362
+ else:
363
+ logs.append("Folder not found or path doesn't exist.")
364
+ yield "\n".join(logs)
365
+ except:
366
+ info = traceback.format_exc()
367
+ print(info)
368
+ logs.append(info)
369
+ yield "\n".join(logs)
370
+
371
+ def download_audio(url, audio_provider):
372
+ logs = []
373
+ os.makedirs("dl_audio", exist_ok=True)
374
+ if url == "":
375
+ logs.append("URL required!")
376
+ yield None, "\n".join(logs)
377
+ return None, "\n".join(logs)
378
+ if audio_provider == "Youtube":
379
+ logs.append("Downloading the audio...")
380
+ yield None, "\n".join(logs)
381
+ ydl_opts = {
382
+ 'noplaylist': True,
383
+ 'format': 'bestaudio/best',
384
+ 'postprocessors': [{
385
+ 'key': 'FFmpegExtractAudio',
386
+ 'preferredcodec': 'wav',
387
+ }],
388
+ "outtmpl": 'result/dl_audio/audio',
389
+ }
390
+ audio_path = "result/dl_audio/audio.wav"
391
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
392
+ ydl.download([url])
393
+ logs.append("Download Complete.")
394
+ yield audio_path, "\n".join(logs)
395
+
396
+ def cut_vocal_and_inst_yt(split_model):
397
+ logs = []
398
+ logs.append("Starting the audio splitting process...")
399
+ yield "\n".join(logs), None, None, None
400
+ command = f"demucs --two-stems=vocals -n {split_model} result/dl_audio/audio.wav -o output"
401
+ result = subprocess.Popen(command.split(), stdout=subprocess.PIPE, text=True)
402
+ for line in result.stdout:
403
+ logs.append(line)
404
+ yield "\n".join(logs), None, None, None
405
+ print(result.stdout)
406
+ vocal = f"output/{split_model}/audio/vocals.wav"
407
+ inst = f"output/{split_model}/audio/no_vocals.wav"
408
+ logs.append("Audio splitting complete.")
409
+ yield "\n".join(logs), vocal, inst, vocal
410
+
411
+ def cut_vocal_and_inst(split_model, audio_data):
412
+ logs = []
413
+ vocal_path = "output/result/audio.wav"
414
+ os.makedirs("output/result", exist_ok=True)
415
+ wavfile.write(vocal_path, audio_data[0], audio_data[1])
416
+ logs.append("Starting the audio splitting process...")
417
+ yield "\n".join(logs), None, None
418
+ command = f"demucs --two-stems=vocals -n {split_model} {vocal_path} -o output"
419
+ result = subprocess.Popen(command.split(), stdout=subprocess.PIPE, text=True)
420
+ for line in result.stdout:
421
+ logs.append(line)
422
+ yield "\n".join(logs), None, None
423
+ print(result.stdout)
424
+ vocal = f"output/{split_model}/audio/vocals.wav"
425
+ inst = f"output/{split_model}/audio/no_vocals.wav"
426
+ logs.append("Audio splitting complete.")
427
+ yield "\n".join(logs), vocal, inst
428
+
429
+ def combine_vocal_and_inst(audio_data, vocal_volume, inst_volume, split_model):
430
+ os.makedirs("output/result", exist_ok=True)
431
+ vocal_path = "output/result/output.wav"
432
+ output_path = "output/result/combine.mp3"
433
+ inst_path = f"output/{split_model}/audio/no_vocals.wav"
434
+ wavfile.write(vocal_path, audio_data[0], audio_data[1])
435
+ command = f'ffmpeg -y -i {inst_path} -i {vocal_path} -filter_complex [0:a]volume={inst_volume}[i];[1:a]volume={vocal_volume}[v];[i][v]amix=inputs=2:duration=longest[a] -map [a] -b:a 320k -c:a libmp3lame {output_path}'
436
+ result = subprocess.run(command.split(), stdout=subprocess.PIPE)
437
+ print(result.stdout.decode())
438
+ return output_path
439
+
440
+ def download_and_extract_models(urls):
441
+ logs = []
442
+ os.makedirs("zips", exist_ok=True)
443
+ os.makedirs(os.path.join("zips", "extract"), exist_ok=True)
444
+ os.makedirs(os.path.join(weight_root), exist_ok=True)
445
+ os.makedirs(os.path.join(index_root), exist_ok=True)
446
+ for link in urls.splitlines():
447
+ url = link.strip()
448
+ if not url:
449
+ raise gr.Error("URL Required!")
450
+ return "No URLs provided."
451
+ model_zip = urlparse(url).path.split('/')[-2] + '.zip'
452
+ model_zip_path = os.path.join('zips', model_zip)
453
+ logs.append(f"Downloading...")
454
+ yield "\n".join(logs)
455
+ if "drive.google.com" in url:
456
+ gdown.download(url, os.path.join("zips", "extract"), quiet=False)
457
+ elif "mega.nz" in url:
458
+ m = Mega()
459
+ m.download_url(url, 'zips')
460
+ else:
461
+ os.system(f"wget {url} -O {model_zip_path}")
462
+ logs.append(f"Extracting...")
463
+ yield "\n".join(logs)
464
+ for filename in os.listdir("zips"):
465
+ archived_file = os.path.join("zips", filename)
466
+ if filename.endswith(".zip"):
467
+ shutil.unpack_archive(archived_file, os.path.join("zips", "extract"), 'zip')
468
+ elif filename.endswith(".rar"):
469
+ with rarfile.RarFile(archived_file, 'r') as rar:
470
+ rar.extractall(os.path.join("zips", "extract"))
471
+ for _, dirs, files in os.walk(os.path.join("zips", "extract")):
472
+ logs.append(f"Searching Model and Index...")
473
+ yield "\n".join(logs)
474
+ model = False
475
+ index = False
476
+ if files:
477
+ for file in files:
478
+ if file.endswith(".pth"):
479
+ basename = file[:-4]
480
+ shutil.move(os.path.join("zips", "extract", file), os.path.join(weight_root, file))
481
+ model = True
482
+ if file.endswith('.index') and "trained" not in file:
483
+ shutil.move(os.path.join("zips", "extract", file), os.path.join(index_root, file))
484
+ index = True
485
+ else:
486
+ logs.append("No model in main folder.")
487
+ yield "\n".join(logs)
488
+ logs.append("Searching in subfolders...")
489
+ yield "\n".join(logs)
490
+ for sub_dir in dirs:
491
+ for _, _, sub_files in os.walk(os.path.join("zips", "extract", sub_dir)):
492
+ for file in sub_files:
493
+ if file.endswith(".pth"):
494
+ basename = file[:-4]
495
+ shutil.move(os.path.join("zips", "extract", sub_dir, file), os.path.join(weight_root, file))
496
+ model = True
497
+ if file.endswith('.index') and "trained" not in file:
498
+ shutil.move(os.path.join("zips", "extract", sub_dir, file), os.path.join(index_root, file))
499
+ index = True
500
+ shutil.rmtree(os.path.join("zips", "extract", sub_dir))
501
+ if index is False:
502
+ logs.append("Model only file, no Index file detected.")
503
+ yield "\n".join(logs)
504
+ logs.append("Download Completed!")
505
+ yield "\n".join(logs)
506
+ logs.append("Successfully download all models! Refresh your model list to load the model")
507
+ yield "\n".join(logs)