Spaces:
Running
Running
import argparse | |
import os | |
import time | |
import warnings | |
from typing import Optional, Tuple, Union | |
import torch | |
import torchaudio as ta | |
from loguru import logger | |
from numpy import ndarray | |
from torch import Tensor, nn | |
from torch.nn import functional as F | |
from torchaudio.backend.common import AudioMetaData | |
import df_local | |
from df_local import config | |
from df_local.checkpoint import load_model as load_model_cp | |
from df_local.logger import init_logger, warn_once | |
from df_local.model import ModelParams | |
from df_local.modules import get_device | |
from df_local.utils import as_complex, as_real, get_norm_alpha, resample | |
from libdf import DF, erb, erb_norm, unit_norm | |
def main(args): | |
model, df_state, suffix = init_df( | |
args.model_base_dir, | |
post_filter=args.pf, | |
log_level=args.log_level, | |
config_allow_defaults=True, | |
epoch=args.epoch, | |
) | |
if args.output_dir is None: | |
args.output_dir = "." | |
elif not os.path.isdir(args.output_dir): | |
os.mkdir(args.output_dir) | |
df_sr = ModelParams().sr | |
n_samples = len(args.noisy_audio_files) | |
for i, file in enumerate(args.noisy_audio_files): | |
progress = (i + 1) / n_samples * 100 | |
audio, meta = load_audio(file, df_sr) | |
t0 = time.time() | |
audio = enhance( | |
model, df_state, audio, pad=args.compensate_delay, atten_lim_db=args.atten_lim | |
) | |
t1 = time.time() | |
t_audio = audio.shape[-1] / df_sr | |
t = t1 - t0 | |
rtf = t / t_audio | |
fn = os.path.basename(file) | |
p_str = f"{progress:2.0f}% | " if n_samples > 1 else "" | |
logger.info(f"{p_str}Enhanced noisy audio file '{fn}' in {t:.1f}s (RT factor: {rtf:.3f})") | |
audio = resample(audio, df_sr, meta.sample_rate) | |
save_audio( | |
file, audio, sr=meta.sample_rate, output_dir=args.output_dir, suffix=suffix, log=False | |
) | |
def init_df( | |
model_base_dir: Optional[str] = None, | |
post_filter: bool = False, | |
log_level: str = "INFO", | |
log_file: Optional[str] = "enhance.log", | |
config_allow_defaults: bool = False, | |
epoch: Union[str, int, None] = "best", | |
default_model: str = "DeepFilterNet2", | |
) -> Tuple[nn.Module, DF, str]: | |
"""Initializes and loads config, model and deep filtering state. | |
Args: | |
model_base_dir (str): Path to the model directory containing checkpoint and config. If None, | |
load the pretrained DeepFilterNet2 model. | |
post_filter (bool): Enable post filter for some minor, extra noise reduction. | |
log_level (str): Control amount of logging. Defaults to `INFO`. | |
log_file (str): Optional log file name. None disables it. Defaults to `enhance.log`. | |
config_allow_defaults (bool): Whether to allow initializing new config values with defaults. | |
epoch (str): Checkpoint epoch to load. Options are `best`, `latest`, `<int>`, and `none`. | |
`none` disables checkpoint loading. Defaults to `best`. | |
Returns: | |
model (nn.Modules): Intialized model, moved to GPU if available. | |
df_state (DF): Deep filtering state for stft/istft/erb | |
suffix (str): Suffix based on the model name. This can be used for saving the enhanced | |
audio. | |
""" | |
try: | |
from icecream import ic, install | |
ic.configureOutput(includeContext=True) | |
install() | |
except ImportError: | |
pass | |
use_default_model = False | |
if model_base_dir == "DeepFilterNet": | |
default_model = "DeepFilterNet" | |
use_default_model = True | |
elif model_base_dir == "DeepFilterNet2": | |
use_default_model = True | |
if model_base_dir is None or use_default_model: | |
use_default_model = True | |
model_base_dir = os.path.relpath( | |
os.path.join( | |
os.path.dirname(df_local.__file__), os.pardir, "pretrained_models", default_model | |
) | |
) | |
if not os.path.isdir(model_base_dir): | |
raise NotADirectoryError("Base directory not found at {}".format(model_base_dir)) | |
log_file = os.path.join(model_base_dir, log_file) if log_file is not None else None | |
init_logger(file=log_file, level=log_level, model=model_base_dir) | |
if use_default_model: | |
logger.info(f"Using {default_model} model at {model_base_dir}") | |
config.load( | |
os.path.join(model_base_dir, "config.ini"), | |
config_must_exist=True, | |
allow_defaults=config_allow_defaults, | |
allow_reload=True, | |
) | |
if post_filter: | |
config.set("mask_pf", True, bool, ModelParams().section) | |
logger.info("Running with post-filter") | |
p = ModelParams() | |
df_state = DF( | |
sr=p.sr, | |
fft_size=p.fft_size, | |
hop_size=p.hop_size, | |
nb_bands=p.nb_erb, | |
min_nb_erb_freqs=p.min_nb_freqs, | |
) | |
checkpoint_dir = os.path.join(model_base_dir, "checkpoints") | |
load_cp = epoch is not None and not (isinstance(epoch, str) and epoch.lower() == "none") | |
if not load_cp: | |
checkpoint_dir = None | |
try: | |
mask_only = config.get("mask_only", cast=bool, section="train") | |
except KeyError: | |
mask_only = False | |
model, epoch = load_model_cp(checkpoint_dir, df_state, epoch=epoch, mask_only=mask_only) | |
if (epoch is None or epoch == 0) and load_cp: | |
logger.error("Could not find a checkpoint") | |
exit(1) | |
logger.debug(f"Loaded checkpoint from epoch {epoch}") | |
model = model.to(get_device()) | |
# Set suffix to model name | |
suffix = os.path.basename(os.path.abspath(model_base_dir)) | |
if post_filter: | |
suffix += "_pf" | |
logger.info("Model loaded") | |
return model, df_state, suffix | |
def df_features(audio: Tensor, df: DF, nb_df: int, device=None) -> Tuple[Tensor, Tensor, Tensor]: | |
spec = df.analysis(audio.numpy()) # [C, Tf] -> [C, Tf, F] | |
a = get_norm_alpha(False) | |
erb_fb = df.erb_widths() | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore", UserWarning) | |
erb_feat = torch.as_tensor(erb_norm(erb(spec, erb_fb), a)).unsqueeze(1) | |
spec_feat = as_real(torch.as_tensor(unit_norm(spec[..., :nb_df], a)).unsqueeze(1)) | |
spec = as_real(torch.as_tensor(spec).unsqueeze(1)) | |
if device is not None: | |
spec = spec.to(device) | |
erb_feat = erb_feat.to(device) | |
spec_feat = spec_feat.to(device) | |
return spec, erb_feat, spec_feat | |
def load_audio( | |
file: str, sr: Optional[int], verbose=True, **kwargs | |
) -> Tuple[Tensor, AudioMetaData]: | |
"""Loads an audio file using torchaudio. | |
Args: | |
file (str): Path to an audio file. | |
sr (int): Optionally resample audio to specified target sampling rate. | |
**kwargs: Passed to torchaudio.load(). Depends on the backend. The resample method | |
may be set via `method` which is passed to `resample()`. | |
Returns: | |
audio (Tensor): Audio tensor of shape [C, T], if channels_first=True (default). | |
info (AudioMetaData): Meta data of the original audio file. Contains the original sr. | |
""" | |
ikwargs = {} | |
if "format" in kwargs: | |
ikwargs["format"] = kwargs["format"] | |
rkwargs = {} | |
if "method" in kwargs: | |
rkwargs["method"] = kwargs.pop("method") | |
info: AudioMetaData = ta.info(file, **ikwargs) | |
audio, orig_sr = ta.load(file, **kwargs) | |
if sr is not None and orig_sr != sr: | |
if verbose: | |
warn_once( | |
f"Audio sampling rate does not match model sampling rate ({orig_sr}, {sr}). " | |
"Resampling..." | |
) | |
audio = resample(audio, orig_sr, sr, **rkwargs) | |
return audio, info | |
def save_audio( | |
file: str, | |
audio: Union[Tensor, ndarray], | |
sr: int, | |
output_dir: Optional[str] = None, | |
suffix: Optional[str] = None, | |
log: bool = False, | |
dtype=torch.int16, | |
): | |
outpath = file | |
if suffix is not None: | |
file, ext = os.path.splitext(file) | |
outpath = file + f"_{suffix}" + ext | |
if output_dir is not None: | |
outpath = os.path.join(output_dir, os.path.basename(outpath)) | |
if log: | |
logger.info(f"Saving audio file '{outpath}'") | |
audio = torch.as_tensor(audio) | |
if audio.ndim == 1: | |
audio.unsqueeze_(0) | |
if dtype == torch.int16 and audio.dtype != torch.int16: | |
audio = (audio * (1 << 15)).to(torch.int16) | |
if dtype == torch.float32 and audio.dtype != torch.float32: | |
audio = audio.to(torch.float32) / (1 << 15) | |
ta.save(outpath, audio, sr) | |
def enhance( | |
model: nn.Module, df_state: DF, audio: Tensor, pad=False, atten_lim_db: Optional[float] = None | |
): | |
model.eval() | |
bs = audio.shape[0] | |
if hasattr(model, "reset_h0"): | |
model.reset_h0(batch_size=bs, device=get_device()) | |
orig_len = audio.shape[-1] | |
n_fft, hop = 0, 0 | |
if pad: | |
n_fft, hop = df_state.fft_size(), df_state.hop_size() | |
# Pad audio to compensate for the delay due to the real-time STFT implementation | |
audio = F.pad(audio, (0, n_fft)) | |
nb_df = getattr(model, "nb_df", getattr(model, "df_bins", ModelParams().nb_df)) | |
spec, erb_feat, spec_feat = df_features(audio, df_state, nb_df, device=get_device()) | |
enhanced = model(spec, erb_feat, spec_feat)[0].cpu() | |
enhanced = as_complex(enhanced.squeeze(1)) | |
if atten_lim_db is not None and abs(atten_lim_db) > 0: | |
lim = 10 ** (-abs(atten_lim_db) / 20) | |
enhanced = as_complex(spec.squeeze(1)) * lim + enhanced * (1 - lim) | |
audio = torch.as_tensor(df_state.synthesis(enhanced.numpy())) | |
if pad: | |
# The frame size is equal to p.hop_size. Given a new frame, the STFT loop requires e.g. | |
# ceil((n_fft-hop)/hop). I.e. for 50% overlap, then hop=n_fft//2 | |
# requires 1 additional frame lookahead; 75% requires 3 additional frames lookahead. | |
# Thus, the STFT/ISTFT loop introduces an algorithmic delay of n_fft - hop. | |
assert n_fft % hop == 0 # This is only tested for 50% and 75% overlap | |
d = n_fft - hop | |
audio = audio[:, d : orig_len + d] | |
return audio | |
def parse_epoch_type(value: str) -> Union[int, str]: | |
try: | |
return int(value) | |
except ValueError: | |
assert value in ("best", "latest") | |
return value | |
def setup_df_argument_parser(default_log_level: str = "INFO") -> argparse.ArgumentParser: | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model-base-dir", | |
"-m", | |
type=str, | |
default=None, | |
help="Model directory containing checkpoints and config. " | |
"To load a pretrained model, you may just provide the model name, e.g. `DeepFilterNet`. " | |
"By default, the pretrained DeepFilterNet2 model is loaded.", | |
) | |
parser.add_argument( | |
"--pf", | |
help="Post-filter that slightly over-attenuates very noisy sections.", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--output-dir", | |
"-o", | |
type=str, | |
default=None, | |
help="Directory in which the enhanced audio files will be stored.", | |
) | |
parser.add_argument( | |
"--log-level", | |
type=str, | |
default=default_log_level, | |
help="Logger verbosity. Can be one of (debug, info, error, none)", | |
) | |
parser.add_argument("--debug", "-d", action="store_const", const="DEBUG", dest="log_level") | |
parser.add_argument( | |
"--epoch", | |
"-e", | |
default="best", | |
type=parse_epoch_type, | |
help="Epoch for checkpoint loading. Can be one of ['best', 'latest', <int>].", | |
) | |
return parser | |
def run(): | |
parser = setup_df_argument_parser() | |
parser.add_argument( | |
"--compensate-delay", | |
"-D", | |
action="store_true", | |
help="Add some paddig to compensate the delay introduced by the real-time STFT/ISTFT implementation.", | |
) | |
parser.add_argument( | |
"--atten-lim", | |
"-a", | |
type=int, | |
default=None, | |
help="Attenuation limit in dB by mixing the enhanced signal with the noisy signal.", | |
) | |
parser.add_argument( | |
"noisy_audio_files", | |
type=str, | |
nargs="+", | |
help="List of noise files to mix with the clean speech file.", | |
) | |
main(parser.parse_args()) | |
if __name__ == "__main__": | |
run() | |