|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
from collections import defaultdict, namedtuple |
|
from pathlib import Path |
|
|
|
import musdb |
|
import numpy as np |
|
import torch as th |
|
import tqdm |
|
from torch.utils.data import DataLoader |
|
|
|
from .audio import AudioFile |
|
|
|
ChunkInfo = namedtuple("ChunkInfo", ["file_index", "offset", "local_index"]) |
|
|
|
|
|
class Rawset: |
|
""" |
|
Dataset of raw, normalized, float32 audio files |
|
""" |
|
def __init__(self, path, samples=None, stride=None, channels=2, streams=None): |
|
self.path = Path(path) |
|
self.channels = channels |
|
self.samples = samples |
|
if stride is None: |
|
stride = samples if samples is not None else 0 |
|
self.stride = stride |
|
entries = defaultdict(list) |
|
for root, folders, files in os.walk(self.path, followlinks=True): |
|
folders.sort() |
|
files.sort() |
|
for file in files: |
|
if file.endswith(".raw"): |
|
path = Path(root) / file |
|
name, stream = path.stem.rsplit('.', 1) |
|
entries[(path.parent.relative_to(self.path), name)].append(int(stream)) |
|
|
|
self._entries = list(entries.keys()) |
|
|
|
sizes = [] |
|
self._lengths = [] |
|
ref_streams = sorted(entries[self._entries[0]]) |
|
assert ref_streams == list(range(len(ref_streams))) |
|
if streams is None: |
|
self.streams = ref_streams |
|
else: |
|
self.streams = streams |
|
for entry in sorted(entries.keys()): |
|
streams = entries[entry] |
|
assert sorted(streams) == ref_streams |
|
file = self._path(*entry) |
|
length = file.stat().st_size // (4 * channels) |
|
if samples is None: |
|
sizes.append(1) |
|
else: |
|
if length < samples: |
|
self._entries.remove(entry) |
|
continue |
|
sizes.append((length - samples) // stride + 1) |
|
self._lengths.append(length) |
|
if not sizes: |
|
raise ValueError(f"Empty dataset {self.path}") |
|
self._cumulative_sizes = np.cumsum(sizes) |
|
self._sizes = sizes |
|
|
|
def __len__(self): |
|
return self._cumulative_sizes[-1] |
|
|
|
@property |
|
def total_length(self): |
|
return sum(self._lengths) |
|
|
|
def chunk_info(self, index): |
|
file_index = np.searchsorted(self._cumulative_sizes, index, side='right') |
|
if file_index == 0: |
|
local_index = index |
|
else: |
|
local_index = index - self._cumulative_sizes[file_index - 1] |
|
return ChunkInfo(offset=local_index * self.stride, |
|
file_index=file_index, |
|
local_index=local_index) |
|
|
|
def _path(self, folder, name, stream=0): |
|
return self.path / folder / (name + f'.{stream}.raw') |
|
|
|
def __getitem__(self, index): |
|
chunk = self.chunk_info(index) |
|
entry = self._entries[chunk.file_index] |
|
|
|
length = self.samples or self._lengths[chunk.file_index] |
|
streams = [] |
|
to_read = length * self.channels * 4 |
|
for stream_index, stream in enumerate(self.streams): |
|
offset = chunk.offset * 4 * self.channels |
|
file = open(self._path(*entry, stream=stream), 'rb') |
|
file.seek(offset) |
|
content = file.read(to_read) |
|
assert len(content) == to_read |
|
content = np.frombuffer(content, dtype=np.float32) |
|
content = content.copy() |
|
streams.append(th.from_numpy(content).view(length, self.channels).t()) |
|
return th.stack(streams, dim=0) |
|
|
|
def name(self, index): |
|
chunk = self.chunk_info(index) |
|
folder, name = self._entries[chunk.file_index] |
|
return folder / name |
|
|
|
|
|
class MusDBSet: |
|
def __init__(self, mus, streams=slice(None), samplerate=44100, channels=2): |
|
self.mus = mus |
|
self.streams = streams |
|
self.samplerate = samplerate |
|
self.channels = channels |
|
|
|
def __len__(self): |
|
return len(self.mus.tracks) |
|
|
|
def __getitem__(self, index): |
|
track = self.mus.tracks[index] |
|
return (track.name, AudioFile(track.path).read(channels=self.channels, |
|
seek_time=0, |
|
streams=self.streams, |
|
samplerate=self.samplerate)) |
|
|
|
|
|
def build_raw(mus, destination, normalize, workers, samplerate, channels): |
|
destination.mkdir(parents=True, exist_ok=True) |
|
loader = DataLoader(MusDBSet(mus, channels=channels, samplerate=samplerate), |
|
batch_size=1, |
|
num_workers=workers, |
|
collate_fn=lambda x: x[0]) |
|
for name, streams in tqdm.tqdm(loader): |
|
if normalize: |
|
ref = streams[0].mean(dim=0) |
|
streams = (streams - ref.mean()) / ref.std() |
|
for index, stream in enumerate(streams): |
|
open(destination / (name + f'.{index}.raw'), "wb").write(stream.t().numpy().tobytes()) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser('rawset') |
|
parser.add_argument('--workers', type=int, default=10) |
|
parser.add_argument('--samplerate', type=int, default=44100) |
|
parser.add_argument('--channels', type=int, default=2) |
|
parser.add_argument('musdb', type=Path) |
|
parser.add_argument('destination', type=Path) |
|
|
|
args = parser.parse_args() |
|
|
|
build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="train"), |
|
args.destination / "train", |
|
normalize=True, |
|
channels=args.channels, |
|
samplerate=args.samplerate, |
|
workers=args.workers) |
|
build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="valid"), |
|
args.destination / "valid", |
|
normalize=True, |
|
samplerate=args.samplerate, |
|
channels=args.channels, |
|
workers=args.workers) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|