Spaces:
Running
Running
import os | |
import string | |
from configparser import ConfigParser | |
from shlex import shlex | |
from typing import Any, List, Optional, Tuple, Type, TypeVar, Union | |
from loguru import logger | |
T = TypeVar("T") | |
class DfParams: | |
def __init__(self): | |
# Sampling rate used for training | |
self.sr: int = config("SR", cast=int, default=48_000, section="DF") | |
# FFT size in samples | |
self.fft_size: int = config("FFT_SIZE", cast=int, default=960, section="DF") | |
# STFT Hop size in samples | |
self.hop_size: int = config("HOP_SIZE", cast=int, default=480, section="DF") | |
# Number of ERB bands | |
self.nb_erb: int = config("NB_ERB", cast=int, default=32, section="DF") | |
# Number of deep filtering bins; DF is applied from 0th to nb_df-th frequency bins | |
self.nb_df: int = config("NB_DF", cast=int, default=96, section="DF") | |
# Normalization decay factor; used for complex and erb features | |
self.norm_tau: float = config("NORM_TAU", 1, float, section="DF") | |
# Local SNR minimum value, ground truth will be truncated | |
self.lsnr_max: int = config("LSNR_MAX", 35, int, section="DF") | |
# Local SNR maximum value, ground truth will be truncated | |
self.lsnr_min: int = config("LSNR_MIN", -15, int, section="DF") | |
# Minimum number of frequency bins per ERB band | |
self.min_nb_freqs = config("MIN_NB_ERB_FREQS", 2, int, section="DF") | |
# Deep Filtering order | |
self.df_order: int = config("DF_ORDER", cast=int, default=5, section="DF") | |
# Deep Filtering look-ahead | |
self.df_lookahead: int = config("DF_LOOKAHEAD", cast=int, default=0, section="DF") | |
# Pad mode. By default, padding will be handled on the input side: | |
# - `input`, which pads the input features passed to the model | |
# - `output`, which pads the output spectrogram corresponding to `df_lookahead` | |
self.pad_mode: str = config("PAD_MODE", default="input_specf", section="DF") | |
class Config: | |
"""Adopted from python-decouple""" | |
DEFAULT_SECTION = "settings" | |
def __init__(self): | |
self.parser: ConfigParser = None # type: ignore | |
self.path = "" | |
self.modified = False | |
self.allow_defaults = True | |
def load( | |
self, path: Optional[str], config_must_exist=False, allow_defaults=True, allow_reload=False | |
): | |
self.allow_defaults = allow_defaults | |
if self.parser is not None and not allow_reload: | |
raise ValueError("Config already loaded") | |
self.parser = ConfigParser() | |
self.path = path | |
if path is not None and os.path.isfile(path): | |
with open(path) as f: | |
self.parser.read_file(f) | |
else: | |
if config_must_exist: | |
raise ValueError(f"No config file found at '{path}'.") | |
if not self.parser.has_section(self.DEFAULT_SECTION): | |
self.parser.add_section(self.DEFAULT_SECTION) | |
self._fix_clc() | |
self._fix_df() | |
def use_defaults(self): | |
self.load(path=None, config_must_exist=False) | |
def save(self, path: str): | |
if not self.modified: | |
logger.debug("Config not modified. No need to overwrite on disk.") | |
return | |
if self.parser is None: | |
self.parser = ConfigParser() | |
for section in self.parser.sections(): | |
if len(self.parser[section]) == 0: | |
self.parser.remove_section(section) | |
with open(path, mode="w") as f: | |
self.parser.write(f) | |
def tostr(self, value, cast): | |
if isinstance(cast, Csv) and isinstance(value, (tuple, list)): | |
return "".join(str(v) + cast.delimiter for v in value)[:-1] | |
return str(value) | |
def set(self, option: str, value: T, cast: Type[T], section: Optional[str] = None) -> T: | |
section = self.DEFAULT_SECTION if section is None else section | |
section = section.lower() | |
if not self.parser.has_section(section): | |
self.parser.add_section(section) | |
if self.parser.has_option(section, option): | |
if value == self.cast(self.parser.get(section, option), cast): | |
return value | |
self.modified = True | |
self.parser.set(section, option, self.tostr(value, cast)) | |
return value | |
def __call__( | |
self, | |
option: str, | |
default: Any = None, | |
cast: Type[T] = str, | |
save: bool = True, | |
section: Optional[str] = None, | |
) -> T: | |
# Get value either from an ENV or from the .ini file | |
section = self.DEFAULT_SECTION if section is None else section | |
value = None | |
if self.parser is None: | |
raise ValueError("No configuration loaded") | |
if not self.parser.has_section(section.lower()): | |
self.parser.add_section(section.lower()) | |
if option in os.environ: | |
value = os.environ[option] | |
if save: | |
self.parser.set(section, option, self.tostr(value, cast)) | |
elif self.parser.has_option(section, option): | |
value = self.read_from_section(section, option, default, cast=cast, save=save) | |
elif self.parser.has_option(section.lower(), option): | |
value = self.read_from_section(section.lower(), option, default, cast=cast, save=save) | |
elif self.parser.has_option(self.DEFAULT_SECTION, option): | |
logger.warning( | |
f"Couldn't find option {option} in section {section}. " | |
"Falling back to default settings section." | |
) | |
value = self.read_from_section(self.DEFAULT_SECTION, option, cast=cast, save=save) | |
elif default is None: | |
raise ValueError("Value {} not found.".format(option)) | |
elif not self.allow_defaults and save: | |
raise ValueError(f"Value '{option}' not found in config (defaults not allowed).") | |
else: | |
value = default | |
if save: | |
self.set(option, value, cast, section) | |
return self.cast(value, cast) | |
def cast(self, value, cast): | |
# Do the casting to get the correct type | |
if cast is bool: | |
value = str(value).lower() | |
if value in {"true", "yes", "y", "on", "1"}: | |
return True # type: ignore | |
elif value in {"false", "no", "n", "off", "0"}: | |
return False # type: ignore | |
raise ValueError("Parse error") | |
return cast(value) | |
def get(self, option: str, cast: Type[T] = str, section: Optional[str] = None) -> T: | |
section = self.DEFAULT_SECTION if section is None else section | |
if not self.parser.has_section(section): | |
raise KeyError(section) | |
if not self.parser.has_option(section, option): | |
raise KeyError(option) | |
return self.cast(self.parser.get(section, option), cast) | |
def read_from_section( | |
self, section: str, option: str, default: Any = None, cast: Type = str, save: bool = True | |
) -> str: | |
value = self.parser.get(section, option) | |
if not save: | |
# Set to default or remove to not read it at trainig start again | |
if default is None: | |
self.parser.remove_option(section, option) | |
elif not self.allow_defaults: | |
raise ValueError(f"Value '{option}' not found in config (defaults not allowed).") | |
else: | |
self.parser.set(section, option, self.tostr(default, cast)) | |
elif section.lower() != section: | |
self.parser.set(section.lower(), option, self.tostr(value, cast)) | |
self.parser.remove_option(section, option) | |
self.modified = True | |
return value | |
def overwrite(self, section: str, option: str, value: Any): | |
if not self.parser.has_section(section): | |
return ValueError(f"Section not found: '{section}'") | |
if not self.parser.has_option(section, option): | |
return ValueError(f"Option not found '{option}' in section '{section}'") | |
self.modified = True | |
cast = type(value) | |
return self.parser.set(section, option, self.tostr(value, cast)) | |
def _fix_df(self): | |
"""Renaming of some groups/options for compatibility with old models.""" | |
if self.parser.has_section("deepfilternet") and self.parser.has_section("df"): | |
sec_deepfilternet = self.parser["deepfilternet"] | |
sec_df = self.parser["df"] | |
if "df_order" in sec_deepfilternet: | |
sec_df["df_order"] = sec_deepfilternet["df_order"] | |
del sec_deepfilternet["df_order"] | |
if "df_lookahead" in sec_deepfilternet: | |
sec_df["df_lookahead"] = sec_deepfilternet["df_lookahead"] | |
del sec_deepfilternet["df_lookahead"] | |
def _fix_clc(self): | |
"""Renaming of some groups/options for compatibility with old models.""" | |
if ( | |
not self.parser.has_section("deepfilternet") | |
and self.parser.has_section("train") | |
and self.parser.get("train", "model") == "convgru5" | |
): | |
self.overwrite("train", "model", "deepfilternet") | |
self.parser.add_section("deepfilternet") | |
self.parser["deepfilternet"] = self.parser["convgru"] | |
del self.parser["convgru"] | |
if not self.parser.has_section("df") and self.parser.has_section("clc"): | |
self.parser["df"] = self.parser["clc"] | |
del self.parser["clc"] | |
for section in self.parser.sections(): | |
for k, v in self.parser[section].items(): | |
if "clc" in k.lower(): | |
self.parser.set(section, k.lower().replace("clc", "df"), v) | |
del self.parser[section][k] | |
def __repr__(self): | |
msg = "" | |
for section in self.parser.sections(): | |
msg += f"{section}:\n" | |
for k, v in self.parser[section].items(): | |
msg += f" {k}: {v}\n" | |
return msg | |
config = Config() | |
class Csv(object): | |
""" | |
Produces a csv parser that return a list of transformed elements. From python-decouple. | |
""" | |
def __init__( | |
self, cast: Type[T] = str, delimiter=",", strip=string.whitespace, post_process=list | |
): | |
""" | |
Parameters: | |
cast -- callable that transforms the item just before it's added to the list. | |
delimiter -- string of delimiters chars passed to shlex. | |
strip -- string of non-relevant characters to be passed to str.strip after the split. | |
post_process -- callable to post process all casted values. Default is `list`. | |
""" | |
self.cast: Type[T] = cast | |
self.delimiter = delimiter | |
self.strip = strip | |
self.post_process = post_process | |
def __call__(self, value: Union[str, Tuple[T], List[T]]) -> List[T]: | |
"""The actual transformation""" | |
if isinstance(value, (tuple, list)): | |
# if default value is a list | |
value = "".join(str(v) + self.delimiter for v in value)[:-1] | |
def transform(s): | |
return self.cast(s.strip(self.strip)) | |
splitter = shlex(value, posix=True) | |
splitter.whitespace = self.delimiter | |
splitter.whitespace_split = True | |
return self.post_process(transform(s) for s in splitter) | |