zhzluke96
update
32b2aaa
raw
history blame
3.54 kB
import argparse
import random
import time
from pathlib import Path
import torch
import torchaudio
from tqdm import tqdm
from .inference import denoise, enhance
@torch.inference_mode()
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
parser.add_argument("out_dir", type=Path, help="Output folder")
parser.add_argument(
"--run_dir",
type=Path,
default=None,
help="Path to the enhancer run folder, if None, use the default model",
)
parser.add_argument(
"--suffix",
type=str,
default=".wav",
help="Audio file suffix",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for computation, recommended to use CUDA",
)
parser.add_argument(
"--denoise_only",
action="store_true",
help="Only apply denoising without enhancement",
)
parser.add_argument(
"--lambd",
type=float,
default=1.0,
help="Denoise strength for enhancement (0.0 to 1.0)",
)
parser.add_argument(
"--tau",
type=float,
default=0.5,
help="CFM prior temperature (0.0 to 1.0)",
)
parser.add_argument(
"--solver",
type=str,
default="midpoint",
choices=["midpoint", "rk4", "euler"],
help="Numerical solver to use",
)
parser.add_argument(
"--nfe",
type=int,
default=64,
help="Number of function evaluations",
)
parser.add_argument(
"--parallel_mode",
action="store_true",
help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
)
args = parser.parse_args()
device = args.device
if device == "cuda" and not torch.cuda.is_available():
print("CUDA is not available but --device is set to cuda, using CPU instead")
device = "cpu"
start_time = time.perf_counter()
run_dir = args.run_dir
paths = sorted(args.in_dir.glob(f"**/*{args.suffix}"))
if args.parallel_mode:
random.shuffle(paths)
if len(paths) == 0:
print(f"No {args.suffix} files found in the following path: {args.in_dir}")
return
pbar = tqdm(paths)
for path in pbar:
out_path = args.out_dir / path.relative_to(args.in_dir)
if args.parallel_mode and out_path.exists():
continue
pbar.set_description(f"Processing {out_path}")
dwav, sr = torchaudio.load(path)
dwav = dwav.mean(0)
if args.denoise_only:
hwav, sr = denoise(
dwav=dwav,
sr=sr,
device=device,
run_dir=args.run_dir,
)
else:
hwav, sr = enhance(
dwav=dwav,
sr=sr,
device=device,
nfe=args.nfe,
solver=args.solver,
lambd=args.lambd,
tau=args.tau,
run_dir=run_dir,
)
out_path.parent.mkdir(parents=True, exist_ok=True)
torchaudio.save(out_path, hwav[None], sr)
# Cool emoji effect saying the job is done
elapsed_time = time.perf_counter() - start_time
print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s")
if __name__ == "__main__":
main()