""" preprocess_idmt_smt_bass.py """ import os import glob import json import wave import numpy as np from typing import Dict, Tuple from sklearn.model_selection import train_test_split from utils.audio import get_audio_file_info, load_audio_file, write_wav_file, guess_onset_offset_by_amp_envelope from utils.midi import midi2note, note_event2midi from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes from utils.event2note import event2note_event from utils.note_event_dataclasses import Note, NoteEvent from utils.utils import assert_note_events_almost_equal SPLIT_INFO_FILE = 'stratified_split_crepe_smt.json' # Plucking style to GM program PS2program = { "FS": 33, # Fingered Elec Bass "MU": 33, # Muted Elec Bass "PK": 34, # Picked Elec Bass "SP": 36, # Slap-Pluck Elec Bass "ST": 37, # Salp-Thumb Elec Bass } PREPEND_SILENCE = 1.8 # seconds APPEND_SILENCE = 1.8 # seconds def bass_string_to_midi_pitch(string_number: int, fret: int, string_pitches=[28, 33, 38, 43, 48]): """ sring_number: 1, 2, 3, 4, fret: 0, 1, 2, ...""" return string_pitches[string_number - 1] + fret def regenerate_stratified_split(audio_files_dict): train_ids_dict = {} val_ids_dict = {} offset = 0 for key, files in audio_files_dict.items(): ids = np.arange(len(files)) + offset train_ids, val_ids = train_test_split( ids, test_size=0.2, random_state=42, stratify=np.zeros_like(ids)) train_ids_dict[key] = train_ids val_ids_dict[key] = val_ids offset += len(files) train_ids = np.concatenate(list(train_ids_dict.values())) val_ids = np.concatenate(list(val_ids_dict.values())) assert len(train_ids) == 1872 and len(val_ids) == 470 return train_ids, val_ids def create_note_event_and_note_from_midi(mid_file: str, id: str, program: int = 0, ignore_pedal: bool = True) -> Tuple[Dict, Dict]: """Extracts note or note_event and metadata from midi: Returns: notes (dict): note events and metadata. note_events (dict): note events and metadata. """ notes, dur_sec = midi2note( mid_file, binary_velocity=True, force_all_program_to=program, fix_offset=True, quantize=True, verbose=0, minimum_offset_sec=0.01, ignore_pedal=ignore_pedal) return { # notes 'idmt_smt_bass_id': str(id), 'program': [program], 'is_drum': [0], 'duration_sec': dur_sec, 'notes': notes, }, { # note_events 'idmt_smt_bass_id': str(id), 'program': [0], 'is_drum': [0], 'duration_sec': dur_sec, 'note_events': note2note_event(notes), } def preprocess_idmt_smt_bass_16k(data_home=os.PathLike, dataset_name='idmt_smt_bass', sanity_check=True, edit_audio=True, regenerate_split=False) -> None: """ Splits: stratified by plucking style 'train': 1872 'validation': 470 Total: 2342 Writes: - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: { index: { 'idmt_smt_bass_id': idmt_smt_bass_id, 'n_frames': (int), 'mix_audio_file': 'path/to/mix.wav', 'notes_file': 'path/to/notes.npy', 'note_events_file': 'path/to/note_events.npy', 'midi_file': 'path/to/midi.mid', 'program': List[int], see PS2program above 'is_drum': List[int], # always [0] for this dataset } } """ # Directory and file paths base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') output_index_dir = os.path.join(data_home, 'yourmt3_indexes') os.makedirs(output_index_dir, exist_ok=True) # # audio file list # FS_audio_pattern = os.path.join(base_dir, 'PS/FS/*.wav') # MU_audio_pattern = os.path.join(base_dir, 'PS/MU/*.wav') # PK_audio_pattern = os.path.join(base_dir, 'PS/PK/*.wav') # SP_audio_pattern = os.path.join(base_dir, 'PS/SP/*.wav') # ST_audio_pattern = os.path.join(base_dir, 'PS/ST/*.wav') # FS_audio_files = sorted(glob.glob(FS_audio_pattern, recursive=False)) # MU_audio_files = sorted(glob.glob(MU_audio_pattern, recursive=False)) # PK_audio_files = sorted(glob.glob(PK_audio_pattern, recursive=False)) # SP_audio_files = sorted(glob.glob(SP_audio_pattern, recursive=False)) # ST_audio_files = sorted(glob.glob(ST_audio_pattern, recursive=False)) # assert len(FS_audio_files) == 469 # assert len(MU_audio_files) == 468 # assert len(PK_audio_files) == 468 # assert len(SP_audio_files) == 469 # assert len(ST_audio_files) == 468 # audio_files_dict = { # 'FS': FS_audio_files, # 'MU': MU_audio_files, # 'PK': PK_audio_files, # 'SP': SP_audio_files, # 'ST': ST_audio_files # } # splits: split_info_file = os.path.join(base_dir, SPLIT_INFO_FILE) with open(split_info_file, 'r') as f: split_info = json.load(f) all_info_dict = {} id = 0 for split in ['train', 'validation']: for file_path in split_info[split]: audio_file = os.path.join(base_dir, file_path) assert os.path.exists(audio_file) all_info_dict[id] = { 'idmt_smt_bass_id': id, 'n_frames': None, 'mix_audio_file': audio_file, 'notes_file': None, 'note_events_file': None, 'midi_file': None, 'program': None, 'is_drum': [0] } id += 1 train_ids = np.arange(len(split_info['train'])) val_ids = np.arange(len(split_info['validation'])) + len(train_ids) # if regenerate_split is True: # train_ids, val_ids = regenerate_stratified_split(audio_files_dict) # else: # val_ids = VALIDATION_IDS # train_ids = [i for i in range(len(all_info_dict)) if i not in val_ids] # Audio processing: prepend/append 1.8s silence if edit_audio is True: for v in all_info_dict.values(): audio_file = v['mix_audio_file'] fs, x_len, _ = get_audio_file_info(audio_file) x = load_audio_file(audio_file) # (T,) prefix_len = int(fs * PREPEND_SILENCE) suffix_len = int(fs * APPEND_SILENCE) x_new_len = prefix_len + x_len + suffix_len x_new = np.zeros(x_new_len) x_new[prefix_len:prefix_len + x_len] = x # overwrite audio file print(f'Overwriting {audio_file} with silence prepended/appended') write_wav_file(audio_file, x_new, fs) # Guess Program/Pitch/Onset/Offset and Generate Notes/NoteEvents/MIDI for id in all_info_dict.keys(): audio_file = all_info_dict[id]['mix_audio_file'] # Guess program/pitch from audio file name _, _, _, _, pluck_style, _, string_num, fret_num = os.path.basename(audio_file).split( '.')[0].split('_') program = PS2program[pluck_style] pitch = bass_string_to_midi_pitch(int(string_num), int(fret_num)) # Guess onset/offset from audio signal x fs, n_frames, _ = get_audio_file_info(audio_file) x = load_audio_file(audio_file, fs=fs) onset, offset, _ = guess_onset_offset_by_amp_envelope( x, fs=fs, onset_threshold=0.05, offset_threshold=0.02, frame_size=256) onset = round((onset / fs) * 1000) / 1000 offset = round((offset / fs) * 1000) / 1000 # Notes and NoteEvents notes = [ Note( is_drum=False, program=program, onset=onset, offset=offset, pitch=pitch, velocity=1, ) ] note_events = note2note_event(notes) # Write MIDI midi_file = audio_file.replace('.wav', '.mid') note_event2midi(note_events, midi_file) # Reconvert MIDI to Notes/NoteEvents, and validate notes_dict, note_events_dict = create_note_event_and_note_from_midi( midi_file, id, program=program, ignore_pedal=True) if sanity_check: assert_note_events_almost_equal(note_events_dict['note_events'], note_events) # Write notes and note_events notes_file = audio_file.replace('.wav', '_notes.npy') note_events_file = audio_file.replace('.wav', '_note_events.npy') np.save(notes_file, notes_dict, allow_pickle=True, fix_imports=False) np.save(note_events_file, note_events_dict, allow_pickle=True, fix_imports=False) print(f'Created {notes_file}') print(f'Created {note_events_file}') # Update all_info_dict all_info_dict[id]['n_frames'] = n_frames all_info_dict[id]['notes_file'] = notes_file all_info_dict[id]['note_events_file'] = note_events_file all_info_dict[id]['midi_file'] = midi_file all_info_dict[id]['program'] = [program] # Save index ids = {'train': train_ids, 'validation': val_ids, 'all': list(all_info_dict.keys())} for split in ['train', 'validation']: fl = {} for i, id in enumerate(ids[split]): fl[i] = all_info_dict[id] output_index_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') with open(output_index_file, 'w') as f: json.dump(fl, f, indent=4) print(f'Created {output_index_file}') def test_guess_onset_offset_by_amp_envelope(all_info_dict): import matplotlib.pyplot as plt id = np.random.randint(0, 2300) x = load_audio_file(all_info_dict[id]['mix_audio_file']) onset, offset, amp_env = guess_onset_offset_by_amp_envelope(x) plt.plot(x) plt.axvline(x=onset, color='r', linestyle='--', label='onset') plt.axvline(x=offset, color='g', linestyle='--', label='offset') plt.show()