|
"""preprocess_egmd.py""" |
|
import os |
|
import csv |
|
import glob |
|
import re |
|
import json |
|
from typing import Dict, List, Tuple |
|
import numpy as np |
|
from utils.audio import get_audio_file_info |
|
from utils.midi import midi2note, note_event2midi |
|
from utils.note2event import note2note_event, note_event2event |
|
from utils.event2note import event2note_event |
|
from utils.note_event_dataclasses import Note, NoteEvent |
|
from utils.utils import note_event2token2note_event_sanity_check |
|
|
|
|
|
|
|
def create_note_event_and_note_from_midi(mid_file: str, id: str) -> Tuple[Dict, Dict]: |
|
"""Extracts note or note_event and metadata from midi: |
|
|
|
Returns: |
|
notes (dict): note events and metadata. |
|
note_events (dict): note events and metadata. |
|
""" |
|
notes, dur_sec = midi2note( |
|
mid_file, |
|
binary_velocity=True, |
|
ch_9_as_drum=True, |
|
force_all_drum=True, |
|
trim_overlap=True, |
|
fix_offset=True, |
|
quantize=True, |
|
verbose=0, |
|
minimum_offset_sec=0.01, |
|
drum_offset_sec=0.01, |
|
ignore_pedal=True) |
|
return { |
|
'egmd_id': id, |
|
'program': [128], |
|
'is_drum': [1], |
|
'duration_sec': dur_sec, |
|
'notes': notes, |
|
}, { |
|
'maps_id': id, |
|
'program': [128], |
|
'is_drum': [1], |
|
'duration_sec': dur_sec, |
|
'note_events': note2note_event(notes), |
|
} |
|
|
|
|
|
def preprocess_egmd16k(data_home: os.PathLike, dataset_name='egmd') -> None: |
|
""" |
|
Splits: |
|
- train: 35217 files |
|
- validation: 5031 files |
|
- test: 5289 files |
|
- test_reduced: 246 files that contain '_5.midi' or '_10.midi' in the filename |
|
|
|
|
|
Writes: |
|
- {dataset_name}_{split}_file_list.json: a dictionary with the following keys: |
|
{ |
|
index: |
|
{ |
|
'egmd_id': egmd_id, # filename wihout extension |
|
'n_frames': (int), |
|
'mix_audio_file': 'path/to/mix.wav', |
|
'notes_file': 'path/to/notes.npy', |
|
'note_events_file': 'path/to/note_events.npy', |
|
'midi_file': 'path/to/midi.mid', |
|
'program': List[int], |
|
'is_drum': List[int], # 0 or 1 |
|
} |
|
} |
|
""" |
|
|
|
|
|
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') |
|
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') |
|
os.makedirs(output_index_dir, exist_ok=True) |
|
|
|
|
|
csv_file = os.path.join(base_dir, 'e-gmd-v1.0.0.csv') |
|
with open(csv_file, 'r') as f: |
|
csv_dict_reader = csv.DictReader(f) |
|
egmd_dict_list_all = list(csv_dict_reader) |
|
assert len(egmd_dict_list_all) == 45537 |
|
|
|
|
|
for d in egmd_dict_list_all: |
|
emgd_id = d['midi_filename'].split('.')[0] |
|
midi_file = os.path.join(base_dir, d['midi_filename']) |
|
notes, note_events = create_note_event_and_note_from_midi(midi_file, emgd_id) |
|
|
|
|
|
notes_file = midi_file.replace('.midi', '_notes.npy') |
|
note_events_file = midi_file.replace('.midi', '_note_events.npy') |
|
np.save(notes_file, notes, allow_pickle=True, fix_imports=False) |
|
print(f"Created {notes_file}") |
|
np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) |
|
print(f"Created {note_events_file}") |
|
|
|
|
|
quantized_midi_file = midi_file.replace('.midi', '_quantized_120bpm.mid') |
|
note_event2midi(note_events['note_events'], quantized_midi_file) |
|
print(f'Wrote {quantized_midi_file}') |
|
|
|
|
|
pass |
|
|
|
|
|
for split in ['train', 'validation', 'test']: |
|
file_list = {} |
|
i = 0 |
|
for d in egmd_dict_list_all: |
|
if d['split'] == split: |
|
egmd_id = d['midi_filename'].split('.')[0] |
|
mix_audio_file = os.path.join(base_dir, d['audio_filename']) |
|
n_frames = get_audio_file_info(mix_audio_file)[1] |
|
midi_file = os.path.join(base_dir, d['midi_filename']) |
|
notes_file = midi_file.replace('.midi', '_notes.npy') |
|
note_events_file = midi_file.replace('.midi', '_note_events.npy') |
|
|
|
|
|
assert os.path.exists(mix_audio_file) |
|
assert os.path.exists(midi_file) |
|
assert os.path.exists(notes_file) |
|
assert os.path.exists(note_events_file) |
|
|
|
|
|
file_list[i] = { |
|
'egmd_id': egmd_id, |
|
'n_frames': n_frames, |
|
'mix_audio_file': mix_audio_file, |
|
'notes_file': notes_file, |
|
'note_events_file': note_events_file, |
|
'midi_file': midi_file, |
|
'program': [128], |
|
'is_drum': [1], |
|
} |
|
i += 1 |
|
else: |
|
pass |
|
|
|
|
|
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') |
|
with open(output_file, 'w') as f: |
|
json.dump(file_list, f, indent=4) |
|
print(f'Wrote {output_file}') |
|
if split == 'train': |
|
assert len(file_list) == 35217 |
|
elif split == 'validation': |
|
assert len(file_list) == 5031 |
|
elif split == 'test': |
|
assert len(file_list) == 5289 |
|
|
|
|
|
split = 'test_reduced' |
|
file_list = {} |
|
i = 0 |
|
for d in egmd_dict_list_all: |
|
if d['split'] == 'test': |
|
midi_file = os.path.join(base_dir, d['midi_filename']) |
|
if '_5.midi' in midi_file or '_10.midi' in midi_file: |
|
egmd_id = d['midi_filename'].split('.')[0] |
|
mix_audio_file = os.path.join(base_dir, d['audio_filename']) |
|
n_frames = get_audio_file_info(mix_audio_file)[1] |
|
notes_file = midi_file.replace('.midi', '_notes.npy') |
|
note_events_file = midi_file.replace('.midi', '_note_events.npy') |
|
file_list[i] = { |
|
'egmd_id': egmd_id, |
|
'n_frames': n_frames, |
|
'mix_audio_file': mix_audio_file, |
|
'notes_file': notes_file, |
|
'note_events_file': note_events_file, |
|
'midi_file': midi_file, |
|
'program': [128], |
|
'is_drum': [1], |
|
} |
|
i += 1 |
|
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') |
|
with open(output_file, 'w') as f: |
|
json.dump(file_list, f, indent=4) |
|
print(f'Wrote {output_file}') |
|
assert len(file_list) == 246 |
|
|