Spaces:
Runtime error
Runtime error
import copy | |
import os | |
import random | |
import numpy as np | |
import torch | |
import muspy | |
from prettytable import PrettyTable | |
from constants import PitchToken, DurationToken | |
import constants | |
import generation_config | |
def set_seed(seed): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
np.random.seed(seed) | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
def append_dict(dest_d, source_d): | |
for k, v in source_d.items(): | |
dest_d[k].append(v) | |
def print_params(model): | |
table = PrettyTable(["Modules", "Parameters"]) | |
total_params = 0 | |
for name, parameter in model.named_parameters(): | |
if not parameter.requires_grad: | |
continue | |
param = parameter.numel() | |
table.add_row([name, param]) | |
total_params += param | |
print(table) | |
print(f"Total Trainable Parameters: {total_params}") | |
return total_params | |
def print_divider(): | |
print('—' * 40) | |
# Builds multitrack pianoroll (mtp) from content tensor containing logits and | |
# structure binary tensor | |
# c_logits: num_nodes x MAX_SIMU_TOKENS x d_token | |
# s_tensor: n_batches x n_bars x n_tracks x n_timesteps | |
def mtp_from_logits(c_logits, s_tensor): | |
mtp = torch.zeros((s_tensor.size(0), s_tensor.size(1), s_tensor.size(2), | |
s_tensor.size(3), c_logits.size(-2), c_logits.size(-1)), | |
device=c_logits.device, dtype=c_logits.dtype) | |
size = mtp.size() | |
mtp = mtp.reshape(-1, mtp.size(-2), mtp.size(-1)) | |
silence = torch.zeros((mtp.size(-2), mtp.size(-1)), | |
device=c_logits.device, dtype=c_logits.dtype) | |
# Create silences with pitch EOS and PAD tokens | |
silence[0, PitchToken.EOS.value] = 1. | |
silence[1:, PitchToken.PAD.value] = 1. | |
# Fill the multitrack pianoroll | |
mtp[s_tensor.bool().reshape(-1)] = c_logits | |
mtp[torch.logical_not(s_tensor.bool().reshape(-1))] = silence | |
mtp = mtp.reshape(size) | |
return mtp | |
# mtp: n_bars x n_tracks x n_timesteps x MAX_SIMU_TOKENS x d_token | |
def muspy_from_mtp(mtp): | |
n_timesteps = mtp.size(2) | |
resolution = n_timesteps // 4 | |
# Collapse bars dimension | |
mtp = mtp.permute(1, 0, 2, 3, 4) | |
size = (mtp.shape[0], -1, mtp.shape[3], mtp.shape[4]) | |
mtp = mtp.reshape(*size) | |
tracks = [] | |
for track_idx in range(mtp.size(0)): | |
notes = [] | |
for t in range(mtp.size(1)): | |
for note_idx in range(mtp.size(2)): | |
# Compute pitch and duration values | |
pitch = mtp[track_idx, t, note_idx, :constants.N_PITCH_TOKENS] | |
dur = mtp[track_idx, t, note_idx, constants.N_PITCH_TOKENS:] | |
pitch, dur = torch.argmax(pitch), torch.argmax(dur) | |
if (pitch == PitchToken.EOS.value or | |
pitch == PitchToken.PAD.value or | |
dur == DurationToken.EOS.value or | |
dur == DurationToken.PAD.value): | |
# The chord contains no additional notes, go to next chord | |
break | |
if (pitch == PitchToken.SOS.value or | |
pitch == PitchToken.SOS.value): | |
# Skip this note | |
continue | |
# Remapping duration values from [0, 95] to [1, 96] | |
dur = dur + 1 | |
# Do not sustain notes beyond sequence limit | |
dur = min(dur.item(), mtp.size(1) - t) | |
notes.append(muspy.Note(t, pitch.item(), dur, 64)) | |
track_name = constants.TRACKS[track_idx] | |
midi_program = generation_config.MIDI_PROGRAMS[track_name] | |
is_drum = (track_name == 'Drums') | |
track = muspy.Track( | |
name=track_name, | |
is_drum=is_drum, | |
program=(0 if is_drum else midi_program), | |
notes=copy.deepcopy(notes) | |
) | |
tracks.append(track) | |
meta = muspy.Metadata() | |
music = muspy.Music(tracks=tracks, metadata=meta, resolution=resolution) | |
return music | |
def loop_muspy_music(muspy_music, n_loop, num_bars, resolution): | |
# Get a deep copy of the original music object to avoid modifying it | |
looped_music = copy.deepcopy(muspy_music) | |
# Loop over the number of times we want to repeat the sequence | |
for i in range(1, n_loop): | |
# Loop over each track in the original music object | |
for track_idx, track in enumerate(muspy_music.tracks): | |
# Adjust the start times of the notes for each repetition and | |
# add them to the corresponding track in the looped_music object | |
for note in track.notes: | |
new_note = copy.deepcopy(note) | |
new_note.time += i * num_bars * 4 * resolution | |
looped_music.tracks[track_idx].notes.append(new_note) | |
return looped_music | |
def add_end_of_track(muspy_music): | |
bar_length = 32 | |
# Determine the last timestep in the song | |
last_time = max( | |
note.start for track in muspy_music.tracks for note in track.notes) | |
# Calculate the position of the current bar line | |
current_bar_position = (last_time // bar_length) * bar_length | |
# Check if there's any note starting or extending | |
# into the current bar's end position | |
notes = [ | |
note for track in muspy_music.tracks for note in track.notes | |
if note.start + note.duration == current_bar_position + bar_length | |
] | |
if not notes: | |
# If no notes extend into the current bar's end position, | |
# add a low-pitched note | |
muspy_music.tracks[0].notes.append( | |
muspy.Note(time=current_bar_position + bar_length, | |
duration=1, pitch=70, velocity=64)) | |
return muspy_music | |
def save_midi(muspy_song, save_dir, name): | |
# Add low MIDI note at last timestep | |
muspy_song.tracks[0].notes.append( | |
muspy.Note(time=61, duration=3, pitch=23, velocity=1)) | |
muspy.write_midi(os.path.join(save_dir, name + ".mid"), muspy_song) | |
def save_audio(muspy_song, save_dir, name): | |
soundfont_path = (generation_config.SOUNDFONT_PATH | |
if os.path.exists(generation_config.SOUNDFONT_PATH) | |
else None) | |
muspy.write_audio(os.path.join(save_dir, name + ".wav"), muspy_song, | |
soundfont_path=soundfont_path) | |