|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""metrics_test.py: |
|
|
|
This file contains tests for the following classes: |
|
• AMTMetrics |
|
|
|
""" |
|
import unittest |
|
import warnings |
|
import torch |
|
import numpy as np |
|
from utils.metrics import AMTMetrics |
|
from utils.metrics import compute_track_metrics |
|
|
|
|
|
class TestAMTMetrics(unittest.TestCase): |
|
|
|
def test_individual_attributes(self): |
|
metric = AMTMetrics() |
|
|
|
|
|
metric.onset_f.update(0.5) |
|
|
|
|
|
metric.onset_f(0.5) |
|
|
|
|
|
metric.onset_f(0, weight=1.0) |
|
|
|
|
|
computed_value = metric.onset_f.compute() |
|
self.assertAlmostEqual(computed_value, 0.3333333333333333) |
|
|
|
|
|
metric.onset_f.reset() |
|
with self.assertWarns(UserWarning): |
|
torch._assert(metric.onset_f.compute(), torch.nan) |
|
|
|
|
|
with self.assertWarns(UserWarning): |
|
computed_metrics = metric.bulk_compute() |
|
|
|
def test_bulk_update_and_compute(self): |
|
metric = AMTMetrics() |
|
|
|
|
|
d1 = {'onset_f': 0.5, 'offset_f': 0.5} |
|
metric.bulk_update(d1) |
|
|
|
|
|
d2 = {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}} |
|
metric.bulk_update(d2) |
|
|
|
|
|
computed_metrics = metric.bulk_compute() |
|
|
|
|
|
self.assertIn('onset_f', computed_metrics) |
|
self.assertIn('offset_f', computed_metrics) |
|
|
|
|
|
self.assertAlmostEqual(computed_metrics['onset_f'], 0.5) |
|
self.assertAlmostEqual(computed_metrics['offset_f'], 0.5) |
|
|
|
def test_compute_track_metrics_singing(self): |
|
from config.vocabulary import SINGING_SOLO_CLASS, GM_INSTR_CLASS_PLUS |
|
from utils.event2note import note_event2note |
|
|
|
ref_notes_dict = np.load('extras/examples/singing_notes.npy', allow_pickle=True).tolist() |
|
ref_note_events_dict = np.load('extras/examples/singing_note_events.npy', allow_pickle=True).tolist() |
|
est_notes, _ = note_event2note(ref_note_events_dict['note_events']) |
|
ref_notes = ref_notes_dict['notes'] |
|
|
|
metric = AMTMetrics(prefix=f'test/', extra_classes=[k for k in SINGING_SOLO_CLASS.keys()]) |
|
drum_metric, non_drum_metric, instr_metric = compute_track_metrics(est_notes, |
|
ref_notes, |
|
eval_vocab=SINGING_SOLO_CLASS, |
|
eval_drum_vocab=None, |
|
onset_tolerance=0.05) |
|
metric.bulk_update(drum_metric) |
|
metric.bulk_update(non_drum_metric) |
|
metric.bulk_update(instr_metric) |
|
computed_metrics = metric.bulk_compute() |
|
cnt = 0 |
|
for k, v in computed_metrics.items(): |
|
if 'Singing Voice' in k: |
|
self.assertEqual(v, 1.0) |
|
cnt += 1 |
|
self.assertEqual(cnt, 6) |
|
|
|
metric = AMTMetrics(prefix=f'test/', extra_classes=[k for k in GM_INSTR_CLASS_PLUS.keys()]) |
|
drum_metric, non_drum_metric, instr_metric = compute_track_metrics(est_notes, |
|
ref_notes, |
|
eval_vocab=GM_INSTR_CLASS_PLUS, |
|
eval_drum_vocab=None, |
|
onset_tolerance=0.05) |
|
metric.bulk_update(drum_metric) |
|
metric.bulk_update(non_drum_metric) |
|
metric.bulk_update(instr_metric) |
|
computed_metrics = metric.bulk_compute() |
|
cnt = 0 |
|
for k, v in computed_metrics.items(): |
|
if 'Singing Voice' in k: |
|
self.assertEqual(v, 1.0) |
|
cnt += 1 |
|
self.assertEqual(cnt, 6) |
|
|
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|