JacobLinCool commited on
Commit
3b574d5
1 Parent(s): 5459d70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -9
app.py CHANGED
@@ -1,9 +1,9 @@
1
- from typing import Tuple
2
  import gradio as gr
3
  import zipfile
4
  import os
5
  import tempfile
6
  import shutil
 
7
  from infer.modules.train.preprocess import PreProcess, preprocess_trainset
8
  from infer.modules.train.extract.extract_f0_rmvpe import FeatureInput
9
  from zero import zero
@@ -13,11 +13,7 @@ def extract_audio_files(zip_file: str, target_dir: str) -> list[str]:
13
  with zipfile.ZipFile(zip_file, "r") as zip_ref:
14
  zip_ref.extractall(target_dir)
15
 
16
- audio_files = [
17
- os.path.join(target_dir, f)
18
- for f in os.listdir(target_dir)
19
- if f.endswith((".wav", ".mp3", ".ogg"))
20
- ]
21
  if not audio_files:
22
  raise gr.Error("No audio files found in the zip archive.")
23
 
@@ -35,9 +31,6 @@ def preprocess(zip_file: str) -> str:
35
  data_dir = os.path.join(temp_dir, "_data")
36
  os.makedirs(data_dir)
37
  audio_files = extract_audio_files(zip_file, data_dir)
38
- if not audio_files:
39
- shutil.rmtree(temp_dir)
40
- raise gr.Error("No audio files found in the zip archive.")
41
 
42
  pp = PreProcess(48000, temp_dir, 3.0, False)
43
  pp.pipeline_mp_inp_dir(data_dir, 4)
 
 
1
  import gradio as gr
2
  import zipfile
3
  import os
4
  import tempfile
5
  import shutil
6
+ from glob import glob
7
  from infer.modules.train.preprocess import PreProcess, preprocess_trainset
8
  from infer.modules.train.extract.extract_f0_rmvpe import FeatureInput
9
  from zero import zero
 
13
  with zipfile.ZipFile(zip_file, "r") as zip_ref:
14
  zip_ref.extractall(target_dir)
15
 
16
+ audio_files = glob(f"{target_dir}/**/*.wav", recursive=True)
 
 
 
 
17
  if not audio_files:
18
  raise gr.Error("No audio files found in the zip archive.")
19
 
 
31
  data_dir = os.path.join(temp_dir, "_data")
32
  os.makedirs(data_dir)
33
  audio_files = extract_audio_files(zip_file, data_dir)
 
 
 
34
 
35
  pp = PreProcess(48000, temp_dir, 3.0, False)
36
  pp.pipeline_mp_inp_dir(data_dir, 4)