Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import json | |
import time | |
import math | |
import torch | |
import os | |
from matplotlib import pyplot as plt | |
import generation_config | |
import constants | |
from model import VAE | |
from utils import set_seed | |
from utils import mtp_from_logits, muspy_from_mtp, set_seed | |
from utils import print_divider | |
from utils import loop_muspy_music, save_midi, save_audio | |
from plots import plot_pianoroll, plot_structure | |
def generate_music(vae, z, s_cond=None, s_tensor_cond=None): | |
# Decoder pass to get structure and content logits | |
s_logits, c_logits = vae.decoder(z, s_cond) | |
if s_tensor_cond is not None: | |
s_tensor = s_tensor_cond | |
else: | |
# Compute binary structure tensor from logits | |
s_tensor = vae.decoder._binary_from_logits(s_logits) | |
# Build (n_batches x n_bars x n_tracks x n_timesteps x Sigma x d_token) | |
# multitrack pianoroll tensor containing logits for each activation and | |
# hard silences elsewhere | |
mtp = mtp_from_logits(c_logits, s_tensor) | |
return mtp, s_tensor | |
def save(mtp, dir, s_tensor=None, n_loops=1, audio=True, z=None, | |
looped_only=False, plot_proll=False, plot_struct=False): | |
n_bars = mtp.size(1) | |
resolution = mtp.size(3) // 4 | |
# Clear matplotlib cache (this solves formatting problems with first plot) | |
plt.clf() | |
# Iterate over batches | |
for i in range(mtp.size(0)): | |
# Create the directory if it does not exist | |
save_dir = os.path.join(dir, str(i)) | |
os.makedirs(save_dir, exist_ok=True) | |
if not looped_only: | |
# Generate MIDI song from multitrack pianoroll and save | |
muspy_song = muspy_from_mtp(mtp[i]) | |
print("Saving MIDI sequence {} in {}...".format(str(i + 1), | |
save_dir)) | |
save_midi(muspy_song, save_dir, name='generated') | |
if audio: | |
print("Saving audio sequence {} in {}...".format(str(i + 1), | |
save_dir)) | |
save_audio(muspy_song, save_dir, name='generated') | |
if plot_proll: | |
plot_pianoroll(muspy_song, save_dir) | |
if plot_struct: | |
plot_structure(s_tensor[i].cpu(), save_dir) | |
if n_loops > 1: | |
# Copy the generated sequence n_loops times and save the looped | |
# MIDI and audio files | |
print("Saving MIDI sequence " | |
"{} looped {} times in {}...".format(str(i + 1), n_loops, | |
save_dir)) | |
extended = loop_muspy_music(muspy_song, n_loops, | |
n_bars, resolution) | |
save_midi(extended, save_dir, name='extended') | |
if audio: | |
print("Saving audio sequence " | |
"{} looped {} times in {}...".format(str(i + 1), n_loops, | |
save_dir)) | |
save_audio(extended, save_dir, name='extended') | |
# Save structure | |
with open(os.path.join(save_dir, 'structure.json'), 'wb') as file: | |
file.write(json.dumps(s_tensor[i].tolist()).encode('utf-8')) | |
# Save z | |
if z[i] is not None: | |
torch.save(z[i], os.path.join(save_dir, 'z')) | |
print() | |
def generate_z(bs, d_model, device): | |
shape = (bs, d_model) | |
z_norm = torch.normal( | |
torch.zeros(shape, device=device), | |
torch.ones(shape, device=device) | |
) | |
return z_norm | |
def load_model(model_dir, device): | |
checkpoint = torch.load(os.path.join(model_dir, 'checkpoint'), | |
map_location='cpu') | |
configuration = torch.load(os.path.join(model_dir, 'configuration'), | |
map_location='cpu') | |
state_dict = checkpoint['model_state_dict'] | |
model = VAE(**configuration['model'], device=device).to(device) | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model, configuration | |
def main(): | |
parser = argparse.ArgumentParser( | |
description='Generates MIDI music with a trained model.' | |
) | |
parser.add_argument( | |
'model_dir', | |
type=str, help='Directory of the model.' | |
) | |
parser.add_argument( | |
'output_dir', | |
type=str, | |
help='Directory to save the generated MIDI files.' | |
) | |
parser.add_argument( | |
'--n', | |
type=int, | |
default=5, | |
help='Number of sequences to be generated. Default is 5.' | |
) | |
parser.add_argument( | |
'--n_loops', | |
type=int, | |
default=1, | |
help="If greater than 1, outputs an additional MIDI file containing " | |
"the sequence looped n_loops times." | |
) | |
parser.add_argument( | |
'--no_audio', | |
action='store_true', | |
default=False, | |
help="Flag to disable audio files generation." | |
) | |
parser.add_argument( | |
'--s_file', | |
type=str, | |
help='Path to the JSON file containing the binary structure tensor.' | |
) | |
parser.add_argument( | |
'--z_file', | |
type=str, | |
help='' | |
) | |
parser.add_argument( | |
'--z_change', | |
action='store_true', | |
default=False, | |
help='' | |
) | |
parser.add_argument( | |
'--use_gpu', | |
action='store_true', | |
default=False, | |
help='Flag to enable GPU usage.' | |
) | |
parser.add_argument( | |
'--gpu_id', | |
type=int, | |
default='0', | |
help='Index of the GPU to be used. Default is 0.' | |
) | |
parser.add_argument( | |
'--seed', | |
type=int | |
) | |
args = parser.parse_args() | |
if args.seed is not None: | |
set_seed(args.seed) | |
audio = not args.no_audio | |
device = torch.device("cuda") if args.use_gpu else torch.device("cpu") | |
if args.use_gpu: | |
torch.cuda.set_device(args.gpu_id) | |
print_divider() | |
print("Loading the model on {} device...".format(device)) | |
model, configuration = load_model(args.model_dir, device) | |
d_model = configuration['model']['d'] | |
n_bars = configuration['model']['n_bars'] | |
n_tracks = constants.N_TRACKS | |
n_timesteps = 4 * configuration['model']['resolution'] | |
output_dir = args.output_dir | |
s, s_tensor = None, None | |
if args.s_file is not None: | |
print("Loading the structure tensor " | |
"from {}...".format(args.model_dir)) | |
# Load structure tensor from file | |
with open(args.s_file, 'r') as f: | |
s_tensor = json.load(f) | |
s_tensor = torch.tensor(s_tensor, dtype=bool) | |
# Check structure dimensions | |
dims = list(s_tensor.size()) | |
expected = [n_bars, n_tracks, n_timesteps] | |
if dims != expected: | |
if (len(dims) != len(expected) or dims[1:] != expected[1:] | |
or dims[0] > n_bars): | |
raise ValueError(f"Loaded structure tensor dimensions {dims} " | |
f"do not match expected dimensions {expected}") | |
elif dims[0] > n_bars: | |
raise ValueError(f"First structure tensor dimension {dims[0]} " | |
f"is higher than {n_bars}") | |
else: | |
# Repeat partial structure tensor | |
r = math.ceil(n_bars / dims[0]) | |
s_tensor = s_tensor.repeat(r, 1, 1) | |
s_tensor = s_tensor[:n_bars, ...] | |
# Avoid empty bars by creating a fake activation for each empty | |
# (n_tracks x n_timesteps) bar matrix in position [0, 0] | |
empty_mask = ~s_tensor.any(dim=-1).any(dim=-1) | |
if empty_mask.any(): | |
print("The provided structure tensor contains empty bars. Fake " | |
"track activations will be created to avoid processing " | |
"empty bars.") | |
idxs = torch.nonzero(empty_mask, as_tuple=True) | |
s_tensor[idxs + (0, 0)] = True | |
# Repeat structure along new batch dimension | |
s_tensor = s_tensor.unsqueeze(0).repeat(args.n, 1, 1, 1) | |
s = model.decoder._structure_from_binary(s_tensor) | |
print() | |
if args.z_file is not None: | |
print("Loading z...") | |
z = torch.load(args.z_file) | |
z = z.unsqueeze(0) | |
if args.z_change: | |
#e = 0.5 | |
e = 0.5 | |
z = z + e*(torch.rand(list(z.size())) - 0.5) | |
else: | |
print("Generating z...") | |
z = generate_z(args.n, d_model, device) | |
print("Generating music with the model...") | |
s_t = time.time() | |
mtp, s_tensor = generate_music(model, z, s, s_tensor) | |
print("Inference time: {:.3f} s".format(time.time() - s_t)) | |
print() | |
print("Saving MIDI files in {}...\n".format(output_dir)) | |
save(mtp, output_dir, s_tensor, args.n_loops, audio, z) | |
print("Finished saving MIDI files.") | |
print_divider() | |
if __name__ == '__main__': | |
main() | |