from pathlib import Path from typing import Callable from torch import Tensor def walk_paths(root, suffix): for path in Path(root).iterdir(): if path.is_dir(): yield from walk_paths(path, suffix) elif path.suffix == suffix: yield path def rglob_audio_files(path: Path): return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac")) def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7): """ Args: fg: (b, t) bg: (b, t) """ assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}" fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps) bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps) fg_energy = fg.pow(2).sum(dim=-1, keepdim=True) bg_energy = bg.pow(2).sum(dim=-1, keepdim=True) fg = fg / (fg_energy + eps).sqrt() bg = bg / (bg_energy + eps).sqrt() if callable(alpha): alpha = alpha() assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}" mx = alpha * fg + (1 - alpha) * bg mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps) return mx