|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" note_event_roundtrip_test.py: |
|
This file contains tests for the round trip conversion between Note and |
|
NoteEvent and Event. |
|
|
|
Itinerary 1: |
|
NoteEvent β Event β Token β Event β NoteEvent |
|
|
|
Itinerary 2: |
|
Note β NoteEvent β Event β Token β Event β NoteEvent β Note |
|
|
|
Training: |
|
(Dataloader) NoteEvent β (augmentation) β Event β Token |
|
|
|
Evaluation : |
|
(Model side) Token β Event β NoteEvent β Note β (mir_eval) |
|
(Ground Truth) Note β (mir_eval) |
|
|
|
β’ This conversion may fail for unsorted and unquantized timing events. |
|
β’ Acitivity attribute of NoteEvent is often ignorable. |
|
|
|
""" |
|
import unittest |
|
import numpy as np |
|
from assert_fns import assert_notes_almost_equal |
|
from assert_fns import assert_note_events_almost_equal |
|
from assert_fns import assert_track_metrics_score1 |
|
|
|
from utils.note_event_dataclasses import Note, NoteEvent, Event |
|
from utils.note2event import note2note_event, note_event2event |
|
from utils.note2event import validate_notes, trim_overlapping_notes |
|
from utils.event2note import event2note_event, note_event2note |
|
from utils.tokenizer import EventTokenizer, NoteEventTokenizer |
|
from utils.midi import note_event2midi |
|
from utils.midi import midi2note |
|
from utils.note2event import slice_multiple_note_events_and_ties_to_bundle |
|
from utils.event2note import merge_zipped_note_events_and_ties_to_notes |
|
from utils.metrics import compute_track_metrics |
|
from config.vocabulary import GM_INSTR_FULL, SINGING_SOLO_CLASS |
|
|
|
|
|
class TestNoteEventRoundTrip1(unittest.TestCase): |
|
|
|
def setUp(self) -> None: |
|
self.note_events = [ |
|
NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), |
|
NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity=set()), |
|
NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity=set()), |
|
NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62, activity=set()), |
|
NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77, activity=set()), |
|
NoteEvent(is_drum=False, program=100, time=2.0, velocity=0, pitch=77, activity=set()), |
|
NoteEvent(is_drum=True, program=128, time=2.0, velocity=1, pitch=38, activity=set()), |
|
NoteEvent(is_drum=False, program=33, time=2.0, velocity=0, pitch=62, activity=set()) |
|
] |
|
self.tokenizer = EventTokenizer() |
|
|
|
def test_note_event_rt_ne2e2ne(self): |
|
""" NoteEvent β Event β NoteEvent """ |
|
note_events = self.note_events.copy() |
|
events = note_event2event(note_events=note_events, |
|
tie_note_events=None, |
|
start_time=0, sort=True) |
|
recon_note_events, unused_tie_note_events, unsued_last_activity, err_cnt = event2note_event( |
|
events, start_time=0, sort=True, tps=100) |
|
|
|
self.assertSequenceEqual(note_events, recon_note_events) |
|
self.assertEqual(len(err_cnt), 0) |
|
|
|
def test_note_event_rt_ne2e2t2e2ne(self): |
|
""" NoteEvent β Event β Token β Event β NoteEvent """ |
|
note_events = self.note_events.copy() |
|
events = note_event2event( |
|
note_events=note_events, tie_note_events=None, start_time=0, sort=True) |
|
tokens = self.tokenizer.encode(events) |
|
events = self.tokenizer.decode(tokens) |
|
recon_note_events, unused_tie_note_events, unsued_last_activity, err_cnt = event2note_event( |
|
events, start_time=0, sort=True, tps=100) |
|
|
|
self.assertSequenceEqual(note_events, recon_note_events) |
|
self.assertEqual(len(err_cnt), 0) |
|
|
|
class TestNoteEvent2(unittest.TestCase): |
|
|
|
def setUp(self) -> None: |
|
notes = [ |
|
Note(is_drum=False, program=33, onset=0, offset=1.5, pitch=60, velocity=1), |
|
Note(is_drum=True, program=128, onset=0.2, offset=0.21, pitch=36, velocity=1), |
|
Note(is_drum=False, program=25, onset=0.4, offset=1.1, pitch=55, velocity=1), |
|
Note(is_drum=True, program=128, onset=1, offset=1.01, pitch=42, velocity=1), |
|
Note(is_drum=False, program=33, onset=1.2, offset=1.8, pitch=80, velocity=1), |
|
Note(is_drum=False, program=33, onset=1.6, offset=2.0, pitch=62, velocity=1), |
|
Note(is_drum=False, program=100, onset=1.6, offset=2.0, pitch=77, velocity=1), |
|
Note(is_drum=False, program=98, onset=1.7, offset=2.0, pitch=77, velocity=1), |
|
Note(is_drum=True, program=128, onset=1.9, offset=1.91, pitch=38, velocity=1) |
|
] |
|
|
|
|
|
_notes = validate_notes(notes, fix=True) |
|
self.assertSequenceEqual(notes, _notes) |
|
_notes = trim_overlapping_notes(notes, sort=True) |
|
self.assertSequenceEqual(notes, _notes) |
|
|
|
self.notes = notes |
|
self.tokenizer = EventTokenizer() |
|
|
|
|
|
def test_note_event_rt_n2ne2e2t2e2ne2n(self): |
|
""" Note β NoteEvent β Event β Token β Event β NoteEvent β Note """ |
|
notes = self.notes.copy() |
|
note_events = note2note_event(notes=notes, sort=True) |
|
events = note_event2event(note_events=note_events, |
|
tie_note_events=None, |
|
start_time=0, |
|
tps=100, |
|
sort=True) |
|
tokens = self.tokenizer.encode(events) |
|
events = self.tokenizer.decode(tokens) |
|
recon_note_events, unused_tie_note_events, unsued_last_activity, err_cnt = event2note_event( |
|
events, start_time=0, sort=True, tps=100) |
|
self.assertEqual(len(err_cnt), 0) |
|
|
|
recon_notes, err_cnt = note_event2note(note_events=recon_note_events, sort=True) |
|
self.assertEqual(len(err_cnt), 0) |
|
assert_notes_almost_equal(notes, recon_notes, delta=5e-3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_encoding_from_midi_with_slicing_zz(self): |
|
src_midi_file = 'extras/examples/2106.mid' |
|
notes, max_time = midi2note(src_midi_file, quantize=False) |
|
note_events = note2note_event(notes=notes, sort=True) |
|
|
|
|
|
num_segs = int(max_time * 16000 // 32757 + 1) |
|
seg_len_sec = 32767 / 16000 |
|
start_times = [i * seg_len_sec for i in range(num_segs)] |
|
note_event_segments = slice_multiple_note_events_and_ties_to_bundle( |
|
note_events, |
|
start_times, |
|
seg_len_sec, |
|
) |
|
|
|
|
|
tokenizer = NoteEventTokenizer() |
|
token_array = np.zeros((num_segs, 1024), dtype=np.int32) |
|
for i, tup in enumerate(list(zip(*note_event_segments.values()))): |
|
padded_tokens = tokenizer.encode_plus(*tup) |
|
token_array[i, :] = padded_tokens |
|
|
|
|
|
zipped_note_events_and_tie, list_events, err_cnt = tokenizer.decode_list_batches( |
|
[token_array], start_times, return_events=True) |
|
self.assertEqual(len(err_cnt), 0) |
|
|
|
|
|
cnt_org_empty = 0 |
|
cnt_recon_empty = 0 |
|
for i, (recon_note_events, recon_tie_note_events, recon_last_activity, recon_start_times) in enumerate(zipped_note_events_and_tie): |
|
org_note_events = note_event_segments['note_events'][i] |
|
org_tie_note_events = note_event_segments['tie_note_events'][i] |
|
if org_note_events == []: |
|
cnt_org_empty += 1 |
|
if recon_note_events == []: |
|
cnt_recon_empty += 1 |
|
|
|
assert len(org_note_events) == len(recon_note_events) |
|
|
|
|
|
|
|
|
|
for i, (recon_note_events, recon_tie_note_events, recon_last_activity, recon_start_times) in enumerate(zipped_note_events_and_tie): |
|
org_note_events = note_event_segments['note_events'][i] |
|
org_tie_note_events = note_event_segments['tie_note_events'][i] |
|
|
|
org_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) |
|
org_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) |
|
recon_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) |
|
recon_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) |
|
|
|
assert_note_events_almost_equal(org_note_events, recon_note_events) |
|
assert_note_events_almost_equal(org_tie_note_events, recon_tie_note_events, ignore_time=True) |
|
|
|
|
|
recon_notes, err_cnt = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie, fix_offset=False) |
|
self.assertEqual(len(err_cnt), 0) |
|
assert_notes_almost_equal(notes, recon_notes, delta=5.1e-3) |
|
|
|
|
|
drum_metric, non_drum_metric, instr_metric = compute_track_metrics( |
|
recon_notes, notes, eval_vocab=GM_INSTR_FULL, onset_tolerance=0.005) |
|
self.assertEqual(non_drum_metric['onset_f'], 1.0) |
|
|
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|