|
import argparse |
|
import os |
|
import json |
|
import time |
|
import math |
|
|
|
import torch |
|
import os |
|
from matplotlib import pyplot as plt |
|
|
|
import generation_config |
|
import constants |
|
from model import VAE |
|
from utils import set_seed |
|
from utils import mtp_from_logits, muspy_from_mtp, set_seed |
|
from utils import print_divider |
|
from utils import loop_muspy_music, save_midi, save_audio |
|
from plots import plot_pianoroll, plot_structure |
|
|
|
|
|
def generate_music(vae, z, s_cond=None, s_tensor_cond=None): |
|
|
|
|
|
s_logits, c_logits = vae.decoder(z, s_cond) |
|
|
|
if s_tensor_cond is not None: |
|
s_tensor = s_tensor_cond |
|
else: |
|
|
|
s_tensor = vae.decoder._binary_from_logits(s_logits) |
|
|
|
|
|
|
|
|
|
mtp = mtp_from_logits(c_logits, s_tensor) |
|
|
|
return mtp, s_tensor |
|
|
|
|
|
def save(mtp, dir, s_tensor=None, n_loops=1, audio=True, z=None, |
|
looped_only=False, plot_proll=False, plot_struct=False): |
|
|
|
n_bars = mtp.size(1) |
|
resolution = mtp.size(3) // 4 |
|
|
|
plt.clf() |
|
|
|
|
|
for i in range(mtp.size(0)): |
|
|
|
|
|
save_dir = os.path.join(dir, str(i)) |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
if not looped_only: |
|
|
|
muspy_song = muspy_from_mtp(mtp[i]) |
|
print("Saving MIDI sequence {} in {}...".format(str(i + 1), |
|
save_dir)) |
|
save_midi(muspy_song, save_dir, name='generated') |
|
if audio: |
|
print("Saving audio sequence {} in {}...".format(str(i + 1), |
|
save_dir)) |
|
save_audio(muspy_song, save_dir, name='generated') |
|
|
|
if plot_proll: |
|
plot_pianoroll(muspy_song, save_dir) |
|
|
|
if plot_struct: |
|
plot_structure(s_tensor[i].cpu(), save_dir) |
|
|
|
if n_loops > 1: |
|
|
|
|
|
print("Saving MIDI sequence " |
|
"{} looped {} times in {}...".format(str(i + 1), n_loops, |
|
save_dir)) |
|
extended = loop_muspy_music(muspy_song, n_loops, |
|
n_bars, resolution) |
|
save_midi(extended, save_dir, name='extended') |
|
if audio: |
|
print("Saving audio sequence " |
|
"{} looped {} times in {}...".format(str(i + 1), n_loops, |
|
save_dir)) |
|
save_audio(extended, save_dir, name='extended') |
|
|
|
|
|
with open(os.path.join(save_dir, 'structure.json'), 'wb') as file: |
|
file.write(json.dumps(s_tensor[i].tolist()).encode('utf-8')) |
|
|
|
|
|
if z[i] is not None: |
|
torch.save(z[i], os.path.join(save_dir, 'z')) |
|
|
|
print() |
|
|
|
|
|
def generate_z(bs, d_model, device): |
|
shape = (bs, d_model) |
|
|
|
z_norm = torch.normal( |
|
torch.zeros(shape, device=device), |
|
torch.ones(shape, device=device) |
|
) |
|
|
|
return z_norm |
|
|
|
|
|
def load_model(model_dir, device): |
|
|
|
checkpoint = torch.load(os.path.join(model_dir, 'checkpoint'), |
|
map_location='cpu') |
|
configuration = torch.load(os.path.join(model_dir, 'configuration'), |
|
map_location='cpu') |
|
|
|
state_dict = checkpoint['model_state_dict'] |
|
|
|
model = VAE(**configuration['model'], device=device).to(device) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
return model, configuration |
|
|
|
|
|
def main(): |
|
|
|
parser = argparse.ArgumentParser( |
|
description='Generates MIDI music with a trained model.' |
|
) |
|
parser.add_argument( |
|
'model_dir', |
|
type=str, help='Directory of the model.' |
|
) |
|
parser.add_argument( |
|
'output_dir', |
|
type=str, |
|
help='Directory to save the generated MIDI files.' |
|
) |
|
parser.add_argument( |
|
'--n', |
|
type=int, |
|
default=5, |
|
help='Number of sequences to be generated. Default is 5.' |
|
) |
|
parser.add_argument( |
|
'--n_loops', |
|
type=int, |
|
default=1, |
|
help="If greater than 1, outputs an additional MIDI file containing " |
|
"the sequence looped n_loops times." |
|
) |
|
parser.add_argument( |
|
'--no_audio', |
|
action='store_true', |
|
default=False, |
|
help="Flag to disable audio files generation." |
|
) |
|
parser.add_argument( |
|
'--s_file', |
|
type=str, |
|
help='Path to the JSON file containing the binary structure tensor.' |
|
) |
|
parser.add_argument( |
|
'--z_file', |
|
type=str, |
|
help='' |
|
) |
|
parser.add_argument( |
|
'--z_change', |
|
action='store_true', |
|
default=False, |
|
help='' |
|
) |
|
parser.add_argument( |
|
'--use_gpu', |
|
action='store_true', |
|
default=False, |
|
help='Flag to enable GPU usage.' |
|
) |
|
parser.add_argument( |
|
'--gpu_id', |
|
type=int, |
|
default='0', |
|
help='Index of the GPU to be used. Default is 0.' |
|
) |
|
parser.add_argument( |
|
'--seed', |
|
type=int |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
if args.seed is not None: |
|
set_seed(args.seed) |
|
|
|
audio = not args.no_audio |
|
|
|
device = torch.device("cuda") if args.use_gpu else torch.device("cpu") |
|
if args.use_gpu: |
|
torch.cuda.set_device(args.gpu_id) |
|
|
|
print_divider() |
|
print("Loading the model on {} device...".format(device)) |
|
|
|
model, configuration = load_model(args.model_dir, device) |
|
|
|
d_model = configuration['model']['d'] |
|
n_bars = configuration['model']['n_bars'] |
|
n_tracks = constants.N_TRACKS |
|
n_timesteps = 4 * configuration['model']['resolution'] |
|
output_dir = args.output_dir |
|
|
|
s, s_tensor = None, None |
|
|
|
if args.s_file is not None: |
|
|
|
print("Loading the structure tensor " |
|
"from {}...".format(args.model_dir)) |
|
|
|
|
|
with open(args.s_file, 'r') as f: |
|
s_tensor = json.load(f) |
|
|
|
s_tensor = torch.tensor(s_tensor, dtype=bool) |
|
|
|
|
|
dims = list(s_tensor.size()) |
|
expected = [n_bars, n_tracks, n_timesteps] |
|
if dims != expected: |
|
if (len(dims) != len(expected) or dims[1:] != expected[1:] |
|
or dims[0] > n_bars): |
|
raise ValueError(f"Loaded structure tensor dimensions {dims} " |
|
f"do not match expected dimensions {expected}") |
|
elif dims[0] > n_bars: |
|
raise ValueError(f"First structure tensor dimension {dims[0]} " |
|
f"is higher than {n_bars}") |
|
else: |
|
|
|
r = math.ceil(n_bars / dims[0]) |
|
s_tensor = s_tensor.repeat(r, 1, 1) |
|
s_tensor = s_tensor[:n_bars, ...] |
|
|
|
|
|
|
|
empty_mask = ~s_tensor.any(dim=-1).any(dim=-1) |
|
if empty_mask.any(): |
|
print("The provided structure tensor contains empty bars. Fake " |
|
"track activations will be created to avoid processing " |
|
"empty bars.") |
|
idxs = torch.nonzero(empty_mask, as_tuple=True) |
|
s_tensor[idxs + (0, 0)] = True |
|
|
|
|
|
s_tensor = s_tensor.unsqueeze(0).repeat(args.n, 1, 1, 1) |
|
|
|
s = model.decoder._structure_from_binary(s_tensor) |
|
|
|
print() |
|
if args.z_file is not None: |
|
print("Loading z...") |
|
z = torch.load(args.z_file) |
|
z = z.unsqueeze(0) |
|
if args.z_change: |
|
|
|
e = 0.5 |
|
z = z + e*(torch.rand(list(z.size())) - 0.5) |
|
else: |
|
print("Generating z...") |
|
z = generate_z(args.n, d_model, device) |
|
|
|
print("Generating music with the model...") |
|
s_t = time.time() |
|
mtp, s_tensor = generate_music(model, z, s, s_tensor) |
|
print("Inference time: {:.3f} s".format(time.time() - s_t)) |
|
|
|
print() |
|
print("Saving MIDI files in {}...\n".format(output_dir)) |
|
save(mtp, output_dir, s_tensor, args.n_loops, audio, z) |
|
print("Finished saving MIDI files.") |
|
print_divider() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|