Spaces:
Build error
Build error
import os | |
import glob | |
import torch | |
import warnings | |
import torchaudio | |
import pyloudnorm as pyln | |
class AudioFile(object): | |
def __init__(self, filepath, preload=False, half=False, target_loudness=None): | |
"""Base class for audio files to handle metadata and loading. | |
Args: | |
filepath (str): Path to audio file to load from disk. | |
preload (bool, optional): If set, load audio data into RAM. Default: False | |
half (bool, optional): If set, store audio data as float16 to save space. Default: False | |
target_loudness (float, optional): Loudness normalize to dB LUFS value. Default: | |
""" | |
super().__init__() | |
self.filepath = filepath | |
self.half = half | |
self.target_loudness = target_loudness | |
self.loaded = False | |
if preload: | |
self.load() | |
num_frames = self.audio.shape[-1] | |
num_channels = self.audio.shape[0] | |
else: | |
metadata = torchaudio.info(filepath) | |
audio = None | |
self.sample_rate = metadata.sample_rate | |
num_frames = metadata.num_frames | |
num_channels = metadata.num_channels | |
self.num_frames = num_frames | |
self.num_channels = num_channels | |
def load(self): | |
audio, sr = torchaudio.load(self.filepath, normalize=True) | |
self.audio = audio | |
self.sample_rate = sr | |
if self.target_loudness is not None: | |
self.loudness_normalize() | |
if self.half: | |
self.audio = audio.half() | |
self.loaded = True | |
def loudness_normalize(self): | |
meter = pyln.Meter(self.sample_rate) | |
# conver mono to stereo | |
if self.audio.shape[0] == 1: | |
tmp_audio = self.audio.repeat(2, 1) | |
else: | |
tmp_audio = self.audio | |
# measure integrated loudness | |
input_loudness = meter.integrated_loudness(tmp_audio.numpy().T) | |
# compute and apply gain | |
gain_dB = self.target_loudness - input_loudness | |
gain_ln = 10 ** (gain_dB / 20.0) | |
self.audio *= gain_ln | |
# check for potentially clipped samples | |
if self.audio.abs().max() >= 1.0: | |
warnings.warn("Possible clipped samples in output.") | |
class AudioFileDataset(torch.utils.data.Dataset): | |
"""Base class for audio file datasets loaded from disk. | |
Datasets can be either paired or unpaired. A paired dataset requires passing the `target_dir` path. | |
Args: | |
input_dir (List[str]): List of paths to the directories containing input audio files. | |
target_dir (List[str], optional): List of paths to the directories containing correponding audio files. Default: [] | |
subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train" | |
length (int, optional): Number of samples to load for each example. Default: 65536 | |
normalize (bool, optional): Normalize audio amplitiude to -1 to 1. Default: True | |
train_frac (float, optional): Fraction of the files to use for training subset. Default: 0.8 | |
val_frac (float, optional): Fraction of the files to use for validation subset. Default: 0.1 | |
preload (bool, optional): Read audio files into RAM at the start of training. Default: False | |
num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000 | |
ext (str, optional): Expected audio file extension. Default: "wav" | |
""" | |
def __init__( | |
self, | |
input_dirs, | |
target_dirs=[], | |
subset="train", | |
length=65536, | |
normalize=True, | |
train_per=0.8, | |
val_per=0.1, | |
preload=False, | |
num_examples_per_epoch=10000, | |
ext="wav", | |
): | |
super().__init__() | |
self.input_dirs = input_dirs | |
self.target_dirs = target_dirs | |
self.subset = subset | |
self.length = length | |
self.normalize = normalize | |
self.train_per = train_per | |
self.val_per = val_per | |
self.preload = preload | |
self.num_examples_per_epoch = num_examples_per_epoch | |
self.ext = ext | |
self.input_filepaths = [] | |
for input_dir in input_dirs: | |
search_path = os.path.join(input_dir, f"*.{ext}") | |
self.input_filepaths += glob.glob(search_path) | |
self.input_filepaths = sorted(self.input_filepaths) | |
self.target_filepaths = [] | |
for target_dir in target_dirs: | |
search_path = os.path.join(target_dir, f"*.{ext}") | |
self.target_filepaths += glob.glob(search_path) | |
self.target_filepaths = sorted(self.target_filepaths) | |
# both sets must have same number of files in paired dataset | |
assert len(self.target_filepaths) == len(self.input_filepaths) | |
# get details about audio files | |
self.input_files = [] | |
for input_filepath in self.input_filepaths: | |
self.input_files.append( | |
AudioFile(input_filepath, preload=preload, normalize=normalize) | |
) | |
self.target_files = [] | |
if target_dir is not None: | |
for target_filepath in self.target_filepaths: | |
self.target_files.append( | |
AudioFile(target_filepath, preload=preload, normalize=normalize) | |
) | |
def __len__(self): | |
return self.num_examples_per_epoch | |
def __getitem__(self, idx): | |
""" """ | |
# index the current audio file | |
input_file = self.input_files[idx] | |
# load the audio data if needed | |
if not input_file.loaded: | |
input_file.load() | |
# get a random patch of size `self.length` | |
start_idx = int(torch.rand() * (input_file.num_frames - self.length)) | |
stop_idx = start_idx + self.length | |
input_audio = input_file.audio[:, start_idx:stop_idx] | |
# if there is a target file, get it (and load) | |
if len(self.target_files) > 0: | |
target_file = self.target_files[idx] | |
if not target_file.loaded: | |
target_file.load() | |
# use the same cropping indices | |
target_audio = target_file.audio[:, start_idx:stop_idx] | |
return input_audio, target_audio | |
else: | |
return input_audio | |