import os from matplotlib import pyplot as plt import matplotlib as mpl import muspy import torch import constants def plot_pianoroll(muspy_song, save_dir=None, name='pianoroll'): lines_linewidth = 4 axes_linewidth = 4 font_size = 34 fformat = 'png' xticklabel = False label = 'y' figsize = (20, 10) dpi = 200 with mpl.rc_context({'lines.linewidth': lines_linewidth, 'axes.linewidth': axes_linewidth, 'font.size': font_size}): fig, axs_ = plt.subplots(constants.N_TRACKS, sharex=True, figsize=figsize) fig.subplots_adjust(hspace=0) axs = axs_.tolist() muspy.show_pianoroll(music=muspy_song, yticklabel='off', xtick='off', label=label, xticklabel=xticklabel, grid_axis='off', axs=axs, preset='full') if save_dir: plt.savefig(os.path.join(save_dir, name + "." + fformat), format=fformat, dpi=dpi) def plot_structure(s_tensor, save_dir=None, name='structure'): lines_linewidth = 1 axes_linewidth = 1 font_size = 14 fformat = 'svg' dpi = 200 n_bars = s_tensor.shape[0] figsize = (3 * n_bars, 3) n_timesteps = s_tensor.size(2) resolution = n_timesteps // 4 s_tensor = s_tensor.permute(1, 0, 2) s_tensor = s_tensor.reshape(s_tensor.shape[0], -1) with mpl.rc_context({'lines.linewidth': lines_linewidth, 'axes.linewidth': axes_linewidth, 'font.size': font_size}): plt.figure(figsize=figsize) plt.pcolormesh(s_tensor, edgecolors='k', linewidth=1) ax = plt.gca() plt.xticks(range(0, s_tensor.shape[1], resolution), range(1, 4*n_bars + 1)) plt.yticks(range(0, s_tensor.shape[0]), constants.TRACKS) ax.invert_yaxis() if save_dir: plt.savefig(os.path.join(save_dir, name + "." + fformat), format=fformat, dpi=dpi) def plot_stats(stat_names, stats_tr, stats_val=None, eval_every=None, labels=None, rx=None, ry=None): for i, stat in enumerate(stat_names): label = stat if not labels else labels[i] plt.plot(range(1, len(stats_tr[stat])+1), stats_tr[stat], label=label+' (TR)') if stats_val: plt.plot(range(eval_every, len(stats_tr[stat])+1, eval_every), stats_val[stat], '.', label=label+' (VL)') plt.grid() plt.ylim(ry) if ry else plt.ylim(0) plt.xlim(rx) if rx else plt.xlim(0) plt.legend() # Dictionary that maps loss statistic name to plot label loss_labels = { 'tot': 'Total Loss', 'structure': 'Structure', 'pitch': 'Pitches', 'dur': 'Duration', 'reconstruction': 'Reconstruction Term', 'kld': 'KLD', 'beta*kld': 'beta * KLD' } def plot_losses(model_dir, losses, plot_val=False): checkpoint_path = os.path.join(model_dir, 'checkpoint') checkpoint = torch.load(checkpoint_path, map_location='cpu') labels = [loss_labels[loss] for loss in losses] tr_losses = checkpoint['tr_losses'] val_losses = checkpoint['val_losses'] if plot_val == True else None eval_every = checkpoint['eval_every'] if plot_val == True else None plot_stats(losses, tr_losses, stats_val=val_losses, eval_every=eval_every, labels=labels, rx=(0)) # Dictionary that maps accuracy statistic name to plot label accuracy_labels = { 's_acc': 'Struct. Accuracy', 's_precision': 'Struct. Precision', 's_recall': 'Struct. Recall', 's_f1': 'Struct. F1', 'pitch': 'Pitch Accuracy', 'pitch_drums': 'Pitch Accuracy (Drums)', 'pitch_non_drums': 'Pitch Accuracy (Non Drums)', 'dur': 'Duration Accuracy', 'note': 'Note Accuracy' } def plot_accuracies(model_dir, accuracies, plot_val=False): checkpoint_path = os.path.join(model_dir, 'checkpoint') checkpoint = torch.load(checkpoint_path, map_location='cpu') labels = [accuracy_labels[accuracy] for accuracy in accuracies] tr_accuracies = checkpoint['tr_accuracies'] val_accuracies = checkpoint['val_accuracies'] if plot_val == True else None eval_every = checkpoint['eval_every'] if plot_val == True else None plot_stats(accuracies, tr_accuracies, stats_val=val_accuracies, eval_every=eval_every, labels=labels, ry=(0, 1))