polyphemus / training.py
EmanueleCosenza's picture
Working version
d896bd4
raw
history blame
19.3 kB
import time
import os
from statistics import mean
from collections import defaultdict
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn
from tqdm.auto import tqdm
import pprint
import math
import constants
from constants import PitchToken, DurationToken
from utils import append_dict, print_divider
class StepBetaScheduler():
def __init__(self, anneal_start, beta_max, step_size, anneal_end):
self.anneal_start = anneal_start
self.beta_max = beta_max
self.step_size = step_size
self.anneal_end = anneal_end
self.update_steps = 0
self.beta = 0
n_steps = self.beta_max // self.step_size
self.inc_every = (self.anneal_end-self.anneal_start) // n_steps
def step(self):
self.update_steps += 1
if (self.update_steps >= self.anneal_start or
self.update_steps < self.anneal_end):
# If we are annealing, update beta according to current step
curr_step = (self.update_steps-self.anneal_start) // self.inc_every
self.beta = self.step_size * (curr_step+1)
return self.beta
class ExpDecayLRScheduler():
def __init__(self, optimizer, peak_lr, warmup_steps, final_lr_scale,
decay_steps):
self.optimizer = optimizer
self.peak_lr = peak_lr
self.warmup_steps = warmup_steps
self.decay_steps = decay_steps
# Find the decay factor needed to reach the specified
# learning rate scale after decay_steps steps
self.decay_factor = -math.log(final_lr_scale) / self.decay_steps
self.update_steps = 0
def set_lr(self, optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def step(self):
self.update_steps += 1
if self.update_steps <= self.warmup_steps:
self.lr = self.peak_lr
else:
# Decay lr exponentially
steps_after_warmup = self.update_steps - self. warmup_steps
self.lr = \
self.peak_lr * math.exp(-self.decay_factor*steps_after_warmup)
self.set_lr(self.optimizer, self.lr)
return self.lr
class PolyphemusTrainer():
def __init__(self, model_dir, model, optimizer, init_lr=1e-4,
lr_scheduler=None, beta_scheduler=None, device=None,
print_every=1, save_every=1, eval_every=100,
iters_to_accumulate=1, **kwargs):
self.__dict__.update(kwargs)
self.model_dir = model_dir
self.model = model
self.optimizer = optimizer
self.init_lr = init_lr
self.lr_scheduler = lr_scheduler
self.beta_scheduler = beta_scheduler
self.device = device if device is not None else torch.device("cpu")
self.cuda = True if self.device.type == 'cuda' else False
self.print_every = print_every
self.save_every = save_every
self.eval_every = eval_every
self.iters_to_accumulate = iters_to_accumulate
# Losses (ignoring PAD tokens)
self.bce_unreduced = nn.BCEWithLogitsLoss(reduction='none')
self.ce_p = nn.CrossEntropyLoss(ignore_index=PitchToken.PAD.value)
self.ce_d = nn.CrossEntropyLoss(ignore_index=DurationToken.PAD.value)
# Training stats
self.tr_losses = defaultdict(list)
self.tr_accuracies = defaultdict(list)
self.val_losses = defaultdict(list)
self.val_accuracies = defaultdict(list)
self.lrs = []
self.betas = []
self.times = []
def train(self, trainloader, validloader=None, epochs=100, early_exit=None):
self.tot_batches = 0
self.beta = 0
self.min_val_loss = np.inf
start = time.time()
self.times.append(start)
self.model.train()
scaler = torch.cuda.amp.GradScaler() if self.cuda else None
self.optimizer.zero_grad()
progress_bar = tqdm(range(len(trainloader)))
for epoch in range(epochs):
self.cur_epoch = epoch
for batch_idx, graph in enumerate(trainloader):
self.cur_batch_idx = batch_idx
# Move batch of graphs to device. Note: a single graph here
# represents a bar in the original sequence.
graph = graph.to(self.device)
s_tensor, c_tensor = graph.s_tensor, graph.c_tensor
with torch.cuda.amp.autocast(enabled=self.cuda):
# Forward pass to obtain mu, log(sigma^2), computed by the
# encoder, and structure and content logits, computed by the
# decoder
(s_logits, c_logits), mu, log_var = self.model(graph)
# Compute losses
tot_loss, losses = self._losses(
s_tensor, s_logits,
c_tensor, c_logits,
mu, log_var
)
tot_loss = tot_loss / self.iters_to_accumulate
# Backpropagation
if self.cuda:
scaler.scale(tot_loss).backward()
else:
tot_loss.backward()
# Update weights with accumulated gradients
if (self.tot_batches + 1) % self.iters_to_accumulate == 0:
if self.cuda:
scaler.step(self.optimizer)
scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
# Update lr and beta
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if self.beta_scheduler is not None:
self.beta_scheduler.step()
# Compute accuracies
accs = self._accuracies(
s_tensor, s_logits,
c_tensor, c_logits,
graph.is_drum
)
# Update the stats
append_dict(self.tr_losses, losses)
append_dict(self.tr_accuracies, accs)
last_lr = (self.lr_scheduler.lr
if self.lr_scheduler is not None else self.init_lr)
self.lrs.append(last_lr)
self.betas.append(self.beta)
now = time.time()
self.times.append(now)
# Print stats
if (self.tot_batches + 1) % self.print_every == 0:
print("Training on batch {}/{} of epoch {}/{} complete."
.format(batch_idx+1,
len(trainloader),
epoch+1,
epochs))
self._print_stats()
print_divider()
# Eval on VL every `eval_every` gradient updates
if (validloader is not None and
(self.tot_batches + 1) % self.eval_every == 0):
# Evaluate on VL
print("\nEvaluating on validation set...\n")
val_losses, val_accuracies = self.evaluate(validloader)
# Update stats
append_dict(self.val_losses, val_losses)
append_dict(self.val_accuracies, val_accuracies)
print("Val losses:")
print(val_losses)
print("Val accuracies:")
print(val_accuracies)
# Save model if VL loss (tot) reached a new minimum
tot_loss = val_losses['tot']
if tot_loss < self.min_val_loss:
print("\nValidation loss improved.")
print("Saving new best model to disk...\n")
self._save_model('best_model')
self.min_val_loss = tot_loss
self.model.train()
progress_bar.update(1)
# Save model and stats on disk
if (self.save_every > 0 and
(self.tot_batches + 1) % self.save_every == 0):
self._save_model('checkpoint')
# Stop prematurely if early_exit is set and reached
if (early_exit is not None and
(self.tot_batches + 1) > early_exit):
break
self.tot_batches += 1
end = time.time()
hours, rem = divmod(end-start, 3600)
minutes, seconds = divmod(rem, 60)
print("Training completed in (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
.format(int(hours), int(minutes), seconds))
self._save_model('checkpoint')
def evaluate(self, loader):
losses = defaultdict(list)
accs = defaultdict(list)
self.model.eval()
progress_bar = tqdm(range(len(loader)))
with torch.no_grad():
for _, graph in enumerate(loader):
# Get the inputs and move them to device
graph = graph.to(self.device)
s_tensor, c_tensor = graph.s_tensor, graph.c_tensor
with torch.cuda.amp.autocast():
# Forward pass, get the reconstructions
(s_logits, c_logits), mu, log_var = self.model(graph)
_, losses_b = self._losses(
s_tensor, s_logits,
c_tensor, c_logits,
mu, log_var
)
accs_b = self._accuracies(
s_tensor, s_logits,
c_tensor, c_logits,
graph.is_drum
)
# Save losses and accuracies
append_dict(losses, losses_b)
append_dict(accs, accs_b)
progress_bar.update(1)
# Compute avg losses and accuracies
avg_losses = {}
for k, l in losses.items():
avg_losses[k] = mean(l)
avg_accs = {}
for k, l in accs.items():
avg_accs[k] = mean(l)
return avg_losses, avg_accs
def _losses(self, s_tensor, s_logits, c_tensor, c_logits, mu, log_var):
# Do not consider SOS token
c_tensor = c_tensor[..., 1:, :]
c_logits = c_logits.reshape(-1, c_logits.size(-1))
c_tensor = c_tensor.reshape(-1, c_tensor.size(-1))
# Reshape logits to match s_tensor dimensions:
# n_graphs (in batch) x n_tracks x n_timesteps
s_logits = s_tensor.reshape(-1, *s_logits.shape[2:])
# Binary structure tensor loss (binary cross entropy)
s_loss = self.bce_unreduced(
s_logits.view(-1), s_tensor.view(-1).float())
s_loss = torch.mean(s_loss)
# Content tensor loss (pitches)
# argmax is used to obtain token ids from onehot rep
pitch_logits = c_logits[:, :constants.N_PITCH_TOKENS]
pitch_true = c_tensor[:, :constants.N_PITCH_TOKENS].argmax(dim=1)
pitch_loss = self.ce_p(pitch_logits, pitch_true)
# Content tensor loss (durations)
dur_logits = c_logits[:, constants.N_PITCH_TOKENS:]
dur_true = c_tensor[:, constants.N_PITCH_TOKENS:].argmax(dim=1)
dur_loss = self.ce_d(dur_logits, dur_true)
# Kullback-Leibler divergence loss
# Derivation in Kingma, Diederik P., and Max Welling. "Auto-encoding
# variational bayes." (2013), Appendix B.
# (https://arxiv.org/pdf/1312.6114.pdf)
kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(),
dim=1)
kld_loss = torch.mean(kld_loss)
# Reconstruction loss and total loss
rec_loss = pitch_loss + dur_loss + s_loss
tot_loss = rec_loss + self.beta*kld_loss
losses = {
'tot': tot_loss.item(),
'pitch': pitch_loss.item(),
'dur': dur_loss.item(),
'structure': s_loss.item(),
'reconstruction': rec_loss.item(),
'kld': kld_loss.item(),
'beta*kld': self.beta*kld_loss.item()
}
return tot_loss, losses
def _accuracies(self, s_tensor, s_logits, c_tensor, c_logits, is_drum):
# Do not consider SOS token
c_tensor = c_tensor[..., 1:, :]
# Reshape logits to match s_tensor dimensions:
# n_graphs (in batch) x n_tracks x n_timesteps
s_logits = s_tensor.reshape(-1, *s_logits.shape[2:])
# Note accuracy considers both pitches and durations
note_acc = self._note_accuracy(c_logits, c_tensor)
pitch_acc = self._pitch_accuracy(c_logits, c_tensor)
# Compute pitch accuracies for drums and non drums separately
pitch_acc_drums = self._pitch_accuracy(
c_logits, c_tensor, drums=True, is_drum=is_drum
)
pitch_acc_non_drums = self._pitch_accuracy(
c_logits, c_tensor, drums=False, is_drum=is_drum
)
dur_acc = self._duration_accuracy(c_logits, c_tensor)
s_acc = self._structure_accuracy(s_logits, s_tensor)
s_precision = self._structure_precision(s_logits, s_tensor)
s_recall = self._structure_recall(s_logits, s_tensor)
s_f1 = (2*s_recall*s_precision / (s_recall+s_precision))
accs = {
'note': note_acc.item(),
'pitch': pitch_acc.item(),
'pitch_drums': pitch_acc_drums.item(),
'pitch_non_drums': pitch_acc_non_drums.item(),
'dur': dur_acc.item(),
's_acc': s_acc.item(),
's_precision': s_precision.item(),
's_recall': s_recall.item(),
's_f1': s_f1.item()
}
return accs
def _pitch_accuracy(self, c_logits, c_tensor, drums=None, is_drum=None):
# When drums is None, just compute the global pitch accuracy without
# distinguishing between drum and non drum pitches
if drums is not None:
if drums:
c_logits = c_logits[is_drum]
c_tensor = c_tensor[is_drum]
else:
c_logits = c_logits[torch.logical_not(is_drum)]
c_tensor = c_tensor[torch.logical_not(is_drum)]
# Apply softmax to obtain pitch reconstructions
pitch_rec = c_logits[..., :constants.N_PITCH_TOKENS]
pitch_rec = F.softmax(pitch_rec, dim=-1)
pitch_rec = torch.argmax(pitch_rec, dim=-1)
pitch_true = c_tensor[..., :constants.N_PITCH_TOKENS]
pitch_true = torch.argmax(pitch_true, dim=-1)
# Do not consider PAD tokens when computing accuracies
not_pad = (pitch_true != PitchToken.PAD.value)
correct = (pitch_rec == pitch_true)
correct = torch.logical_and(correct, not_pad)
return torch.sum(correct) / torch.sum(not_pad)
def _duration_accuracy(self, c_logits, c_tensor):
# Apply softmax to obtain reconstructed durations
dur_rec = c_logits[..., constants.N_PITCH_TOKENS:]
dur_rec = F.softmax(dur_rec, dim=-1)
dur_rec = torch.argmax(dur_rec, dim=-1)
dur_true = c_tensor[..., constants.N_PITCH_TOKENS:]
dur_true = torch.argmax(dur_true, dim=-1)
# Do not consider PAD tokens when computing accuracies
not_pad = (dur_true != DurationToken.PAD.value)
correct = (dur_rec == dur_true)
correct = torch.logical_and(correct, not_pad)
return torch.sum(correct) / torch.sum(not_pad)
def _note_accuracy(self, c_logits, c_tensor):
# Apply softmax to obtain pitch reconstructions
pitch_rec = c_logits[..., :constants.N_PITCH_TOKENS]
pitch_rec = F.softmax(pitch_rec, dim=-1)
pitch_rec = torch.argmax(pitch_rec, dim=-1)
pitch_true = c_tensor[..., :constants.N_PITCH_TOKENS]
pitch_true = torch.argmax(pitch_true, dim=-1)
not_pad_p = (pitch_true != PitchToken.PAD.value)
correct_p = (pitch_rec == pitch_true)
correct_p = torch.logical_and(correct_p, not_pad_p)
dur_rec = c_logits[..., constants.N_PITCH_TOKENS:]
dur_rec = F.softmax(dur_rec, dim=-1)
dur_rec = torch.argmax(dur_rec, dim=-1)
dur_true = c_tensor[..., constants.N_PITCH_TOKENS:]
dur_true = torch.argmax(dur_true, dim=-1)
not_pad_d = (dur_true != DurationToken.PAD.value)
correct_d = (dur_rec == dur_true)
correct_d = torch.logical_and(correct_d, not_pad_d)
note_accuracy = torch.sum(
torch.logical_and(correct_p, correct_d)) / torch.sum(not_pad_p)
return note_accuracy
def _structure_accuracy(self, s_logits, s_tensor):
s_logits = torch.sigmoid(s_logits)
s_logits[s_logits < 0.5] = 0
s_logits[s_logits >= 0.5] = 1
return torch.sum(s_logits == s_tensor) / s_tensor.numel()
def _structure_precision(self, s_logits, s_tensor):
s_logits = torch.sigmoid(s_logits)
s_logits[s_logits < 0.5] = 0
s_logits[s_logits >= 0.5] = 1
tp = torch.sum(s_tensor[s_logits == 1])
return tp / torch.sum(s_logits)
def _structure_recall(self, s_logits, s_tensor):
s_logits = torch.sigmoid(s_logits)
s_logits[s_logits < 0.5] = 0
s_logits[s_logits >= 0.5] = 1
tp = torch.sum(s_tensor[s_logits == 1])
return tp / torch.sum(s_tensor)
def _save_model(self, filename):
path = os.path.join(self.model_dir, filename)
print("Saving model to disk...")
torch.save({
'epoch': self.cur_epoch,
'batch': self.cur_batch_idx,
'tot_batches': self.tot_batches,
'betas': self.betas,
'min_val_loss': self.min_val_loss,
'print_every': self.print_every,
'save_every': self.save_every,
'eval_every': self.eval_every,
'lrs': self.lrs,
'tr_losses': self.tr_losses,
'tr_accuracies': self.tr_accuracies,
'val_losses': self.val_losses,
'val_accuracies': self.val_accuracies,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict()
}, path)
print("The model has been successfully saved.")
def _print_stats(self):
hours, rem = divmod(self.times[-1]-self.times[0], 3600)
minutes, seconds = divmod(rem, 60)
print("Elapsed time from start (h:m:s): {:0>2}:{:0>2}:{:05.2f}"
.format(int(hours), int(minutes), seconds))
# Take mean of the last non-printed batches for each loss and accuracy
avg_losses = {}
for k, l in self.tr_losses.items():
v = mean(l[-self.print_every:])
avg_losses[k] = round(v, 2)
avg_accs = {}
for k, l in self.tr_accuracies.items():
v = mean(l[-self.print_every:])
avg_accs[k] = round(v, 2)
print("Losses:")
pprint.pprint(avg_losses, indent=2)
print("Accuracies:")
pprint.pprint(avg_accs, indent=2)