YourMT3 / amt /src /utils /preprocess /preprocess_egmd.py
mimbres's picture
.
a03c9b4
raw
history blame
6.91 kB
"""preprocess_egmd.py"""
import os
import csv
import glob
import re
import json
from typing import Dict, List, Tuple
import numpy as np
from utils.audio import get_audio_file_info
from utils.midi import midi2note, note_event2midi
from utils.note2event import note2note_event, note_event2event
from utils.event2note import event2note_event
from utils.note_event_dataclasses import Note, NoteEvent
from utils.utils import note_event2token2note_event_sanity_check
# from utils.utils import assert_note_events_almost_equal
def create_note_event_and_note_from_midi(mid_file: str, id: str) -> Tuple[Dict, Dict]:
"""Extracts note or note_event and metadata from midi:
Returns:
notes (dict): note events and metadata.
note_events (dict): note events and metadata.
"""
notes, dur_sec = midi2note(
mid_file,
binary_velocity=True,
ch_9_as_drum=True,
force_all_drum=True,
trim_overlap=True,
fix_offset=True,
quantize=True,
verbose=0,
minimum_offset_sec=0.01,
drum_offset_sec=0.01,
ignore_pedal=True)
return { # notes
'egmd_id': id,
'program': [128],
'is_drum': [1],
'duration_sec': dur_sec,
'notes': notes,
}, { # note_events
'maps_id': id,
'program': [128],
'is_drum': [1],
'duration_sec': dur_sec,
'note_events': note2note_event(notes),
}
def preprocess_egmd16k(data_home: os.PathLike, dataset_name='egmd') -> None:
"""
Splits:
- train: 35217 files
- validation: 5031 files
- test: 5289 files
- test_reduced: 246 files that contain '_5.midi' or '_10.midi' in the filename
Writes:
- {dataset_name}_{split}_file_list.json: a dictionary with the following keys:
{
index:
{
'egmd_id': egmd_id, # filename wihout extension
'n_frames': (int),
'mix_audio_file': 'path/to/mix.wav',
'notes_file': 'path/to/notes.npy',
'note_events_file': 'path/to/note_events.npy',
'midi_file': 'path/to/midi.mid',
'program': List[int],
'is_drum': List[int], # 0 or 1
}
}
"""
# Directory and file paths
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k')
output_index_dir = os.path.join(data_home, 'yourmt3_indexes')
os.makedirs(output_index_dir, exist_ok=True)
# Load csv file and create a dictionary
csv_file = os.path.join(base_dir, 'e-gmd-v1.0.0.csv')
with open(csv_file, 'r') as f:
csv_dict_reader = csv.DictReader(f)
egmd_dict_list_all = list(csv_dict_reader)
assert len(egmd_dict_list_all) == 45537
# Process MIDI files
for d in egmd_dict_list_all:
emgd_id = d['midi_filename'].split('.')[0]
midi_file = os.path.join(base_dir, d['midi_filename'])
notes, note_events = create_note_event_and_note_from_midi(midi_file, emgd_id)
# Write notes and note_events
notes_file = midi_file.replace('.midi', '_notes.npy')
note_events_file = midi_file.replace('.midi', '_note_events.npy')
np.save(notes_file, notes, allow_pickle=True, fix_imports=False)
print(f"Created {notes_file}")
np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False)
print(f"Created {note_events_file}")
# rewrite 120 bpm quantized midi file
quantized_midi_file = midi_file.replace('.midi', '_quantized_120bpm.mid')
note_event2midi(note_events['note_events'], quantized_midi_file)
print(f'Wrote {quantized_midi_file}')
# Process audio files
pass
# Create index files
for split in ['train', 'validation', 'test']:
file_list = {}
i = 0
for d in egmd_dict_list_all:
if d['split'] == split:
egmd_id = d['midi_filename'].split('.')[0]
mix_audio_file = os.path.join(base_dir, d['audio_filename'])
n_frames = get_audio_file_info(mix_audio_file)[1]
midi_file = os.path.join(base_dir, d['midi_filename'])
notes_file = midi_file.replace('.midi', '_notes.npy')
note_events_file = midi_file.replace('.midi', '_note_events.npy')
# check file existence
assert os.path.exists(mix_audio_file)
assert os.path.exists(midi_file)
assert os.path.exists(notes_file)
assert os.path.exists(note_events_file)
# create file list
file_list[i] = {
'egmd_id': egmd_id,
'n_frames': n_frames,
'mix_audio_file': mix_audio_file,
'notes_file': notes_file,
'note_events_file': note_events_file,
'midi_file': midi_file,
'program': [128],
'is_drum': [1],
}
i += 1
else:
pass
# Write file list
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json')
with open(output_file, 'w') as f:
json.dump(file_list, f, indent=4)
print(f'Wrote {output_file}')
if split == 'train':
assert len(file_list) == 35217
elif split == 'validation':
assert len(file_list) == 5031
elif split == 'test':
assert len(file_list) == 5289
# Create reduced test index file
split = 'test_reduced'
file_list = {}
i = 0
for d in egmd_dict_list_all:
if d['split'] == 'test':
midi_file = os.path.join(base_dir, d['midi_filename'])
if '_5.midi' in midi_file or '_10.midi' in midi_file:
egmd_id = d['midi_filename'].split('.')[0]
mix_audio_file = os.path.join(base_dir, d['audio_filename'])
n_frames = get_audio_file_info(mix_audio_file)[1]
notes_file = midi_file.replace('.midi', '_notes.npy')
note_events_file = midi_file.replace('.midi', '_note_events.npy')
file_list[i] = {
'egmd_id': egmd_id,
'n_frames': n_frames,
'mix_audio_file': mix_audio_file,
'notes_file': notes_file,
'note_events_file': note_events_file,
'midi_file': midi_file,
'program': [128],
'is_drum': [1],
}
i += 1
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json')
with open(output_file, 'w') as f:
json.dump(file_list, f, indent=4)
print(f'Wrote {output_file}')
assert len(file_list) == 246