File size: 3,535 Bytes
32b2aaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import argparse
import random
import time
from pathlib import Path
import torch
import torchaudio
from tqdm import tqdm
from .inference import denoise, enhance
@torch.inference_mode()
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
parser.add_argument("out_dir", type=Path, help="Output folder")
parser.add_argument(
"--run_dir",
type=Path,
default=None,
help="Path to the enhancer run folder, if None, use the default model",
)
parser.add_argument(
"--suffix",
type=str,
default=".wav",
help="Audio file suffix",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for computation, recommended to use CUDA",
)
parser.add_argument(
"--denoise_only",
action="store_true",
help="Only apply denoising without enhancement",
)
parser.add_argument(
"--lambd",
type=float,
default=1.0,
help="Denoise strength for enhancement (0.0 to 1.0)",
)
parser.add_argument(
"--tau",
type=float,
default=0.5,
help="CFM prior temperature (0.0 to 1.0)",
)
parser.add_argument(
"--solver",
type=str,
default="midpoint",
choices=["midpoint", "rk4", "euler"],
help="Numerical solver to use",
)
parser.add_argument(
"--nfe",
type=int,
default=64,
help="Number of function evaluations",
)
parser.add_argument(
"--parallel_mode",
action="store_true",
help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
)
args = parser.parse_args()
device = args.device
if device == "cuda" and not torch.cuda.is_available():
print("CUDA is not available but --device is set to cuda, using CPU instead")
device = "cpu"
start_time = time.perf_counter()
run_dir = args.run_dir
paths = sorted(args.in_dir.glob(f"**/*{args.suffix}"))
if args.parallel_mode:
random.shuffle(paths)
if len(paths) == 0:
print(f"No {args.suffix} files found in the following path: {args.in_dir}")
return
pbar = tqdm(paths)
for path in pbar:
out_path = args.out_dir / path.relative_to(args.in_dir)
if args.parallel_mode and out_path.exists():
continue
pbar.set_description(f"Processing {out_path}")
dwav, sr = torchaudio.load(path)
dwav = dwav.mean(0)
if args.denoise_only:
hwav, sr = denoise(
dwav=dwav,
sr=sr,
device=device,
run_dir=args.run_dir,
)
else:
hwav, sr = enhance(
dwav=dwav,
sr=sr,
device=device,
nfe=args.nfe,
solver=args.solver,
lambd=args.lambd,
tau=args.tau,
run_dir=run_dir,
)
out_path.parent.mkdir(parents=True, exist_ok=True)
torchaudio.save(out_path, hwav[None], sr)
# Cool emoji effect saying the job is done
elapsed_time = time.perf_counter() - start_time
print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s")
if __name__ == "__main__":
main()
|