Spaces:
Runtime error
Runtime error
import datetime | |
import logging | |
import os | |
import pickle | |
from typing import Dict, NoReturn | |
import librosa | |
import numpy as np | |
import yaml | |
def create_logging(log_dir: str, filemode: str) -> logging: | |
r"""Create logging to write out log files. | |
Args: | |
logs_dir, str, directory to write out logs | |
filemode: str, e.g., "w" | |
Returns: | |
logging | |
""" | |
os.makedirs(log_dir, exist_ok=True) | |
i1 = 0 | |
while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): | |
i1 += 1 | |
log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", | |
datefmt="%a, %d %b %Y %H:%M:%S", | |
filename=log_path, | |
filemode=filemode, | |
) | |
# Print to console | |
console = logging.StreamHandler() | |
console.setLevel(logging.INFO) | |
formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") | |
console.setFormatter(formatter) | |
logging.getLogger("").addHandler(console) | |
return logging | |
def load_audio( | |
audio_path: str, | |
mono: bool, | |
sample_rate: float, | |
offset: float = 0.0, | |
duration: float = None, | |
) -> np.array: | |
r"""Load audio. | |
Args: | |
audio_path: str | |
mono: bool | |
sample_rate: float | |
""" | |
audio, _ = librosa.core.load( | |
audio_path, sr=sample_rate, mono=mono, offset=offset, duration=duration | |
) | |
# (audio_samples,) | (channels_num, audio_samples) | |
if audio.ndim == 1: | |
audio = audio[None, :] | |
# (1, audio_samples,) | |
return audio | |
def load_random_segment( | |
audio_path: str, random_state, segment_seconds: float, mono: bool, sample_rate: int | |
) -> np.array: | |
r"""Randomly select an audio segment from a recording.""" | |
duration = librosa.get_duration(filename=audio_path) | |
start_time = random_state.uniform(0.0, duration - segment_seconds) | |
audio = load_audio( | |
audio_path=audio_path, | |
mono=mono, | |
sample_rate=sample_rate, | |
offset=start_time, | |
duration=segment_seconds, | |
) | |
# (channels_num, audio_samples) | |
return audio | |
def float32_to_int16(x: np.float32) -> np.int16: | |
x = np.clip(x, a_min=-1, a_max=1) | |
return (x * 32767.0).astype(np.int16) | |
def int16_to_float32(x: np.int16) -> np.float32: | |
return (x / 32767.0).astype(np.float32) | |
def read_yaml(config_yaml: str): | |
with open(config_yaml, "r") as fr: | |
configs = yaml.load(fr, Loader=yaml.FullLoader) | |
return configs | |
def check_configs_gramma(configs: Dict) -> NoReturn: | |
r"""Check if the gramma of the config dictionary for training is legal.""" | |
input_source_types = configs['train']['input_source_types'] | |
for augmentation_type in configs['train']['augmentations'].keys(): | |
augmentation_dict = configs['train']['augmentations'][augmentation_type] | |
for source_type in augmentation_dict.keys(): | |
if source_type not in input_source_types: | |
error_msg = ( | |
"The source type '{}'' in configs['train']['augmentations']['{}'] " | |
"must be one of input_source_types {}".format( | |
source_type, augmentation_type, input_source_types | |
) | |
) | |
raise Exception(error_msg) | |
def magnitude_to_db(x: float) -> float: | |
eps = 1e-10 | |
return 20.0 * np.log10(max(x, eps)) | |
def db_to_magnitude(x: float) -> float: | |
return 10.0 ** (x / 20) | |
def get_pitch_shift_factor(shift_pitch: float) -> float: | |
r"""The factor of the audio length to be scaled.""" | |
return 2 ** (shift_pitch / 12) | |
class StatisticsContainer(object): | |
def __init__(self, statistics_path): | |
self.statistics_path = statistics_path | |
self.backup_statistics_path = "{}_{}.pkl".format( | |
os.path.splitext(self.statistics_path)[0], | |
datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), | |
) | |
self.statistics_dict = {"train": [], "test": []} | |
def append(self, steps, statistics, split): | |
statistics["steps"] = steps | |
self.statistics_dict[split].append(statistics) | |
def dump(self): | |
pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) | |
pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) | |
logging.info(" Dump statistics to {}".format(self.statistics_path)) | |
logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) | |
''' | |
def load_state_dict(self, resume_steps): | |
self.statistics_dict = pickle.load(open(self.statistics_path, "rb")) | |
resume_statistics_dict = {"train": [], "test": []} | |
for key in self.statistics_dict.keys(): | |
for statistics in self.statistics_dict[key]: | |
if statistics["steps"] <= resume_steps: | |
resume_statistics_dict[key].append(statistics) | |
self.statistics_dict = resume_statistics_dict | |
''' | |
def calculate_sdr(ref: np.array, est: np.array) -> float: | |
s_true = ref | |
s_artif = est - ref | |
sdr = 10.0 * ( | |
np.log10(np.clip(np.mean(s_true ** 2), 1e-8, np.inf)) | |
- np.log10(np.clip(np.mean(s_artif ** 2), 1e-8, np.inf)) | |
) | |
return sdr | |