|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
import hashlib |
|
import math |
|
import json |
|
from pathlib import Path |
|
|
|
import julius |
|
import torch as th |
|
from torch import distributed |
|
import torchaudio as ta |
|
from torch.nn import functional as F |
|
|
|
from .audio import convert_audio_channels |
|
from .compressed import get_musdb_tracks |
|
|
|
MIXTURE = "mixture" |
|
EXT = ".wav" |
|
|
|
|
|
def _track_metadata(track, sources): |
|
track_length = None |
|
track_samplerate = None |
|
for source in sources + [MIXTURE]: |
|
file = track / f"{source}{EXT}" |
|
info = ta.info(str(file)) |
|
length = info.num_frames |
|
if track_length is None: |
|
track_length = length |
|
track_samplerate = info.sample_rate |
|
elif track_length != length: |
|
raise ValueError( |
|
f"Invalid length for file {file}: " |
|
f"expecting {track_length} but got {length}.") |
|
elif info.sample_rate != track_samplerate: |
|
raise ValueError( |
|
f"Invalid sample rate for file {file}: " |
|
f"expecting {track_samplerate} but got {info.sample_rate}.") |
|
if source == MIXTURE: |
|
wav, _ = ta.load(str(file)) |
|
wav = wav.mean(0) |
|
mean = wav.mean().item() |
|
std = wav.std().item() |
|
|
|
return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate} |
|
|
|
|
|
def _build_metadata(path, sources): |
|
meta = {} |
|
path = Path(path) |
|
for file in path.iterdir(): |
|
meta[file.name] = _track_metadata(file, sources) |
|
return meta |
|
|
|
|
|
class Wavset: |
|
def __init__( |
|
self, |
|
root, metadata, sources, |
|
length=None, stride=None, normalize=True, |
|
samplerate=44100, channels=2): |
|
""" |
|
Waveset (or mp3 set for that matter). Can be used to train |
|
with arbitrary sources. Each track should be one folder inside of `path`. |
|
The folder should contain files named `{source}.{ext}`. |
|
Files will be grouped according to `sources` (each source is a list of |
|
filenames). |
|
|
|
Sample rate and channels will be converted on the fly. |
|
|
|
`length` is the sample size to extract (in samples, not duration). |
|
`stride` is how many samples to move by between each example. |
|
""" |
|
self.root = Path(root) |
|
self.metadata = OrderedDict(metadata) |
|
self.length = length |
|
self.stride = stride or length |
|
self.normalize = normalize |
|
self.sources = sources |
|
self.channels = channels |
|
self.samplerate = samplerate |
|
self.num_examples = [] |
|
for name, meta in self.metadata.items(): |
|
track_length = int(self.samplerate * meta['length'] / meta['samplerate']) |
|
if length is None or track_length < length: |
|
examples = 1 |
|
else: |
|
examples = int(math.ceil((track_length - self.length) / self.stride) + 1) |
|
self.num_examples.append(examples) |
|
|
|
def __len__(self): |
|
return sum(self.num_examples) |
|
|
|
def get_file(self, name, source): |
|
return self.root / name / f"{source}{EXT}" |
|
|
|
def __getitem__(self, index): |
|
for name, examples in zip(self.metadata, self.num_examples): |
|
if index >= examples: |
|
index -= examples |
|
continue |
|
meta = self.metadata[name] |
|
num_frames = -1 |
|
offset = 0 |
|
if self.length is not None: |
|
offset = int(math.ceil( |
|
meta['samplerate'] * self.stride * index / self.samplerate)) |
|
num_frames = int(math.ceil( |
|
meta['samplerate'] * self.length / self.samplerate)) |
|
wavs = [] |
|
for source in self.sources: |
|
file = self.get_file(name, source) |
|
wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames) |
|
wav = convert_audio_channels(wav, self.channels) |
|
wavs.append(wav) |
|
|
|
example = th.stack(wavs) |
|
example = julius.resample_frac(example, meta['samplerate'], self.samplerate) |
|
if self.normalize: |
|
example = (example - meta['mean']) / meta['std'] |
|
if self.length: |
|
example = example[..., :self.length] |
|
example = F.pad(example, (0, self.length - example.shape[-1])) |
|
return example |
|
|
|
|
|
def get_wav_datasets(args, samples, sources): |
|
sig = hashlib.sha1(str(args.wav).encode()).hexdigest()[:8] |
|
metadata_file = args.metadata / (sig + ".json") |
|
train_path = args.wav / "train" |
|
valid_path = args.wav / "valid" |
|
if not metadata_file.is_file() and args.rank == 0: |
|
train = _build_metadata(train_path, sources) |
|
valid = _build_metadata(valid_path, sources) |
|
json.dump([train, valid], open(metadata_file, "w")) |
|
if args.world_size > 1: |
|
distributed.barrier() |
|
train, valid = json.load(open(metadata_file)) |
|
train_set = Wavset(train_path, train, sources, |
|
length=samples, stride=args.data_stride, |
|
samplerate=args.samplerate, channels=args.audio_channels, |
|
normalize=args.norm_wav) |
|
valid_set = Wavset(valid_path, valid, [MIXTURE] + sources, |
|
samplerate=args.samplerate, channels=args.audio_channels, |
|
normalize=args.norm_wav) |
|
return train_set, valid_set |
|
|
|
|
|
def get_musdb_wav_datasets(args, samples, sources): |
|
metadata_file = args.metadata / "musdb_wav.json" |
|
root = args.musdb / "train" |
|
if not metadata_file.is_file() and args.rank == 0: |
|
metadata = _build_metadata(root, sources) |
|
json.dump(metadata, open(metadata_file, "w")) |
|
if args.world_size > 1: |
|
distributed.barrier() |
|
metadata = json.load(open(metadata_file)) |
|
|
|
train_tracks = get_musdb_tracks(args.musdb, is_wav=True, subsets=["train"], split="train") |
|
metadata_train = {name: meta for name, meta in metadata.items() if name in train_tracks} |
|
metadata_valid = {name: meta for name, meta in metadata.items() if name not in train_tracks} |
|
train_set = Wavset(root, metadata_train, sources, |
|
length=samples, stride=args.data_stride, |
|
samplerate=args.samplerate, channels=args.audio_channels, |
|
normalize=args.norm_wav) |
|
valid_set = Wavset(root, metadata_valid, [MIXTURE] + sources, |
|
samplerate=args.samplerate, channels=args.audio_channels, |
|
normalize=args.norm_wav) |
|
return train_set, valid_set |
|
|