|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
from fractions import Fraction |
|
from concurrent import futures |
|
|
|
import musdb |
|
from torch import distributed |
|
|
|
from .audio import AudioFile |
|
|
|
|
|
def get_musdb_tracks(root, *args, **kwargs): |
|
mus = musdb.DB(root, *args, **kwargs) |
|
return {track.name: track.path for track in mus} |
|
|
|
|
|
class StemsSet: |
|
def __init__(self, tracks, metadata, duration=None, stride=1, |
|
samplerate=44100, channels=2, streams=slice(None)): |
|
|
|
self.metadata = [] |
|
for name, path in tracks.items(): |
|
meta = dict(metadata[name]) |
|
meta["path"] = path |
|
meta["name"] = name |
|
self.metadata.append(meta) |
|
if duration is not None and meta["duration"] < duration: |
|
raise ValueError(f"Track {name} duration is too small {meta['duration']}") |
|
self.metadata.sort(key=lambda x: x["name"]) |
|
self.duration = duration |
|
self.stride = stride |
|
self.channels = channels |
|
self.samplerate = samplerate |
|
self.streams = streams |
|
|
|
def __len__(self): |
|
return sum(self._examples_count(m) for m in self.metadata) |
|
|
|
def _examples_count(self, meta): |
|
if self.duration is None: |
|
return 1 |
|
else: |
|
return int((meta["duration"] - self.duration) // self.stride + 1) |
|
|
|
def track_metadata(self, index): |
|
for meta in self.metadata: |
|
examples = self._examples_count(meta) |
|
if index >= examples: |
|
index -= examples |
|
continue |
|
return meta |
|
|
|
def __getitem__(self, index): |
|
for meta in self.metadata: |
|
examples = self._examples_count(meta) |
|
if index >= examples: |
|
index -= examples |
|
continue |
|
streams = AudioFile(meta["path"]).read(seek_time=index * self.stride, |
|
duration=self.duration, |
|
channels=self.channels, |
|
samplerate=self.samplerate, |
|
streams=self.streams) |
|
return (streams - meta["mean"]) / meta["std"] |
|
|
|
|
|
def _get_track_metadata(path): |
|
|
|
|
|
audio = AudioFile(path) |
|
mix = audio.read(streams=0, channels=1, samplerate=44100) |
|
return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()} |
|
|
|
|
|
def _build_metadata(tracks, workers=10): |
|
pendings = [] |
|
with futures.ProcessPoolExecutor(workers) as pool: |
|
for name, path in tracks.items(): |
|
pendings.append((name, pool.submit(_get_track_metadata, path))) |
|
return {name: p.result() for name, p in pendings} |
|
|
|
|
|
def _build_musdb_metadata(path, musdb, workers): |
|
tracks = get_musdb_tracks(musdb) |
|
metadata = _build_metadata(tracks, workers) |
|
path.parent.mkdir(exist_ok=True, parents=True) |
|
json.dump(metadata, open(path, "w")) |
|
|
|
|
|
def get_compressed_datasets(args, samples): |
|
metadata_file = args.metadata / "musdb.json" |
|
if not metadata_file.is_file() and args.rank == 0: |
|
_build_musdb_metadata(metadata_file, args.musdb, args.workers) |
|
if args.world_size > 1: |
|
distributed.barrier() |
|
metadata = json.load(open(metadata_file)) |
|
duration = Fraction(samples, args.samplerate) |
|
stride = Fraction(args.data_stride, args.samplerate) |
|
train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), |
|
metadata, |
|
duration=duration, |
|
stride=stride, |
|
streams=slice(1, None), |
|
samplerate=args.samplerate, |
|
channels=args.audio_channels) |
|
valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), |
|
metadata, |
|
samplerate=args.samplerate, |
|
channels=args.audio_channels) |
|
return train_set, valid_set |
|
|