|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import json |
|
import time |
|
import hashlib |
|
import requests |
|
import tarfile |
|
import warnings |
|
import argparse |
|
from typing import Tuple, Union, Optional, List, Dict, Any |
|
from tqdm import tqdm |
|
import numpy as np |
|
from collections import Counter |
|
from utils.note_event_dataclasses import Note |
|
from utils.note2event import note2note_event |
|
from utils.midi import note_event2midi |
|
from utils.note2event import slice_multiple_note_events_and_ties_to_bundle |
|
from utils.event2note import merge_zipped_note_events_and_ties_to_notes |
|
from utils.metrics import compute_track_metrics |
|
from utils.tokenizer import EventTokenizer, NoteEventTokenizer |
|
from utils.note_event_dataclasses import Note, NoteEvent, Event |
|
from config.vocabulary import GM_INSTR_FULL, GM_INSTR_CLASS_PLUS |
|
from config.config import shared_cfg |
|
|
|
|
|
def get_checksum(file_path: os.PathLike, buffer_size: int = 65536) -> str: |
|
md5 = hashlib.md5() |
|
with open(file_path, "rb") as f: |
|
while True: |
|
data = f.read(buffer_size) |
|
if not data: |
|
break |
|
md5.update(data) |
|
return md5.hexdigest() |
|
|
|
|
|
def download_and_extract(data_home: os.PathLike, |
|
url: str, |
|
remove_tar_file: bool = True, |
|
check_sum: Optional[str] = None, |
|
zenodo_token: Optional[str] = None) -> None: |
|
|
|
file_name = url.split("/")[-1].split("?")[0] |
|
tar_path = os.path.join(data_home, file_name) |
|
|
|
if not os.path.exists(data_home): |
|
os.makedirs(data_home) |
|
|
|
if zenodo_token is not None: |
|
url_with_token = f"{url}&token={zenodo_token}" if "?download=1" in url else f"{url}?token={zenodo_token}" |
|
else: |
|
url_with_token = url |
|
|
|
response = requests.get(url_with_token, stream=True) |
|
|
|
|
|
if response.status_code != 200: |
|
print(f"Failed to download file. Status code: {response.status_code}") |
|
return |
|
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
with open(tar_path, "wb") as f: |
|
for chunk in tqdm(response.iter_content(chunk_size=8192), total=total_size // 8192, unit='KB', desc=file_name): |
|
f.write(chunk) |
|
|
|
_check_sum = get_checksum(tar_path) |
|
print(f"Checksum (md5): {_check_sum}") |
|
|
|
if check_sum is not None and check_sum != _check_sum: |
|
raise ValueError(f"Checksum doesn't match! Expected: {check_sum}, Actual: {_check_sum}") |
|
|
|
with tarfile.open(tar_path, "r:gz") as tar: |
|
tar.extractall(data_home) |
|
|
|
if remove_tar_file: |
|
os.remove(tar_path) |
|
|
|
|
|
def create_inverse_vocab(vocab: Dict) -> Dict: |
|
inverse_vocab = {} |
|
for k, vnp in vocab.items(): |
|
for v in vnp: |
|
inverse_vocab[v] = (vnp[0], k) |
|
return inverse_vocab |
|
|
|
|
|
def create_program2channel_vocab(program_vocab: Dict, drum_program: int = 128, force_assign_13_ch: bool = False): |
|
""" |
|
Create a direct map for programs to indices, instrument groups, and primary programs. |
|
|
|
Args: |
|
program_vocab (dict): A dictionary of program vocabularies. |
|
drum_program (int): The program number for drums. Default: 128. |
|
|
|
Returns: |
|
program2channel_vocab (dict): A dictionary of program to indices, instrument groups, and primary programs. |
|
e.g. { |
|
0: {'channel': 0, 'instrument_group': 'Piano', 'primary_program': 0}, |
|
1: {'channel': 1, 'instrument_group': 'Chromatic Percussion', 'primary_program': 8}, |
|
... |
|
100: {'channel': 11, 'instrument_group': 'Singing Voice', 'primary_program': 100}, |
|
128: {'channel': 12, 'instrument_group': 'Drums', 'primary_program': 128} |
|
} |
|
"primary_program" is not used now. |
|
|
|
num_channels (int): The number of channels. Typically length of program vocab + 1 (for drums) |
|
|
|
""" |
|
num_channels = len(program_vocab) + 1 |
|
program2channel_vocab = {} |
|
for idx, (instrument_group, programs) in enumerate(program_vocab.items()): |
|
if idx > num_channels: |
|
raise ValueError( |
|
f"📕 The number of channels ({num_channels}) is less than the number of instrument groups ({idx})") |
|
for program in programs: |
|
if program in program2channel_vocab: |
|
raise ValueError(f"📕 program {program} is duplicated in program_vocab") |
|
else: |
|
program2channel_vocab[program] = { |
|
"channel": int(idx), |
|
"instrument_group": str(instrument_group), |
|
"primary_program": int(programs[0]), |
|
} |
|
|
|
|
|
if drum_program in program2channel_vocab.keys(): |
|
raise ValueError( |
|
f"📕 drum_program {drum_program} is duplicated in program_vocab. program_vocab should not include drum or program 128" |
|
) |
|
else: |
|
program2channel_vocab[drum_program] = { |
|
"channel": idx + 1, |
|
"instrument_group": "Drums", |
|
"primary_program": drum_program, |
|
} |
|
return program2channel_vocab, num_channels |
|
|
|
|
|
def write_model_output_as_npy(data, output_dir, track_id): |
|
output_dir = os.path.join(output_dir, "model_output") |
|
os.makedirs(output_dir, exist_ok=True) |
|
output_file = os.path.join(output_dir, f"output_{track_id}.npy") |
|
np.save(output_file, data, allow_pickle=True) |
|
|
|
|
|
def write_model_output_as_midi(notes: List[Note], |
|
output_dir: os.PathLike, |
|
track_id: str, |
|
output_inverse_vocab: Optional[Dict] = None, |
|
output_dir_suffix: Optional[str] = None) -> None: |
|
|
|
if output_dir_suffix is not None: |
|
output_dir = os.path.join(output_dir, f"model_output/{output_dir_suffix}") |
|
else: |
|
output_dir = os.path.join(output_dir, "model_output") |
|
os.makedirs(output_dir, exist_ok=True) |
|
output_file = os.path.join(output_dir, f"{track_id}.mid") |
|
|
|
if output_inverse_vocab is not None: |
|
|
|
new_notes = [] |
|
for note in notes: |
|
if note.is_drum: |
|
new_notes.append(note) |
|
else: |
|
new_notes.append( |
|
Note(is_drum=note.is_drum, |
|
program=output_inverse_vocab.get(note.program, [note.program])[0], |
|
onset=note.onset, |
|
offset=note.offset, |
|
pitch=note.pitch, |
|
velocity=note.velocity)) |
|
|
|
note_events = note2note_event(new_notes, return_activity=False) |
|
note_event2midi(note_events, output_file, output_inverse_vocab=output_inverse_vocab) |
|
|
|
|
|
def write_err_cnt_as_json( |
|
track_id: str, |
|
output_dir: os.PathLike, |
|
output_dir_suffix: Optional[str] = None, |
|
note_err_cnt: Optional[Counter] = None, |
|
note_event_err_cnt: Optional[Counter] = None, |
|
): |
|
|
|
if output_dir_suffix is not None: |
|
output_dir = os.path.join(output_dir, f"model_output/{output_dir_suffix}") |
|
else: |
|
output_dir = os.path.join(output_dir, "model_output") |
|
os.makedirs(output_dir, exist_ok=True) |
|
output_file = os.path.join(output_dir, f"error_count_{track_id}.json") |
|
|
|
output_dict = {} |
|
if note_err_cnt is not None: |
|
output_dict['note_err_cnt'] = dict(note_err_cnt) |
|
if note_event_err_cnt is not None: |
|
output_dict['note_event_err_cnt'] = dict(note_event_err_cnt) |
|
output_str = json.dumps(output_dict, indent=4) |
|
|
|
with open(output_file, 'w') as json_file: |
|
json_file.write(output_str) |
|
|
|
|
|
class Timer: |
|
"""A simple timer class to measure elapsed time. |
|
Usage: |
|
|
|
with Timer() as t: |
|
# Your code here |
|
time.sleep(2) |
|
t.print_elapsed_time() |
|
|
|
""" |
|
|
|
def __init__(self) -> None: |
|
self.start_time = None |
|
self.end_time = None |
|
|
|
def start(self) -> None: |
|
self.start_time = time.time() |
|
|
|
def stop(self) -> None: |
|
self.end_time = time.time() |
|
|
|
def elapsed_time(self) -> float: |
|
if self.start_time is None: |
|
raise ValueError("Timer has not been started yet.") |
|
if self.end_time is None: |
|
raise ValueError("Timer has not been stopped yet.") |
|
return self.end_time - self.start_time |
|
|
|
def print_elapsed_time(self, message: Optional[str] = None) -> float: |
|
elapsed_seconds = self.elapsed_time() |
|
minutes, seconds = divmod(elapsed_seconds, 60) |
|
milliseconds = (elapsed_seconds % 1) * 1000 |
|
if message is not None: |
|
text = f"⏰ {message}: {int(minutes)}m {int(seconds)}s {milliseconds:.2f}ms" |
|
else: |
|
text = f"⏰ elapse time: {int(minutes)}m {int(seconds)}s {milliseconds:.2f}ms" |
|
print(text) |
|
return elapsed_seconds |
|
|
|
def reset(self) -> None: |
|
self.start_time = None |
|
self.end_time = None |
|
|
|
def __enter__(self) -> 'Timer': |
|
self.start() |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None: |
|
self.stop() |
|
|
|
|
|
def merge_file_lists(file_lists: List[Dict]) -> Dict[int, Any]: |
|
""" Merge file lists from different datasets, and return a reindexed |
|
dictionary of file list.""" |
|
merged_file_list = {} |
|
index = 0 |
|
for file_list in file_lists: |
|
for v in file_list.values(): |
|
merged_file_list[index] = v |
|
index += 1 |
|
return merged_file_list |
|
|
|
|
|
def merge_splits(splits: List[str], dataset_name: Union[str, List[str]]) -> Dict[int, Any]: |
|
""" |
|
merge_splits: |
|
- Merge multiple splits from different datasets, and return a reindexed |
|
dictionary of file list. |
|
- It is also possible to merge splits from different datasets. |
|
|
|
""" |
|
n_splits = len(splits) |
|
if n_splits > 1 and isinstance(dataset_name, str): |
|
dataset_name = [dataset_name] * n_splits |
|
elif n_splits > 1 and isinstance(dataset_name, list) and len(dataset_name) != n_splits: |
|
raise ValueError("The number of dataset names in list must be equal to the number of splits.") |
|
else: |
|
pass |
|
|
|
|
|
data_home = shared_cfg['PATH']['data_home'] |
|
file_lists = [] |
|
for i, s in enumerate(splits): |
|
json_file = (f"{data_home}/yourmt3_indexes/{dataset_name[i]}_{s}_file_list.json") |
|
|
|
|
|
if not os.path.exists(json_file): |
|
warnings.warn( |
|
f"File list {json_file} does not exist. If you don't have a complete package of dataset, ignore this warning..." |
|
) |
|
return {} |
|
|
|
with open(json_file, 'r') as j: |
|
file_lists.append(json.load(j)) |
|
merged_file_list = merge_file_lists(file_lists) |
|
return merged_file_list |
|
|
|
|
|
def reindex_file_list_keys(file_list: Dict[str, Any]) -> Dict[int, Any]: |
|
""" Reindex file list keys from 0 to total count.""" |
|
reindexed_file_list = {} |
|
for i, (k, v) in enumerate(file_list.items()): |
|
reindexed_file_list[i] = v |
|
return reindexed_file_list |
|
|
|
|
|
def remove_ids_from_file_list(file_list: Dict[str, Any], |
|
selected_ids: List[int], |
|
reindex: bool = True) -> Dict[int, Any]: |
|
""" Remove selected ids from file list.""" |
|
key = None |
|
for v in file_list.values(): |
|
|
|
for k in v.keys(): |
|
if 'id' in k: |
|
key = k |
|
break |
|
if key: |
|
break |
|
|
|
if key is None: |
|
raise ValueError("No key contains 'id'.") |
|
|
|
|
|
selected_ids = [str(id) for id in selected_ids] |
|
file_list = {k: v for k, v in file_list.items() if str(v[key]) not in selected_ids} |
|
if reindex: |
|
return reindex_file_list_keys(file_list) |
|
else: |
|
return file_list |
|
|
|
|
|
def deduplicate_splits(split_a: Union[str, Dict], |
|
split_b: Union[str, Dict], |
|
dataset_name: Optional[str] = None, |
|
reindex: bool = True) -> Dict[int, Any]: |
|
"""Remove overlapping splits in file_list A with splits from file_list B, |
|
and return a reindexed dictionary of files.""" |
|
data_home = shared_cfg['PATH']['data_home'] |
|
|
|
if isinstance(split_a, str): |
|
json_file_a = (f"{data_home}/yourmt3_indexes/{dataset_name}_{split_a}_file_list.json") |
|
with open(json_file_a, 'r') as j: |
|
file_list_a = json.load(j) |
|
elif isinstance(split_a, dict): |
|
file_list_a = split_a |
|
|
|
if isinstance(split_b, str): |
|
json_file_b = (f"{data_home}/yourmt3_indexes/{dataset_name}_{split_b}_file_list.json") |
|
with open(json_file_b, 'r') as j: |
|
file_list_b = json.load(j) |
|
elif isinstance(split_b, dict): |
|
file_list_b = split_b |
|
|
|
|
|
id_key = None |
|
for v in file_list_a.values(): |
|
for k in v.keys(): |
|
if 'id' in k: |
|
id_key = k |
|
break |
|
if id_key: |
|
break |
|
if id_key is None: |
|
raise ValueError("No key contains 'id' in file_list_a.") |
|
|
|
|
|
ids_b = set(str(v.get(id_key, '')) for v in file_list_b.values()) |
|
|
|
|
|
ids_a = [str(v.get(id_key, '')) for v in file_list_a.values()] |
|
|
|
|
|
ids_to_remove = list(set(ids_a).intersection(ids_b)) |
|
filtered_file_list_a = remove_ids_from_file_list(file_list_a, ids_to_remove, reindex) |
|
|
|
return filtered_file_list_a |
|
|
|
|
|
def merge_vocab(vocab_list: List[Dict]) -> Dict[str, Any]: |
|
""" Merge file lists from different datasets, and return a reindexed |
|
dictionary of file list.""" |
|
merged_vocab = {} |
|
for vocab in vocab_list: |
|
for k, v in vocab.items(): |
|
if k not in merged_vocab.keys(): |
|
merged_vocab[k] = v |
|
return merged_vocab |
|
|
|
|
|
def assert_note_events_almost_equal(actual_note_events, |
|
predicted_note_events, |
|
ignore_time=False, |
|
ignore_activity=True, |
|
delta=5.1e-3): |
|
""" |
|
Asserts that the given lists of Note instances are equal up to a small |
|
floating-point tolerance, similar to `assertAlmostEqual` of `unittest`. |
|
Tolerance is 5e-3 by default, which is 5 ms for 100 ticks-per-second. |
|
|
|
If `ignore_time` is True, then the time field is ignored. (useful for |
|
comparing tie note events, default is False) |
|
|
|
If `ignore_activity` is True, then the activity field is ignored (default |
|
is True). |
|
""" |
|
assert len(actual_note_events) == len(predicted_note_events) |
|
for j, (actual_note_event, predicted_note_event) in enumerate(zip(actual_note_events, predicted_note_events)): |
|
if ignore_time is False: |
|
assert abs(actual_note_event.time - predicted_note_event.time) <= delta, (j, actual_note_event, |
|
predicted_note_event) |
|
assert actual_note_event.is_drum == predicted_note_event.is_drum, (j, actual_note_event, predicted_note_event) |
|
assert actual_note_event.program == predicted_note_event.program, (j, actual_note_event, predicted_note_event) |
|
assert actual_note_event.pitch == predicted_note_event.pitch, (j, actual_note_event, predicted_note_event) |
|
assert actual_note_event.velocity == predicted_note_event.velocity, (j, actual_note_event, predicted_note_event) |
|
if ignore_activity is False: |
|
assert actual_note_event.activity == predicted_note_event.activity, (j, actual_note_event, |
|
predicted_note_event) |
|
|
|
|
|
def note_event2token2note_event_sanity_check(note_events: List[NoteEvent], |
|
notes: List[Note], |
|
report_err_cnt=False) -> Counter: |
|
|
|
max_time = note_events[-1].time |
|
num_segs = int(max_time * 16000 // 32757 + 1) |
|
seg_len_sec = 32767 / 16000 |
|
start_times = [i * seg_len_sec for i in range(num_segs)] |
|
note_event_segments = slice_multiple_note_events_and_ties_to_bundle( |
|
note_events, |
|
start_times, |
|
seg_len_sec, |
|
) |
|
|
|
|
|
tokenizer = NoteEventTokenizer() |
|
token_array = np.zeros((num_segs, 1024), dtype=np.int32) |
|
for i, tup in enumerate(list(zip(*note_event_segments.values()))): |
|
padded_tokens = tokenizer.encode_plus(*tup) |
|
token_array[i, :] = padded_tokens |
|
|
|
|
|
zipped_note_events_and_tie, list_events, err_cnt = tokenizer.decode_list_batches([token_array], |
|
start_times, |
|
return_events=True) |
|
if report_err_cnt: |
|
|
|
err_cnt_all = err_cnt |
|
else: |
|
assert len(err_cnt) == 0 |
|
err_cnt_all = Counter() |
|
|
|
|
|
cnt_org_empty = 0 |
|
cnt_recon_empty = 0 |
|
for i, (recon_note_events, recon_tie_note_events, _, _) in enumerate(zipped_note_events_and_tie): |
|
org_note_events = note_event_segments['note_events'][i] |
|
org_tie_note_events = note_event_segments['tie_note_events'][i] |
|
if org_note_events == []: |
|
cnt_org_empty += 1 |
|
if recon_note_events == []: |
|
cnt_recon_empty += 1 |
|
|
|
|
|
|
|
|
|
for i, (recon_note_events, recon_tie_note_events, _, _) in enumerate(zipped_note_events_and_tie): |
|
org_note_events = note_event_segments['note_events'][i] |
|
org_tie_note_events = note_event_segments['tie_note_events'][i] |
|
|
|
org_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) |
|
org_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) |
|
recon_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) |
|
recon_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
recon_notes, err_cnt = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie, fix_offset=False) |
|
|
|
|
|
|
|
drum_metric, non_drum_metric, instr_metric = compute_track_metrics(recon_notes, |
|
notes, |
|
eval_vocab=GM_INSTR_FULL, |
|
onset_tolerance=0.005) |
|
if not np.isnan(non_drum_metric['offset_f']) and non_drum_metric['offset_f'] != 1.0: |
|
warnings.warn(f"non_drum_metric['offset_f'] = {non_drum_metric['offset_f']}") |
|
assert non_drum_metric['onset_f'] > 0.99 |
|
if not np.isnan(drum_metric['onset_f_drum']) and non_drum_metric['offset_f'] != 1.0: |
|
warnings.warn(f"drum_metric['offset_f'] = {drum_metric['offset_f']}") |
|
assert drum_metric['offset_f'] > 0.99 |
|
return err_cnt_all + Counter(err_cnt) |
|
|
|
|
|
def str2bool(v): |
|
""" |
|
Converts a string value to a boolean value. |
|
|
|
Args: |
|
v: The string value to convert. |
|
|
|
Returns: |
|
The boolean value equivalent of the input string. |
|
|
|
Raises: |
|
ArgumentTypeError: If the input string is not a valid boolean value. |
|
""" |
|
if v.lower() in ('yes', 'true', 't', 'y', '1'): |
|
return True |
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
|
return False |
|
else: |
|
raise argparse.ArgumentTypeError('Boolean value expected.') |
|
|
|
|
|
def freq_to_midi(freq): |
|
return round(69 + 12 * np.log2(freq / 440)) |
|
|
|
|
|
def dict_iterator(d: Dict): |
|
""" |
|
This function is used to iterate over a dictionary of lists. |
|
As an output, it yields a newly created instance of a dictionary |
|
""" |
|
for values in zip(*d.values()): |
|
yield {k: [v] for k, v in zip(d.keys(), values)} |
|
|
|
|
|
def extend_dict(dict1: dict, dict2: dict) -> None: |
|
""" |
|
Extends the lists in dict1 with the corresponding lists in dict2. |
|
Modifies dict1 in-place and does not return anything. |
|
|
|
Args: |
|
dict1 (dict): The dictionary to be extended. |
|
dict2 (dict): The dictionary with lists to extend dict1. |
|
|
|
Example: |
|
dict1 = {'a': [1,2,3], 'b':[4,5,6]} |
|
dict2 = {'a':[10], 'b':[17]} |
|
extend_dict_in_place(dict1, dict2) |
|
print(dict1) # Outputs: {'a': [1, 2, 3, 10], 'b': [4, 5, 6, 17]} |
|
""" |
|
for key in dict1: |
|
dict1[key].extend(dict2[key]) |
|
|