polyphemus / plots.py
EmanueleCosenza's picture
Working version
d896bd4
raw
history blame
4.51 kB
import os
from matplotlib import pyplot as plt
import matplotlib as mpl
import muspy
import torch
import constants
def plot_pianoroll(muspy_song, save_dir=None, name='pianoroll'):
lines_linewidth = 4
axes_linewidth = 4
font_size = 34
fformat = 'png'
xticklabel = False
label = 'y'
figsize = (20, 10)
dpi = 200
with mpl.rc_context({'lines.linewidth': lines_linewidth,
'axes.linewidth': axes_linewidth,
'font.size': font_size}):
fig, axs_ = plt.subplots(constants.N_TRACKS, sharex=True,
figsize=figsize)
fig.subplots_adjust(hspace=0)
axs = axs_.tolist()
muspy.show_pianoroll(music=muspy_song, yticklabel='off', xtick='off',
label=label, xticklabel=xticklabel,
grid_axis='off', axs=axs, preset='full')
if save_dir:
plt.savefig(os.path.join(save_dir, name + "." + fformat),
format=fformat, dpi=dpi)
def plot_structure(s_tensor, save_dir=None, name='structure'):
lines_linewidth = 1
axes_linewidth = 1
font_size = 14
fformat = 'svg'
dpi = 200
n_bars = s_tensor.shape[0]
figsize = (3 * n_bars, 3)
n_timesteps = s_tensor.size(2)
resolution = n_timesteps // 4
s_tensor = s_tensor.permute(1, 0, 2)
s_tensor = s_tensor.reshape(s_tensor.shape[0], -1)
with mpl.rc_context({'lines.linewidth': lines_linewidth,
'axes.linewidth': axes_linewidth,
'font.size': font_size}):
plt.figure(figsize=figsize)
plt.pcolormesh(s_tensor, edgecolors='k', linewidth=1)
ax = plt.gca()
plt.xticks(range(0, s_tensor.shape[1], resolution),
range(1, 4*n_bars + 1))
plt.yticks(range(0, s_tensor.shape[0]), constants.TRACKS)
ax.invert_yaxis()
if save_dir:
plt.savefig(os.path.join(save_dir, name + "." + fformat),
format=fformat, dpi=dpi)
def plot_stats(stat_names, stats_tr, stats_val=None, eval_every=None,
labels=None, rx=None, ry=None):
for i, stat in enumerate(stat_names):
label = stat if not labels else labels[i]
plt.plot(range(1, len(stats_tr[stat])+1), stats_tr[stat],
label=label+' (TR)')
if stats_val:
plt.plot(range(eval_every, len(stats_tr[stat])+1, eval_every),
stats_val[stat], '.', label=label+' (VL)')
plt.grid()
plt.ylim(ry) if ry else plt.ylim(0)
plt.xlim(rx) if rx else plt.xlim(0)
plt.legend()
# Dictionary that maps loss statistic name to plot label
loss_labels = {
'tot': 'Total Loss',
'structure': 'Structure',
'pitch': 'Pitches',
'dur': 'Duration',
'reconstruction': 'Reconstruction Term',
'kld': 'KLD',
'beta*kld': 'beta * KLD'
}
def plot_losses(model_dir, losses, plot_val=False):
checkpoint_path = os.path.join(model_dir, 'checkpoint')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
labels = [loss_labels[loss] for loss in losses]
tr_losses = checkpoint['tr_losses']
val_losses = checkpoint['val_losses'] if plot_val == True else None
eval_every = checkpoint['eval_every'] if plot_val == True else None
plot_stats(losses, tr_losses, stats_val=val_losses,
eval_every=eval_every, labels=labels, rx=(0))
# Dictionary that maps accuracy statistic name to plot label
accuracy_labels = {
's_acc': 'Struct. Accuracy',
's_precision': 'Struct. Precision',
's_recall': 'Struct. Recall',
's_f1': 'Struct. F1',
'pitch': 'Pitch Accuracy',
'pitch_drums': 'Pitch Accuracy (Drums)',
'pitch_non_drums': 'Pitch Accuracy (Non Drums)',
'dur': 'Duration Accuracy',
'note': 'Note Accuracy'
}
def plot_accuracies(model_dir, accuracies, plot_val=False):
checkpoint_path = os.path.join(model_dir, 'checkpoint')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
labels = [accuracy_labels[accuracy] for accuracy in accuracies]
tr_accuracies = checkpoint['tr_accuracies']
val_accuracies = checkpoint['val_accuracies'] if plot_val == True else None
eval_every = checkpoint['eval_every'] if plot_val == True else None
plot_stats(accuracies, tr_accuracies, stats_val=val_accuracies,
eval_every=eval_every, labels=labels, ry=(0, 1))