Spaces:
Sleeping
Sleeping
File size: 4,508 Bytes
d896bd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
from matplotlib import pyplot as plt
import matplotlib as mpl
import muspy
import torch
import constants
def plot_pianoroll(muspy_song, save_dir=None, name='pianoroll'):
lines_linewidth = 4
axes_linewidth = 4
font_size = 34
fformat = 'png'
xticklabel = False
label = 'y'
figsize = (20, 10)
dpi = 200
with mpl.rc_context({'lines.linewidth': lines_linewidth,
'axes.linewidth': axes_linewidth,
'font.size': font_size}):
fig, axs_ = plt.subplots(constants.N_TRACKS, sharex=True,
figsize=figsize)
fig.subplots_adjust(hspace=0)
axs = axs_.tolist()
muspy.show_pianoroll(music=muspy_song, yticklabel='off', xtick='off',
label=label, xticklabel=xticklabel,
grid_axis='off', axs=axs, preset='full')
if save_dir:
plt.savefig(os.path.join(save_dir, name + "." + fformat),
format=fformat, dpi=dpi)
def plot_structure(s_tensor, save_dir=None, name='structure'):
lines_linewidth = 1
axes_linewidth = 1
font_size = 14
fformat = 'svg'
dpi = 200
n_bars = s_tensor.shape[0]
figsize = (3 * n_bars, 3)
n_timesteps = s_tensor.size(2)
resolution = n_timesteps // 4
s_tensor = s_tensor.permute(1, 0, 2)
s_tensor = s_tensor.reshape(s_tensor.shape[0], -1)
with mpl.rc_context({'lines.linewidth': lines_linewidth,
'axes.linewidth': axes_linewidth,
'font.size': font_size}):
plt.figure(figsize=figsize)
plt.pcolormesh(s_tensor, edgecolors='k', linewidth=1)
ax = plt.gca()
plt.xticks(range(0, s_tensor.shape[1], resolution),
range(1, 4*n_bars + 1))
plt.yticks(range(0, s_tensor.shape[0]), constants.TRACKS)
ax.invert_yaxis()
if save_dir:
plt.savefig(os.path.join(save_dir, name + "." + fformat),
format=fformat, dpi=dpi)
def plot_stats(stat_names, stats_tr, stats_val=None, eval_every=None,
labels=None, rx=None, ry=None):
for i, stat in enumerate(stat_names):
label = stat if not labels else labels[i]
plt.plot(range(1, len(stats_tr[stat])+1), stats_tr[stat],
label=label+' (TR)')
if stats_val:
plt.plot(range(eval_every, len(stats_tr[stat])+1, eval_every),
stats_val[stat], '.', label=label+' (VL)')
plt.grid()
plt.ylim(ry) if ry else plt.ylim(0)
plt.xlim(rx) if rx else plt.xlim(0)
plt.legend()
# Dictionary that maps loss statistic name to plot label
loss_labels = {
'tot': 'Total Loss',
'structure': 'Structure',
'pitch': 'Pitches',
'dur': 'Duration',
'reconstruction': 'Reconstruction Term',
'kld': 'KLD',
'beta*kld': 'beta * KLD'
}
def plot_losses(model_dir, losses, plot_val=False):
checkpoint_path = os.path.join(model_dir, 'checkpoint')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
labels = [loss_labels[loss] for loss in losses]
tr_losses = checkpoint['tr_losses']
val_losses = checkpoint['val_losses'] if plot_val == True else None
eval_every = checkpoint['eval_every'] if plot_val == True else None
plot_stats(losses, tr_losses, stats_val=val_losses,
eval_every=eval_every, labels=labels, rx=(0))
# Dictionary that maps accuracy statistic name to plot label
accuracy_labels = {
's_acc': 'Struct. Accuracy',
's_precision': 'Struct. Precision',
's_recall': 'Struct. Recall',
's_f1': 'Struct. F1',
'pitch': 'Pitch Accuracy',
'pitch_drums': 'Pitch Accuracy (Drums)',
'pitch_non_drums': 'Pitch Accuracy (Non Drums)',
'dur': 'Duration Accuracy',
'note': 'Note Accuracy'
}
def plot_accuracies(model_dir, accuracies, plot_val=False):
checkpoint_path = os.path.join(model_dir, 'checkpoint')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
labels = [accuracy_labels[accuracy] for accuracy in accuracies]
tr_accuracies = checkpoint['tr_accuracies']
val_accuracies = checkpoint['val_accuracies'] if plot_val == True else None
eval_every = checkpoint['eval_every'] if plot_val == True else None
plot_stats(accuracies, tr_accuracies, stats_val=val_accuracies,
eval_every=eval_every, labels=labels, ry=(0, 1)) |