smjain commited on
Commit
8850937
·
verified ·
1 Parent(s): 7009380

Upload infer_new.py

Browse files
Files changed (1) hide show
  1. infer_new.py +569 -0
infer_new.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os, traceback, sys, warnings, shutil, numpy as np
2
+ import gradio as gr
3
+ import librosa
4
+ import asyncio
5
+ import rarfile
6
+ import edge_tts
7
+ import yt_dlp
8
+ import ffmpeg
9
+ import gdown
10
+ import subprocess
11
+ import wave
12
+ import soundfile as sf
13
+ from scipy.io import wavfile
14
+ from datetime import datetime
15
+ from urllib.parse import urlparse
16
+ from mega import Mega
17
+
18
+ import base64
19
+ import tempfile
20
+ import os
21
+
22
+ from pydub import AudioSegment
23
+
24
+
25
+
26
+
27
+
28
+ now_dir = os.getcwd()
29
+ tmp = os.path.join(now_dir, "TEMP")
30
+ shutil.rmtree(tmp, ignore_errors=True)
31
+ os.makedirs(tmp, exist_ok=True)
32
+ os.environ["TEMP"] = tmp
33
+ split_model="htdemucs"
34
+ from lib.infer_pack.models import (
35
+ SynthesizerTrnMs256NSFsid,
36
+ SynthesizerTrnMs256NSFsid_nono,
37
+ SynthesizerTrnMs768NSFsid,
38
+ SynthesizerTrnMs768NSFsid_nono,
39
+ )
40
+ from fairseq import checkpoint_utils
41
+ from vc_infer_pipeline import VC
42
+ from config import Config
43
+ config = Config()
44
+
45
+ tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
46
+ voices = [f"{v['ShortName']}-{v['Gender']}" for v in tts_voice_list]
47
+
48
+ hubert_model = None
49
+
50
+ f0method_mode = ["pm", "harvest", "crepe"]
51
+ f0method_info = "PM is fast, Harvest is good but extremely slow, and Crepe effect is good but requires GPU (Default: PM)"
52
+
53
+ if os.path.isfile("rmvpe.pt"):
54
+ f0method_mode.insert(2, "rmvpe")
55
+ f0method_info = "PM is fast, Harvest is good but extremely slow, Rvmpe is alternative to harvest (might be better), and Crepe effect is good but requires GPU (Default: PM)"
56
+
57
+ def load_hubert():
58
+ global hubert_model
59
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
60
+ ["hubert_base.pt"],
61
+ suffix="",
62
+ )
63
+ hubert_model = models[0]
64
+ hubert_model = hubert_model.to(config.device)
65
+ if config.is_half:
66
+ hubert_model = hubert_model.half()
67
+ else:
68
+ hubert_model = hubert_model.float()
69
+ hubert_model.eval()
70
+
71
+ load_hubert()
72
+
73
+ weight_root = "weights"
74
+ index_root = "weights/index"
75
+ weights_model = []
76
+ weights_index = []
77
+ for _, _, model_files in os.walk(weight_root):
78
+ for file in model_files:
79
+ if file.endswith(".pth"):
80
+ weights_model.append(file)
81
+ for _, _, index_files in os.walk(index_root):
82
+ for file in index_files:
83
+ if file.endswith('.index') and "trained" not in file:
84
+ weights_index.append(os.path.join(index_root, file))
85
+
86
+ def check_models():
87
+ weights_model = []
88
+ weights_index = []
89
+ for _, _, model_files in os.walk(weight_root):
90
+ for file in model_files:
91
+ if file.endswith(".pth"):
92
+ weights_model.append(file)
93
+ for _, _, index_files in os.walk(index_root):
94
+ for file in index_files:
95
+ if file.endswith('.index') and "trained" not in file:
96
+ weights_index.append(os.path.join(index_root, file))
97
+ return (
98
+ gr.Dropdown.update(choices=sorted(weights_model), value=weights_model[0]),
99
+ gr.Dropdown.update(choices=sorted(weights_index))
100
+ )
101
+
102
+ def clean():
103
+ return (
104
+ gr.Dropdown.update(value=""),
105
+ gr.Slider.update(visible=False)
106
+ )
107
+
108
+
109
+
110
+ def api_convert_voice(spk_id,voice_transform,input_audio_path):
111
+
112
+ #split audio
113
+ cut_vocal_and_inst(input_audio_path,spk_id)
114
+ print("audio splitting performed")
115
+ vocal_path = f"output/{split_model}/{spk_id}_input_audio/vocals.wav"
116
+ inst = f"output/{split_model}/{spk_id}_input_audio/no_vocals.wav"
117
+
118
+ output_path = convert_voice(spk_id, vocal_path, voice_transform)
119
+ output_path1= combine_vocal_and_inst(output_path,inst)
120
+ print(output_path1)
121
+ return output_path1
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+ def convert_voice(spk_id, input_audio_path, voice_transform):
130
+ get_vc(spk_id,0.5)
131
+ output_audio_path = vc_single(
132
+ sid=0,
133
+ input_audio_path=input_audio_path,
134
+ f0_up_key=voice_transform, # Assuming voice_transform corresponds to f0_up_key
135
+ f0_file=None ,
136
+ f0_method="rmvpe",
137
+ file_index=spk_id, # Assuming file_index_path corresponds to file_index
138
+ index_rate=0.75,
139
+ filter_radius=3,
140
+ resample_sr=0,
141
+ rms_mix_rate=0.25,
142
+ protect=0.33 # Adjusted from protect_rate to protect to match the function signature
143
+ )
144
+ print(output_audio_path)
145
+ return output_audio_path
146
+
147
+
148
+ def vc_single(
149
+ sid,
150
+ input_audio_path,
151
+ f0_up_key,
152
+ f0_file,
153
+ f0_method,
154
+ file_index,
155
+ index_rate,
156
+ filter_radius,
157
+ resample_sr,
158
+ rms_mix_rate,
159
+ protect
160
+ ): # spk_item, input_audio0, vc_transform0,f0_file,f0method0
161
+ global tgt_sr, net_g, vc, hubert_model, version, cpt
162
+
163
+ try:
164
+ logs = []
165
+ print(f"Converting...")
166
+
167
+ audio, sr = librosa.load(input_audio_path, sr=16000, mono=True)
168
+ print(f"found audio ")
169
+ f0_up_key = int(f0_up_key)
170
+ times = [0, 0, 0]
171
+ if hubert_model == None:
172
+ load_hubert()
173
+ print("loaded hubert")
174
+ if_f0 = 1
175
+ audio_opt = vc.pipeline(
176
+ hubert_model,
177
+ net_g,
178
+ 0,
179
+ audio,
180
+ input_audio_path,
181
+ times,
182
+ f0_up_key,
183
+ f0_method,
184
+ file_index,
185
+ # file_big_npy,
186
+ index_rate,
187
+ if_f0,
188
+ filter_radius,
189
+ tgt_sr,
190
+ resample_sr,
191
+ rms_mix_rate,
192
+ version,
193
+ protect,
194
+ f0_file=f0_file
195
+ )
196
+ if resample_sr >= 16000 and tgt_sr != resample_sr:
197
+ tgt_sr = resample_sr
198
+ index_info = (
199
+ "Using index:%s." % file_index
200
+ if os.path.exists(file_index)
201
+ else "Index not used."
202
+ )
203
+ print("writing to FS")
204
+ output_file_path = os.path.join("output", f"converted_audio_{sid}.wav") # Adjust path as needed
205
+
206
+ os.makedirs(os.path.dirname(output_file_path), exist_ok=True) # Create the output directory if it doesn't exist
207
+ print("create dir")
208
+ # Save the audio file using the target sampling rate
209
+ sf.write(output_file_path, audio_opt, tgt_sr)
210
+
211
+ print("wrote to FS")
212
+
213
+ # Return the path to the saved file along with any other information
214
+
215
+ return output_file_path
216
+
217
+
218
+ except:
219
+ info = traceback.format_exc()
220
+
221
+ return info, (None, None)
222
+
223
+ def get_vc(sid, to_return_protect0):
224
+ global n_spk, tgt_sr, net_g, vc, cpt, version, weights_index
225
+ if sid == "" or sid == []:
226
+ global hubert_model
227
+ if hubert_model is not None: # 考虑到轮询, 需要加个判断看是否 sid 是由有模型切换到无模型的
228
+ print("clean_empty_cache")
229
+ del net_g, n_spk, vc, hubert_model, tgt_sr # ,cpt
230
+ hubert_model = net_g = n_spk = vc = hubert_model = tgt_sr = None
231
+ if torch.cuda.is_available():
232
+ torch.cuda.empty_cache()
233
+ ###楼下不这么折腾清理不干净
234
+ if_f0 = cpt.get("f0", 1)
235
+ version = cpt.get("version", "v1")
236
+ if version == "v1":
237
+ if if_f0 == 1:
238
+ net_g = SynthesizerTrnMs256NSFsid(
239
+ *cpt["config"], is_half=config.is_half
240
+ )
241
+ else:
242
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
243
+ elif version == "v2":
244
+ if if_f0 == 1:
245
+ net_g = SynthesizerTrnMs768NSFsid(
246
+ *cpt["config"], is_half=config.is_half
247
+ )
248
+ else:
249
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
250
+ del net_g, cpt
251
+ if torch.cuda.is_available():
252
+ torch.cuda.empty_cache()
253
+ cpt = None
254
+ return (
255
+ gr.Slider.update(maximum=2333, visible=False),
256
+ gr.Slider.update(visible=True),
257
+ gr.Dropdown.update(choices=sorted(weights_index), value=""),
258
+ gr.Markdown.update(value="# <center> No model selected")
259
+ )
260
+ print(f"Loading {sid} model...")
261
+ selected_model = sid[:-4]
262
+ cpt = torch.load(os.path.join(weight_root, sid), map_location="cpu")
263
+ tgt_sr = cpt["config"][-1]
264
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
265
+ if_f0 = cpt.get("f0", 1)
266
+ if if_f0 == 0:
267
+ to_return_protect0 = {
268
+ "visible": False,
269
+ "value": 0.5,
270
+ "__type__": "update",
271
+ }
272
+ else:
273
+ to_return_protect0 = {
274
+ "visible": True,
275
+ "value": to_return_protect0,
276
+ "__type__": "update",
277
+ }
278
+ version = cpt.get("version", "v1")
279
+ if version == "v1":
280
+ if if_f0 == 1:
281
+ net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
282
+ else:
283
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
284
+ elif version == "v2":
285
+ if if_f0 == 1:
286
+ net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
287
+ else:
288
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
289
+ del net_g.enc_q
290
+ print(net_g.load_state_dict(cpt["weight"], strict=False))
291
+ net_g.eval().to(config.device)
292
+ if config.is_half:
293
+ net_g = net_g.half()
294
+ else:
295
+ net_g = net_g.float()
296
+ vc = VC(tgt_sr, config)
297
+ n_spk = cpt["config"][-3]
298
+ weights_index = []
299
+ for _, _, index_files in os.walk(index_root):
300
+ for file in index_files:
301
+ if file.endswith('.index') and "trained" not in file:
302
+ weights_index.append(os.path.join(index_root, file))
303
+ if weights_index == []:
304
+ selected_index = gr.Dropdown.update(value="")
305
+ else:
306
+ selected_index = gr.Dropdown.update(value=weights_index[0])
307
+ for index, model_index in enumerate(weights_index):
308
+ if selected_model in model_index:
309
+ selected_index = gr.Dropdown.update(value=weights_index[index])
310
+ break
311
+ return (
312
+ gr.Slider.update(maximum=n_spk, visible=True),
313
+ to_return_protect0,
314
+ selected_index,
315
+ gr.Markdown.update(
316
+ f'## <center> {selected_model}\n'+
317
+ f'### <center> RVC {version} Model'
318
+ )
319
+ )
320
+
321
+ def find_audio_files(folder_path, extensions):
322
+ audio_files = []
323
+ for root, dirs, files in os.walk(folder_path):
324
+ for file in files:
325
+ if any(file.endswith(ext) for ext in extensions):
326
+ audio_files.append(file)
327
+ return audio_files
328
+
329
+ def vc_multi(
330
+ spk_item,
331
+ vc_input,
332
+ vc_output,
333
+ vc_transform0,
334
+ f0method0,
335
+ file_index,
336
+ index_rate,
337
+ filter_radius,
338
+ resample_sr,
339
+ rms_mix_rate,
340
+ protect,
341
+ ):
342
+ global tgt_sr, net_g, vc, hubert_model, version, cpt
343
+ logs = []
344
+ logs.append("Converting...")
345
+ yield "\n".join(logs)
346
+ print()
347
+ try:
348
+ if os.path.exists(vc_input):
349
+ folder_path = vc_input
350
+ extensions = [".mp3", ".wav", ".flac", ".ogg"]
351
+ audio_files = find_audio_files(folder_path, extensions)
352
+ for index, file in enumerate(audio_files, start=1):
353
+ audio, sr = librosa.load(os.path.join(folder_path, file), sr=16000, mono=True)
354
+ input_audio_path = folder_path, file
355
+ f0_up_key = int(vc_transform0)
356
+ times = [0, 0, 0]
357
+ if hubert_model == None:
358
+ load_hubert()
359
+ if_f0 = cpt.get("f0", 1)
360
+ audio_opt = vc.pipeline(
361
+ hubert_model,
362
+ net_g,
363
+ spk_item,
364
+ audio,
365
+ input_audio_path,
366
+ times,
367
+ f0_up_key,
368
+ f0method0,
369
+ file_index,
370
+ index_rate,
371
+ if_f0,
372
+ filter_radius,
373
+ tgt_sr,
374
+ resample_sr,
375
+ rms_mix_rate,
376
+ version,
377
+ protect,
378
+ f0_file=None
379
+ )
380
+ if resample_sr >= 16000 and tgt_sr != resample_sr:
381
+ tgt_sr = resample_sr
382
+ output_path = f"{os.path.join(vc_output, file)}"
383
+ os.makedirs(os.path.join(vc_output), exist_ok=True)
384
+ sf.write(
385
+ output_path,
386
+ audio_opt,
387
+ tgt_sr,
388
+ )
389
+ info = f"{index} / {len(audio_files)} | {file}"
390
+ print(info)
391
+ logs.append(info)
392
+ yield "\n".join(logs)
393
+ else:
394
+ logs.append("Folder not found or path doesn't exist.")
395
+ yield "\n".join(logs)
396
+ except:
397
+ info = traceback.format_exc()
398
+ print(info)
399
+ logs.append(info)
400
+ yield "\n".join(logs)
401
+
402
+ def download_audio(url, audio_provider):
403
+ logs = []
404
+ os.makedirs("dl_audio", exist_ok=True)
405
+ if url == "":
406
+ logs.append("URL required!")
407
+ yield None, "\n".join(logs)
408
+ return None, "\n".join(logs)
409
+ if audio_provider == "Youtube":
410
+ logs.append("Downloading the audio...")
411
+ yield None, "\n".join(logs)
412
+ ydl_opts = {
413
+ 'noplaylist': True,
414
+ 'format': 'bestaudio/best',
415
+ 'postprocessors': [{
416
+ 'key': 'FFmpegExtractAudio',
417
+ 'preferredcodec': 'wav',
418
+ }],
419
+ "outtmpl": 'result/dl_audio/audio',
420
+ }
421
+ audio_path = "result/dl_audio/audio.wav"
422
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
423
+ ydl.download([url])
424
+ logs.append("Download Complete.")
425
+ yield audio_path, "\n".join(logs)
426
+
427
+ def cut_vocal_and_inst_yt(split_model,spk_id):
428
+ logs = []
429
+ logs.append("Starting the audio splitting process...")
430
+ yield "\n".join(logs), None, None, None
431
+ command = f"demucs --two-stems=vocals -n {split_model} result/dl_audio/audio.wav -o output"
432
+ result = subprocess.Popen(command.split(), stdout=subprocess.PIPE, text=True)
433
+ for line in result.stdout:
434
+ logs.append(line)
435
+ yield "\n".join(logs), None, None, None
436
+ print(result.stdout)
437
+ vocal = f"output/{split_model}/{spk_id}_input_audio/vocals.wav"
438
+ inst = f"output/{split_model}/{spk_id}_input_audio/no_vocals.wav"
439
+ logs.append("Audio splitting complete.")
440
+ yield "\n".join(logs), vocal, inst, vocal
441
+
442
+ def cut_vocal_and_inst(audio_path,spk_id):
443
+
444
+ vocal_path = "output/result/audio.wav"
445
+ os.makedirs("output/result", exist_ok=True)
446
+ #wavfile.write(vocal_path, audio_data[0], audio_data[1])
447
+ #logs.append("Starting the audio splitting process...")
448
+ #yield "\n".join(logs), None, None
449
+ print("before executing splitter")
450
+ command = f"demucs --two-stems=vocals -n {split_model} {audio_path} -o output"
451
+ #result = subprocess.Popen(command.split(), stdout=subprocess.PIPE, text=True)
452
+ result = subprocess.run(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
453
+ if result.returncode != 0:
454
+ print("Demucs process failed:", result.stderr)
455
+ else:
456
+ print("Demucs process completed successfully.")
457
+ print("after executing splitter")
458
+ #for line in result.stdout:
459
+ # logs.append(line)
460
+ # yield "\n".join(logs), None, None
461
+
462
+ print(result.stdout)
463
+ vocal = f"output/{split_model}/{spk_id}_input_audio/vocals.wav"
464
+ inst = f"output/{split_model}/{spk_id}_input_audio/no_vocals.wav"
465
+ #logs.append("Audio splitting complete.")
466
+
467
+
468
+ def combine_vocal_and_inst(vocal_path, inst_path):
469
+
470
+ vocal_volume=1
471
+ inst_volume=1
472
+ os.makedirs("output/result", exist_ok=True)
473
+ # Assuming vocal_path and inst_path are now directly passed as arguments
474
+ output_path = "output/result/combine.mp3"
475
+ #command = f'ffmpeg -y -i "{inst_path}" -i "{vocal_path}" -filter_complex [0:a]volume={inst_volume}[i];[1:a]volume={vocal_volume}[v];[i][v]amix=inputs=2:duration=longest[a] -map [a] -b:a 320k -c:a libmp3lame "{output_path}"'
476
+ #command=f'ffmpeg -y -i "{inst_path}" -i "{vocal_path}" -filter_complex "amix=inputs=2:duration=longest" -b:a 320k -c:a libmp3lame "{output_path}"'
477
+ # Load the audio files
478
+ vocal = AudioSegment.from_file(vocal_path)
479
+ instrumental = AudioSegment.from_file(inst_path)
480
+
481
+ # Overlay the vocal track on top of the instrumental track
482
+ combined = vocal.overlay(instrumental)
483
+
484
+ # Export the result
485
+ combined.export(output_path, format="mp3")
486
+
487
+ #result = subprocess.run(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
488
+ return output_path
489
+
490
+ #def combine_vocal_and_inst(audio_data, vocal_volume, inst_volume):
491
+ # os.makedirs("output/result", exist_ok=True)
492
+ ## output_path = "output/result/combine.mp3"
493
+ # inst_path = f"output/{split_model}/audio/no_vocals.wav"
494
+ #wavfile.write(vocal_path, audio_data[0], audio_data[1])
495
+ #command = f'ffmpeg -y -i {inst_path} -i {vocal_path} -filter_complex [0:a]volume={inst_volume}[i];[1:a]volume={vocal_volume}[v];[i][v]amix=inputs=2:duration=longest[a] -map [a] -b:a 320k -c:a libmp3lame {output_path}'
496
+ #result = subprocess.run(command.split(), stdout=subprocess.PIPE)
497
+ #print(result.stdout.decode())
498
+ #return output_path
499
+
500
+ def download_and_extract_models(urls):
501
+ logs = []
502
+ os.makedirs("zips", exist_ok=True)
503
+ os.makedirs(os.path.join("zips", "extract"), exist_ok=True)
504
+ os.makedirs(os.path.join(weight_root), exist_ok=True)
505
+ os.makedirs(os.path.join(index_root), exist_ok=True)
506
+ for link in urls.splitlines():
507
+ url = link.strip()
508
+ if not url:
509
+ raise gr.Error("URL Required!")
510
+ return "No URLs provided."
511
+ model_zip = urlparse(url).path.split('/')[-2] + '.zip'
512
+ model_zip_path = os.path.join('zips', model_zip)
513
+ logs.append(f"Downloading...")
514
+ yield "\n".join(logs)
515
+ if "drive.google.com" in url:
516
+ gdown.download(url, os.path.join("zips", "extract"), quiet=False)
517
+ elif "mega.nz" in url:
518
+ m = Mega()
519
+ m.download_url(url, 'zips')
520
+ else:
521
+ os.system(f"wget {url} -O {model_zip_path}")
522
+ logs.append(f"Extracting...")
523
+ yield "\n".join(logs)
524
+ for filename in os.listdir("zips"):
525
+ archived_file = os.path.join("zips", filename)
526
+ if filename.endswith(".zip"):
527
+ shutil.unpack_archive(archived_file, os.path.join("zips", "extract"), 'zip')
528
+ elif filename.endswith(".rar"):
529
+ with rarfile.RarFile(archived_file, 'r') as rar:
530
+ rar.extractall(os.path.join("zips", "extract"))
531
+ for _, dirs, files in os.walk(os.path.join("zips", "extract")):
532
+ logs.append(f"Searching Model and Index...")
533
+ yield "\n".join(logs)
534
+ model = False
535
+ index = False
536
+ if files:
537
+ for file in files:
538
+ if file.endswith(".pth"):
539
+ basename = file[:-4]
540
+ shutil.move(os.path.join("zips", "extract", file), os.path.join(weight_root, file))
541
+ model = True
542
+ if file.endswith('.index') and "trained" not in file:
543
+ shutil.move(os.path.join("zips", "extract", file), os.path.join(index_root, file))
544
+ index = True
545
+ else:
546
+ logs.append("No model in main folder.")
547
+ yield "\n".join(logs)
548
+ logs.append("Searching in subfolders...")
549
+ yield "\n".join(logs)
550
+ for sub_dir in dirs:
551
+ for _, _, sub_files in os.walk(os.path.join("zips", "extract", sub_dir)):
552
+ for file in sub_files:
553
+ if file.endswith(".pth"):
554
+ basename = file[:-4]
555
+ shutil.move(os.path.join("zips", "extract", sub_dir, file), os.path.join(weight_root, file))
556
+ model = True
557
+ if file.endswith('.index') and "trained" not in file:
558
+ shutil.move(os.path.join("zips", "extract", sub_dir, file), os.path.join(index_root, file))
559
+ index = True
560
+ shutil.rmtree(os.path.join("zips", "extract", sub_dir))
561
+ if index is False:
562
+ logs.append("Model only file, no Index file detected.")
563
+ yield "\n".join(logs)
564
+ logs.append("Download Completed!")
565
+ yield "\n".join(logs)
566
+ logs.append("Successfully download all models! Refresh your model list to load the model")
567
+ yield "\n".join(logs)
568
+ if __name__ == '__main__':
569
+ app.run(debug=False, port=5000,host='0.0.0.0')