|
import copy |
|
import os |
|
import random |
|
|
|
import numpy as np |
|
import torch |
|
import muspy |
|
from prettytable import PrettyTable |
|
|
|
from constants import PitchToken, DurationToken |
|
import constants |
|
import generation_config |
|
|
|
|
|
def set_seed(seed): |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
|
|
|
def append_dict(dest_d, source_d): |
|
|
|
for k, v in source_d.items(): |
|
dest_d[k].append(v) |
|
|
|
|
|
def print_params(model): |
|
|
|
table = PrettyTable(["Modules", "Parameters"]) |
|
total_params = 0 |
|
|
|
for name, parameter in model.named_parameters(): |
|
|
|
if not parameter.requires_grad: |
|
continue |
|
|
|
param = parameter.numel() |
|
table.add_row([name, param]) |
|
total_params += param |
|
|
|
print(table) |
|
print(f"Total Trainable Parameters: {total_params}") |
|
|
|
return total_params |
|
|
|
|
|
def print_divider(): |
|
print('—' * 40) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def mtp_from_logits(c_logits, s_tensor): |
|
|
|
mtp = torch.zeros((s_tensor.size(0), s_tensor.size(1), s_tensor.size(2), |
|
s_tensor.size(3), c_logits.size(-2), c_logits.size(-1)), |
|
device=c_logits.device, dtype=c_logits.dtype) |
|
|
|
size = mtp.size() |
|
mtp = mtp.reshape(-1, mtp.size(-2), mtp.size(-1)) |
|
silence = torch.zeros((mtp.size(-2), mtp.size(-1)), |
|
device=c_logits.device, dtype=c_logits.dtype) |
|
|
|
|
|
silence[0, PitchToken.EOS.value] = 1. |
|
silence[1:, PitchToken.PAD.value] = 1. |
|
|
|
|
|
mtp[s_tensor.bool().reshape(-1)] = c_logits |
|
mtp[torch.logical_not(s_tensor.bool().reshape(-1))] = silence |
|
mtp = mtp.reshape(size) |
|
|
|
return mtp |
|
|
|
|
|
|
|
def muspy_from_mtp(mtp): |
|
|
|
n_timesteps = mtp.size(2) |
|
resolution = n_timesteps // 4 |
|
|
|
|
|
mtp = mtp.permute(1, 0, 2, 3, 4) |
|
size = (mtp.shape[0], -1, mtp.shape[3], mtp.shape[4]) |
|
mtp = mtp.reshape(*size) |
|
|
|
tracks = [] |
|
|
|
for track_idx in range(mtp.size(0)): |
|
|
|
notes = [] |
|
|
|
for t in range(mtp.size(1)): |
|
for note_idx in range(mtp.size(2)): |
|
|
|
|
|
pitch = mtp[track_idx, t, note_idx, :constants.N_PITCH_TOKENS] |
|
dur = mtp[track_idx, t, note_idx, constants.N_PITCH_TOKENS:] |
|
pitch, dur = torch.argmax(pitch), torch.argmax(dur) |
|
|
|
if (pitch == PitchToken.EOS.value or |
|
pitch == PitchToken.PAD.value or |
|
dur == DurationToken.EOS.value or |
|
dur == DurationToken.PAD.value): |
|
|
|
break |
|
|
|
if (pitch == PitchToken.SOS.value or |
|
pitch == PitchToken.SOS.value): |
|
|
|
continue |
|
|
|
|
|
dur = dur + 1 |
|
|
|
dur = min(dur.item(), mtp.size(1) - t) |
|
|
|
notes.append(muspy.Note(t, pitch.item(), dur, 64)) |
|
|
|
track_name = constants.TRACKS[track_idx] |
|
midi_program = generation_config.MIDI_PROGRAMS[track_name] |
|
is_drum = (track_name == 'Drums') |
|
|
|
track = muspy.Track( |
|
name=track_name, |
|
is_drum=is_drum, |
|
program=(0 if is_drum else midi_program), |
|
notes=copy.deepcopy(notes) |
|
) |
|
tracks.append(track) |
|
|
|
meta = muspy.Metadata() |
|
music = muspy.Music(tracks=tracks, metadata=meta, resolution=resolution) |
|
|
|
return music |
|
|
|
|
|
def loop_muspy_music(muspy_music, n_loop, num_bars, resolution): |
|
|
|
|
|
looped_music = copy.deepcopy(muspy_music) |
|
|
|
|
|
for i in range(1, n_loop): |
|
|
|
for track_idx, track in enumerate(muspy_music.tracks): |
|
|
|
|
|
for note in track.notes: |
|
new_note = copy.deepcopy(note) |
|
new_note.time += i * num_bars * 4 * resolution |
|
looped_music.tracks[track_idx].notes.append(new_note) |
|
|
|
return looped_music |
|
|
|
|
|
def add_end_of_track(muspy_music): |
|
|
|
bar_length = 32 |
|
|
|
|
|
last_time = max( |
|
note.start for track in muspy_music.tracks for note in track.notes) |
|
|
|
|
|
current_bar_position = (last_time // bar_length) * bar_length |
|
|
|
|
|
|
|
notes = [ |
|
note for track in muspy_music.tracks for note in track.notes |
|
if note.start + note.duration == current_bar_position + bar_length |
|
] |
|
|
|
if not notes: |
|
|
|
|
|
muspy_music.tracks[0].notes.append( |
|
muspy.Note(time=current_bar_position + bar_length, |
|
duration=1, pitch=70, velocity=64)) |
|
|
|
return muspy_music |
|
|
|
|
|
def save_midi(muspy_song, save_dir, name): |
|
|
|
muspy_song.tracks[0].notes.append( |
|
muspy.Note(time=61, duration=3, pitch=23, velocity=1)) |
|
muspy.write_midi(os.path.join(save_dir, name + ".mid"), muspy_song) |
|
|
|
|
|
def save_audio(muspy_song, save_dir, name): |
|
soundfont_path = (generation_config.SOUNDFONT_PATH |
|
if os.path.exists(generation_config.SOUNDFONT_PATH) |
|
else None) |
|
muspy.write_audio(os.path.join(save_dir, name + ".wav"), muspy_song, |
|
soundfont_path=soundfont_path) |
|
|