__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' |
import os |
import random |
import numpy as np |
import torch |
import soundfile as sf |
import pickle |
import time |
import itertools |
import multiprocessing |
from tqdm.auto import tqdm |
from glob import glob |
import audiomentations as AU |
import pedalboard as PB |
import warnings |
warnings.filterwarnings("ignore") |
def load_chunk(path, length, chunk_size, offset=None): |
if chunk_size <= length: |
if offset is None: |
offset = np.random.randint(length - chunk_size + 1) |
x = sf.read(path, dtype='float32', start=offset, frames=chunk_size)[0] |
else: |
x = sf.read(path, dtype='float32')[0] |
if len(x.shape) == 1: |
pad = np.zeros((chunk_size - length)) |
else: |
pad = np.zeros([chunk_size - length, x.shape[-1]]) |
x = np.concatenate([x, pad], axis=0) |
if len(x.shape) == 1: |
x = np.expand_dims(x, axis=1) |
return x.T |
def get_track_set_length(params): |
path, instruments, file_types = params |
lengths_arr = [] |
for instr in instruments: |
length = -1 |
for extension in file_types: |
path_to_audio_file = path + '/{}.{}'.format(instr, extension) |
if os.path.isfile(path_to_audio_file): |
length = len(sf.read(path_to_audio_file)[0]) |
break |
if length == -1: |
print('Cant find file "{}" in folder {}'.format(instr, path)) |
continue |
lengths_arr.append(length) |
lengths_arr = np.array(lengths_arr) |
if lengths_arr.min() != lengths_arr.max(): |
print('Warning: lengths of stems are different for path: {}. ({} != {})'.format( |
path, |
lengths_arr.min(), |
lengths_arr.max()) |
) |
return path, lengths_arr.min() |
def get_track_length(params): |
path = params |
length = len(sf.read(path)[0]) |
return (path, length) |
class MSSDataset(torch.utils.data.Dataset): |
def __init__(self, config, data_path, metadata_path="metadata.pkl", dataset_type=1, batch_size=None, verbose=True): |
self.verbose = verbose |
self.config = config |
self.dataset_type = dataset_type |
self.data_path = data_path |
self.instruments = instruments = config.training.instruments |
if batch_size is None: |
batch_size = config.training.batch_size |
self.batch_size = batch_size |
self.file_types = ['wav', 'flac'] |
self.metadata_path = metadata_path |
self.aug = False |
if 'augmentations' in config: |
if config['augmentations'].enable is True: |
if self.verbose: |
print('Use augmentation for training') |
self.aug = True |
else: |
if self.verbose: |
print('There is no augmentations block in config. Augmentations disabled for training...') |
metadata = self.get_metadata() |
if self.dataset_type in [1, 4]: |
if len(metadata) > 0: |
if self.verbose: |
print('Found tracks in dataset: {}'.format(len(metadata))) |
else: |
print('No tracks found for training. Check paths you provided!') |
exit() |
else: |
for instr in self.instruments: |
if self.verbose: |
print('Found tracks for {} in dataset: {}'.format(instr, len(metadata[instr]))) |
self.metadata = metadata |
self.chunk_size = config.audio.chunk_size |
self.min_mean_abs = config.audio.min_mean_abs |
def __len__(self): |
return self.config.training.num_steps * self.batch_size |
def read_from_metadata_cache(self, track_paths, instr=None): |
metadata = [] |
if os.path.isfile(self.metadata_path): |
if self.verbose: |
print('Found metadata cache file: {}'.format(self.metadata_path)) |
old_metadata = pickle.load(open(self.metadata_path, 'rb')) |
else: |
return track_paths, metadata |
if instr: |
old_metadata = old_metadata[instr] |
track_paths_set = set(track_paths) |
for old_path, file_size in old_metadata: |
if old_path in track_paths_set: |
metadata.append([old_path, file_size]) |
track_paths_set.remove(old_path) |
track_paths = list(track_paths_set) |
if len(metadata) > 0: |
print('Old metadata was used for {} tracks.'.format(len(metadata))) |
return track_paths, metadata |
def get_metadata(self): |
read_metadata_procs = multiprocessing.cpu_count() |
if 'read_metadata_procs' in self.config['training']: |
read_metadata_procs = int(self.config['training']['read_metadata_procs']) |
if self.verbose: |
print( |
'Dataset type:', self.dataset_type, |
'Processes to use:', read_metadata_procs, |
'\nCollecting metadata for', str(self.data_path), |
) |
if self.dataset_type in [1, 4]: |
track_paths = [] |
if type(self.data_path) == list: |
for tp in self.data_path: |
tracks_for_folder = sorted(glob(tp + '/*')) |
if len(tracks_for_folder) == 0: |
print('Warning: no tracks found in folder \'{}\'. Please check it!'.format(tp)) |
track_paths += tracks_for_folder |
else: |
track_paths += sorted(glob(self.data_path + '/*')) |
track_paths = [path for path in track_paths if os.path.basename(path)[0] != '.' and os.path.isdir(path)] |
track_paths, metadata = self.read_from_metadata_cache(track_paths, None) |
if read_metadata_procs <= 1: |
for path in tqdm(track_paths): |
track_path, track_length = get_track_set_length((path, self.instruments, self.file_types)) |
metadata.append((track_path, track_length)) |
else: |
p = multiprocessing.Pool(processes=read_metadata_procs) |
with tqdm(total=len(track_paths)) as pbar: |
track_iter = p.imap( |
get_track_set_length, |
zip(track_paths, itertools.repeat(self.instruments), itertools.repeat(self.file_types)) |
) |
for track_path, track_length in track_iter: |
metadata.append((track_path, track_length)) |
pbar.update() |
p.close() |
elif self.dataset_type == 2: |
metadata = dict() |
for instr in self.instruments: |
metadata[instr] = [] |
track_paths = [] |
if type(self.data_path) == list: |
for tp in self.data_path: |
track_paths += sorted(glob(tp + '/{}/*.wav'.format(instr))) |
track_paths += sorted(glob(tp + '/{}/*.flac'.format(instr))) |
else: |
track_paths += sorted(glob(self.data_path + '/{}/*.wav'.format(instr))) |
track_paths += sorted(glob(self.data_path + '/{}/*.flac'.format(instr))) |
track_paths, metadata[instr] = self.read_from_metadata_cache(track_paths, instr) |
if read_metadata_procs <= 1: |
for path in tqdm(track_paths): |
length = len(sf.read(path)[0]) |
metadata[instr].append((path, length)) |
else: |
p = multiprocessing.Pool(processes=read_metadata_procs) |
for out in tqdm(p.imap(get_track_length, track_paths), total=len(track_paths)): |
metadata[instr].append(out) |
elif self.dataset_type == 3: |
import pandas as pd |
if type(self.data_path) != list: |
data_path = [self.data_path] |
metadata = dict() |
for i in range(len(self.data_path)): |
if self.verbose: |
print('Reading tracks from: {}'.format(self.data_path[i])) |
df = pd.read_csv(self.data_path[i]) |
skipped = 0 |
for instr in self.instruments: |
part = df[df['instrum'] == instr].copy() |
print('Tracks found for {}: {}'.format(instr, len(part))) |
for instr in self.instruments: |
part = df[df['instrum'] == instr].copy() |
metadata[instr] = [] |
track_paths = list(part['path'].values) |
track_paths, metadata[instr] = self.read_from_metadata_cache(track_paths, instr) |
for path in tqdm(track_paths): |
if not os.path.isfile(path): |
print('Cant find track: {}'.format(path)) |
skipped += 1 |
continue |
try: |
length = len(sf.read(path)[0]) |
except: |
print('Problem with path: {}'.format(path)) |
skipped += 1 |
continue |
metadata[instr].append((path, length)) |
if skipped > 0: |
print('Missing tracks: {} from {}'.format(skipped, len(df))) |
else: |
print('Unknown dataset type: {}. Must be 1, 2, 3 or 4'.format(self.dataset_type)) |
exit() |
pickle.dump(metadata, open(self.metadata_path, 'wb')) |
return metadata |
def load_source(self, metadata, instr): |
while True: |
if self.dataset_type in [1, 4]: |
track_path, track_length = random.choice(metadata) |
for extension in self.file_types: |
path_to_audio_file = track_path + '/{}.{}'.format(instr, extension) |
if os.path.isfile(path_to_audio_file): |
try: |
source = load_chunk(path_to_audio_file, track_length, self.chunk_size) |
except Exception as e: |
print('Error: {} Path: {}'.format(e, path_to_audio_file)) |
source = np.zeros((2, self.chunk_size), dtype=np.float32) |
break |
else: |
track_path, track_length = random.choice(metadata[instr]) |
try: |
source = load_chunk(track_path, track_length, self.chunk_size) |
except Exception as e: |
print('Error: {} Path: {}'.format(e, track_path)) |
source = np.zeros((2, self.chunk_size), dtype=np.float32) |
if np.abs(source).mean() >= self.min_mean_abs: |
break |
if self.aug: |
source = self.augm_data(source, instr) |
return torch.tensor(source, dtype=torch.float32) |
def load_random_mix(self): |
res = [] |
for instr in self.instruments: |
s1 = self.load_source(self.metadata, instr) |
if self.aug: |
if 'mixup' in self.config['augmentations']: |
if self.config['augmentations'].mixup: |
mixup = [s1] |
for prob in self.config.augmentations.mixup_probs: |
if random.uniform(0, 1) < prob: |
s2 = self.load_source(self.metadata, instr) |
mixup.append(s2) |
mixup = torch.stack(mixup, dim=0) |
loud_values = np.random.uniform( |
low=self.config.augmentations.loudness_min, |
high=self.config.augmentations.loudness_max, |
size=(len(mixup),) |
) |
loud_values = torch.tensor(loud_values, dtype=torch.float32) |
mixup *= loud_values[:, None, None] |
s1 = mixup.mean(dim=0, dtype=torch.float32) |
res.append(s1) |
res = torch.stack(res) |
return res |
def load_aligned_data(self): |
track_path, track_length = random.choice(self.metadata) |
attempts = 10 |
while attempts: |
if track_length >= self.chunk_size: |
common_offset = np.random.randint(track_length - self.chunk_size + 1) |
else: |
common_offset = None |
res = [] |
silent_chunks = 0 |
for i in self.instruments: |
for extension in self.file_types: |
path_to_audio_file = track_path + '/{}.{}'.format(i, extension) |
if os.path.isfile(path_to_audio_file): |
try: |
source = load_chunk(path_to_audio_file, track_length, self.chunk_size, offset=common_offset) |
except Exception as e: |
print('Error: {} Path: {}'.format(e, path_to_audio_file)) |
source = np.zeros((2, self.chunk_size), dtype=np.float32) |
break |
res.append(source) |
if np.abs(source).mean() < self.min_mean_abs: |
silent_chunks += 1 |
if silent_chunks == 0: |
break |
attempts -= 1 |
if attempts <= 0: |
print('Attempts max!', track_path) |
if common_offset is None: |
break |
res = np.stack(res, axis=0) |
if self.aug: |
for i, instr in enumerate(self.instruments): |
res[i] = self.augm_data(res[i], instr) |
return torch.tensor(res, dtype=torch.float32) |
def augm_data(self, source, instr): |
source_shape = source.shape |
applied_augs = [] |
if 'all' in self.config['augmentations']: |
augs = self.config['augmentations']['all'] |
else: |
augs = dict() |
if instr in self.config['augmentations']: |
for el in self.config['augmentations'][instr]: |
augs[el] = self.config['augmentations'][instr][el] |
if 'channel_shuffle' in augs: |
if augs['channel_shuffle'] > 0: |
if random.uniform(0, 1) < augs['channel_shuffle']: |
source = source[::-1].copy() |
applied_augs.append('channel_shuffle') |
if 'random_inverse' in augs: |
if augs['random_inverse'] > 0: |
if random.uniform(0, 1) < augs['random_inverse']: |
source = source[:, ::-1].copy() |
applied_augs.append('random_inverse') |
if 'random_polarity' in augs: |
if augs['random_polarity'] > 0: |
if random.uniform(0, 1) < augs['random_polarity']: |
source = -source.copy() |
applied_augs.append('random_polarity') |
if 'pitch_shift' in augs: |
if augs['pitch_shift'] > 0: |
if random.uniform(0, 1) < augs['pitch_shift']: |
apply_aug = AU.PitchShift( |
min_semitones=augs['pitch_shift_min_semitones'], |
max_semitones=augs['pitch_shift_max_semitones'], |
p=1.0 |
) |
source = apply_aug(samples=source, sample_rate=44100) |
applied_augs.append('pitch_shift') |
if 'seven_band_parametric_eq' in augs: |
if augs['seven_band_parametric_eq'] > 0: |
if random.uniform(0, 1) < augs['seven_band_parametric_eq']: |
apply_aug = AU.SevenBandParametricEQ( |
min_gain_db=augs['seven_band_parametric_eq_min_gain_db'], |
max_gain_db=augs['seven_band_parametric_eq_max_gain_db'], |
p=1.0 |
) |
source = apply_aug(samples=source, sample_rate=44100) |
applied_augs.append('seven_band_parametric_eq') |
if 'tanh_distortion' in augs: |
if augs['tanh_distortion'] > 0: |
if random.uniform(0, 1) < augs['tanh_distortion']: |
apply_aug = AU.TanhDistortion( |
min_distortion=augs['tanh_distortion_min'], |
max_distortion=augs['tanh_distortion_max'], |
p=1.0 |
) |
source = apply_aug(samples=source, sample_rate=44100) |
applied_augs.append('tanh_distortion') |
if 'mp3_compression' in augs: |
if augs['mp3_compression'] > 0: |
if random.uniform(0, 1) < augs['mp3_compression']: |
apply_aug = AU.Mp3Compression( |
min_bitrate=augs['mp3_compression_min_bitrate'], |
max_bitrate=augs['mp3_compression_max_bitrate'], |
backend=augs['mp3_compression_backend'], |
p=1.0 |
) |
source = apply_aug(samples=source, sample_rate=44100) |
applied_augs.append('mp3_compression') |
if 'gaussian_noise' in augs: |
if augs['gaussian_noise'] > 0: |
if random.uniform(0, 1) < augs['gaussian_noise']: |
apply_aug = AU.AddGaussianNoise( |
min_amplitude=augs['gaussian_noise_min_amplitude'], |
max_amplitude=augs['gaussian_noise_max_amplitude'], |
p=1.0 |
) |
source = apply_aug(samples=source, sample_rate=44100) |
applied_augs.append('gaussian_noise') |
if 'time_stretch' in augs: |
if augs['time_stretch'] > 0: |
if random.uniform(0, 1) < augs['time_stretch']: |
apply_aug = AU.TimeStretch( |
min_rate=augs['time_stretch_min_rate'], |
max_rate=augs['time_stretch_max_rate'], |
leave_length_unchanged=True, |
p=1.0 |
) |
source = apply_aug(samples=source, sample_rate=44100) |
applied_augs.append('time_stretch') |
if source_shape != source.shape: |
source = source[..., :source_shape[-1]] |
if 'pedalboard_reverb' in augs: |
if augs['pedalboard_reverb'] > 0: |
if random.uniform(0, 1) < augs['pedalboard_reverb']: |
room_size = random.uniform( |
augs['pedalboard_reverb_room_size_min'], |
augs['pedalboard_reverb_room_size_max'], |
) |
damping = random.uniform( |
augs['pedalboard_reverb_damping_min'], |
augs['pedalboard_reverb_damping_max'], |
) |
wet_level = random.uniform( |
augs['pedalboard_reverb_wet_level_min'], |
augs['pedalboard_reverb_wet_level_max'], |
) |
dry_level = random.uniform( |
augs['pedalboard_reverb_dry_level_min'], |
augs['pedalboard_reverb_dry_level_max'], |
) |
width = random.uniform( |
augs['pedalboard_reverb_width_min'], |
augs['pedalboard_reverb_width_max'], |
) |
board = PB.Pedalboard([PB.Reverb( |
room_size=room_size, |
damping=damping, |
wet_level=wet_level, |
dry_level=dry_level, |
width=width, |
freeze_mode=0.0, |
)]) |
source = board(source, 44100) |
applied_augs.append('pedalboard_reverb') |
if 'pedalboard_chorus' in augs: |
if augs['pedalboard_chorus'] > 0: |
if random.uniform(0, 1) < augs['pedalboard_chorus']: |
rate_hz = random.uniform( |
augs['pedalboard_chorus_rate_hz_min'], |
augs['pedalboard_chorus_rate_hz_max'], |
) |
depth = random.uniform( |
augs['pedalboard_chorus_depth_min'], |
augs['pedalboard_chorus_depth_max'], |
) |
centre_delay_ms = random.uniform( |
augs['pedalboard_chorus_centre_delay_ms_min'], |
augs['pedalboard_chorus_centre_delay_ms_max'], |
) |
feedback = random.uniform( |
augs['pedalboard_chorus_feedback_min'], |
augs['pedalboard_chorus_feedback_max'], |
) |
mix = random.uniform( |
augs['pedalboard_chorus_mix_min'], |
augs['pedalboard_chorus_mix_max'], |
) |
board = PB.Pedalboard([PB.Chorus( |
rate_hz=rate_hz, |
depth=depth, |
centre_delay_ms=centre_delay_ms, |
feedback=feedback, |
mix=mix, |
)]) |
source = board(source, 44100) |
applied_augs.append('pedalboard_chorus') |
if 'pedalboard_phazer' in augs: |
if augs['pedalboard_phazer'] > 0: |
if random.uniform(0, 1) < augs['pedalboard_phazer']: |
rate_hz = random.uniform( |
augs['pedalboard_phazer_rate_hz_min'], |
augs['pedalboard_phazer_rate_hz_max'], |
) |
depth = random.uniform( |
augs['pedalboard_phazer_depth_min'], |
augs['pedalboard_phazer_depth_max'], |
) |
centre_frequency_hz = random.uniform( |
augs['pedalboard_phazer_centre_frequency_hz_min'], |
augs['pedalboard_phazer_centre_frequency_hz_max'], |
) |
feedback = random.uniform( |
augs['pedalboard_phazer_feedback_min'], |
augs['pedalboard_phazer_feedback_max'], |
) |
mix = random.uniform( |
augs['pedalboard_phazer_mix_min'], |
augs['pedalboard_phazer_mix_max'], |
) |
board = PB.Pedalboard([PB.Phaser( |
rate_hz=rate_hz, |
depth=depth, |
centre_frequency_hz=centre_frequency_hz, |
feedback=feedback, |
mix=mix, |
)]) |
source = board(source, 44100) |
applied_augs.append('pedalboard_phazer') |
if 'pedalboard_distortion' in augs: |
if augs['pedalboard_distortion'] > 0: |
if random.uniform(0, 1) < augs['pedalboard_distortion']: |
drive_db = random.uniform( |
augs['pedalboard_distortion_drive_db_min'], |
augs['pedalboard_distortion_drive_db_max'], |
) |
board = PB.Pedalboard([PB.Distortion( |
drive_db=drive_db, |
)]) |
source = board(source, 44100) |
applied_augs.append('pedalboard_distortion') |
if 'pedalboard_pitch_shift' in augs: |
if augs['pedalboard_pitch_shift'] > 0: |
if random.uniform(0, 1) < augs['pedalboard_pitch_shift']: |
semitones = random.uniform( |
augs['pedalboard_pitch_shift_semitones_min'], |
augs['pedalboard_pitch_shift_semitones_max'], |
) |
board = PB.Pedalboard([PB.PitchShift( |
semitones=semitones |
)]) |
source = board(source, 44100) |
applied_augs.append('pedalboard_pitch_shift') |
if 'pedalboard_resample' in augs: |
if augs['pedalboard_resample'] > 0: |
if random.uniform(0, 1) < augs['pedalboard_resample']: |
target_sample_rate = random.uniform( |
augs['pedalboard_resample_target_sample_rate_min'], |
augs['pedalboard_resample_target_sample_rate_max'], |
) |
board = PB.Pedalboard([PB.Resample( |
target_sample_rate=target_sample_rate |
)]) |
source = board(source, 44100) |
applied_augs.append('pedalboard_resample') |
if 'pedalboard_bitcrash' in augs: |
if augs['pedalboard_bitcrash'] > 0: |
if random.uniform(0, 1) < augs['pedalboard_bitcrash']: |
bit_depth = random.uniform( |
augs['pedalboard_bitcrash_bit_depth_min'], |
augs['pedalboard_bitcrash_bit_depth_max'], |
) |
board = PB.Pedalboard([PB.Bitcrush( |
bit_depth=bit_depth |
)]) |
source = board(source, 44100) |
applied_augs.append('pedalboard_bitcrash') |
if 'pedalboard_mp3_compressor' in augs: |
if augs['pedalboard_mp3_compressor'] > 0: |
if random.uniform(0, 1) < augs['pedalboard_mp3_compressor']: |
vbr_quality = random.uniform( |
augs['pedalboard_mp3_compressor_pedalboard_mp3_compressor_min'], |
augs['pedalboard_mp3_compressor_pedalboard_mp3_compressor_max'], |
) |
board = PB.Pedalboard([PB.MP3Compressor( |
vbr_quality=vbr_quality |
)]) |
source = board(source, 44100) |
applied_augs.append('pedalboard_mp3_compressor') |
return source |
def __getitem__(self, index): |
if self.dataset_type in [1, 2, 3]: |
res = self.load_random_mix() |
else: |
res = self.load_aligned_data() |
if self.aug: |
if 'loudness' in self.config['augmentations']: |
if self.config['augmentations']['loudness']: |
loud_values = np.random.uniform( |
low=self.config['augmentations']['loudness_min'], |
high=self.config['augmentations']['loudness_max'], |
size=(len(res),) |
) |
loud_values = torch.tensor(loud_values, dtype=torch.float32) |
res *= loud_values[:, None, None] |
mix = res.sum(0) |
if self.aug: |
if 'mp3_compression_on_mixture' in self.config['augmentations']: |
apply_aug = AU.Mp3Compression( |
min_bitrate=self.config['augmentations']['mp3_compression_on_mixture_bitrate_min'], |
max_bitrate=self.config['augmentations']['mp3_compression_on_mixture_bitrate_max'], |
backend=self.config['augmentations']['mp3_compression_on_mixture_backend'], |
p=self.config['augmentations']['mp3_compression_on_mixture'] |
) |
mix_conv = mix.cpu().numpy().astype(np.float32) |
required_shape = mix_conv.shape |
mix = apply_aug(samples=mix_conv, sample_rate=44100) |
if mix.shape != required_shape: |
mix = mix[..., :required_shape[-1]] |
mix = torch.tensor(mix, dtype=torch.float32) |
if self.config.training.target_instrument is not None: |
index = self.config.training.instruments.index(self.config.training.target_instrument) |
return res[index:index+1], mix |
return res, mix |