Spaces:
Sleeping
Sleeping
import time | |
import os | |
from statistics import mean | |
from collections import defaultdict | |
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
from torch import nn | |
from tqdm.auto import tqdm | |
import pprint | |
import math | |
import constants | |
from constants import PitchToken, DurationToken | |
from utils import append_dict, print_divider | |
class StepBetaScheduler(): | |
def __init__(self, anneal_start, beta_max, step_size, anneal_end): | |
self.anneal_start = anneal_start | |
self.beta_max = beta_max | |
self.step_size = step_size | |
self.anneal_end = anneal_end | |
self.update_steps = 0 | |
self.beta = 0 | |
n_steps = self.beta_max // self.step_size | |
self.inc_every = (self.anneal_end-self.anneal_start) // n_steps | |
def step(self): | |
self.update_steps += 1 | |
if (self.update_steps >= self.anneal_start or | |
self.update_steps < self.anneal_end): | |
# If we are annealing, update beta according to current step | |
curr_step = (self.update_steps-self.anneal_start) // self.inc_every | |
self.beta = self.step_size * (curr_step+1) | |
return self.beta | |
class ExpDecayLRScheduler(): | |
def __init__(self, optimizer, peak_lr, warmup_steps, final_lr_scale, | |
decay_steps): | |
self.optimizer = optimizer | |
self.peak_lr = peak_lr | |
self.warmup_steps = warmup_steps | |
self.decay_steps = decay_steps | |
# Find the decay factor needed to reach the specified | |
# learning rate scale after decay_steps steps | |
self.decay_factor = -math.log(final_lr_scale) / self.decay_steps | |
self.update_steps = 0 | |
def set_lr(self, optimizer, lr): | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
def step(self): | |
self.update_steps += 1 | |
if self.update_steps <= self.warmup_steps: | |
self.lr = self.peak_lr | |
else: | |
# Decay lr exponentially | |
steps_after_warmup = self.update_steps - self. warmup_steps | |
self.lr = \ | |
self.peak_lr * math.exp(-self.decay_factor*steps_after_warmup) | |
self.set_lr(self.optimizer, self.lr) | |
return self.lr | |
class PolyphemusTrainer(): | |
def __init__(self, model_dir, model, optimizer, init_lr=1e-4, | |
lr_scheduler=None, beta_scheduler=None, device=None, | |
print_every=1, save_every=1, eval_every=100, | |
iters_to_accumulate=1, **kwargs): | |
self.__dict__.update(kwargs) | |
self.model_dir = model_dir | |
self.model = model | |
self.optimizer = optimizer | |
self.init_lr = init_lr | |
self.lr_scheduler = lr_scheduler | |
self.beta_scheduler = beta_scheduler | |
self.device = device if device is not None else torch.device("cpu") | |
self.cuda = True if self.device.type == 'cuda' else False | |
self.print_every = print_every | |
self.save_every = save_every | |
self.eval_every = eval_every | |
self.iters_to_accumulate = iters_to_accumulate | |
# Losses (ignoring PAD tokens) | |
self.bce_unreduced = nn.BCEWithLogitsLoss(reduction='none') | |
self.ce_p = nn.CrossEntropyLoss(ignore_index=PitchToken.PAD.value) | |
self.ce_d = nn.CrossEntropyLoss(ignore_index=DurationToken.PAD.value) | |
# Training stats | |
self.tr_losses = defaultdict(list) | |
self.tr_accuracies = defaultdict(list) | |
self.val_losses = defaultdict(list) | |
self.val_accuracies = defaultdict(list) | |
self.lrs = [] | |
self.betas = [] | |
self.times = [] | |
def train(self, trainloader, validloader=None, epochs=100, early_exit=None): | |
self.tot_batches = 0 | |
self.beta = 0 | |
self.min_val_loss = np.inf | |
start = time.time() | |
self.times.append(start) | |
self.model.train() | |
scaler = torch.cuda.amp.GradScaler() if self.cuda else None | |
self.optimizer.zero_grad() | |
progress_bar = tqdm(range(len(trainloader))) | |
for epoch in range(epochs): | |
self.cur_epoch = epoch | |
for batch_idx, graph in enumerate(trainloader): | |
self.cur_batch_idx = batch_idx | |
# Move batch of graphs to device. Note: a single graph here | |
# represents a bar in the original sequence. | |
graph = graph.to(self.device) | |
s_tensor, c_tensor = graph.s_tensor, graph.c_tensor | |
with torch.cuda.amp.autocast(enabled=self.cuda): | |
# Forward pass to obtain mu, log(sigma^2), computed by the | |
# encoder, and structure and content logits, computed by the | |
# decoder | |
(s_logits, c_logits), mu, log_var = self.model(graph) | |
# Compute losses | |
tot_loss, losses = self._losses( | |
s_tensor, s_logits, | |
c_tensor, c_logits, | |
mu, log_var | |
) | |
tot_loss = tot_loss / self.iters_to_accumulate | |
# Backpropagation | |
if self.cuda: | |
scaler.scale(tot_loss).backward() | |
else: | |
tot_loss.backward() | |
# Update weights with accumulated gradients | |
if (self.tot_batches + 1) % self.iters_to_accumulate == 0: | |
if self.cuda: | |
scaler.step(self.optimizer) | |
scaler.update() | |
else: | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
# Update lr and beta | |
if self.lr_scheduler is not None: | |
self.lr_scheduler.step() | |
if self.beta_scheduler is not None: | |
self.beta_scheduler.step() | |
# Compute accuracies | |
accs = self._accuracies( | |
s_tensor, s_logits, | |
c_tensor, c_logits, | |
graph.is_drum | |
) | |
# Update the stats | |
append_dict(self.tr_losses, losses) | |
append_dict(self.tr_accuracies, accs) | |
last_lr = (self.lr_scheduler.lr | |
if self.lr_scheduler is not None else self.init_lr) | |
self.lrs.append(last_lr) | |
self.betas.append(self.beta) | |
now = time.time() | |
self.times.append(now) | |
# Print stats | |
if (self.tot_batches + 1) % self.print_every == 0: | |
print("Training on batch {}/{} of epoch {}/{} complete." | |
.format(batch_idx+1, | |
len(trainloader), | |
epoch+1, | |
epochs)) | |
self._print_stats() | |
print_divider() | |
# Eval on VL every `eval_every` gradient updates | |
if (validloader is not None and | |
(self.tot_batches + 1) % self.eval_every == 0): | |
# Evaluate on VL | |
print("\nEvaluating on validation set...\n") | |
val_losses, val_accuracies = self.evaluate(validloader) | |
# Update stats | |
append_dict(self.val_losses, val_losses) | |
append_dict(self.val_accuracies, val_accuracies) | |
print("Val losses:") | |
print(val_losses) | |
print("Val accuracies:") | |
print(val_accuracies) | |
# Save model if VL loss (tot) reached a new minimum | |
tot_loss = val_losses['tot'] | |
if tot_loss < self.min_val_loss: | |
print("\nValidation loss improved.") | |
print("Saving new best model to disk...\n") | |
self._save_model('best_model') | |
self.min_val_loss = tot_loss | |
self.model.train() | |
progress_bar.update(1) | |
# Save model and stats on disk | |
if (self.save_every > 0 and | |
(self.tot_batches + 1) % self.save_every == 0): | |
self._save_model('checkpoint') | |
# Stop prematurely if early_exit is set and reached | |
if (early_exit is not None and | |
(self.tot_batches + 1) > early_exit): | |
break | |
self.tot_batches += 1 | |
end = time.time() | |
hours, rem = divmod(end-start, 3600) | |
minutes, seconds = divmod(rem, 60) | |
print("Training completed in (h:m:s): {:0>2}:{:0>2}:{:05.2f}" | |
.format(int(hours), int(minutes), seconds)) | |
self._save_model('checkpoint') | |
def evaluate(self, loader): | |
losses = defaultdict(list) | |
accs = defaultdict(list) | |
self.model.eval() | |
progress_bar = tqdm(range(len(loader))) | |
with torch.no_grad(): | |
for _, graph in enumerate(loader): | |
# Get the inputs and move them to device | |
graph = graph.to(self.device) | |
s_tensor, c_tensor = graph.s_tensor, graph.c_tensor | |
with torch.cuda.amp.autocast(): | |
# Forward pass, get the reconstructions | |
(s_logits, c_logits), mu, log_var = self.model(graph) | |
_, losses_b = self._losses( | |
s_tensor, s_logits, | |
c_tensor, c_logits, | |
mu, log_var | |
) | |
accs_b = self._accuracies( | |
s_tensor, s_logits, | |
c_tensor, c_logits, | |
graph.is_drum | |
) | |
# Save losses and accuracies | |
append_dict(losses, losses_b) | |
append_dict(accs, accs_b) | |
progress_bar.update(1) | |
# Compute avg losses and accuracies | |
avg_losses = {} | |
for k, l in losses.items(): | |
avg_losses[k] = mean(l) | |
avg_accs = {} | |
for k, l in accs.items(): | |
avg_accs[k] = mean(l) | |
return avg_losses, avg_accs | |
def _losses(self, s_tensor, s_logits, c_tensor, c_logits, mu, log_var): | |
# Do not consider SOS token | |
c_tensor = c_tensor[..., 1:, :] | |
c_logits = c_logits.reshape(-1, c_logits.size(-1)) | |
c_tensor = c_tensor.reshape(-1, c_tensor.size(-1)) | |
# Reshape logits to match s_tensor dimensions: | |
# n_graphs (in batch) x n_tracks x n_timesteps | |
s_logits = s_tensor.reshape(-1, *s_logits.shape[2:]) | |
# Binary structure tensor loss (binary cross entropy) | |
s_loss = self.bce_unreduced( | |
s_logits.view(-1), s_tensor.view(-1).float()) | |
s_loss = torch.mean(s_loss) | |
# Content tensor loss (pitches) | |
# argmax is used to obtain token ids from onehot rep | |
pitch_logits = c_logits[:, :constants.N_PITCH_TOKENS] | |
pitch_true = c_tensor[:, :constants.N_PITCH_TOKENS].argmax(dim=1) | |
pitch_loss = self.ce_p(pitch_logits, pitch_true) | |
# Content tensor loss (durations) | |
dur_logits = c_logits[:, constants.N_PITCH_TOKENS:] | |
dur_true = c_tensor[:, constants.N_PITCH_TOKENS:].argmax(dim=1) | |
dur_loss = self.ce_d(dur_logits, dur_true) | |
# Kullback-Leibler divergence loss | |
# Derivation in Kingma, Diederik P., and Max Welling. "Auto-encoding | |
# variational bayes." (2013), Appendix B. | |
# (https://arxiv.org/pdf/1312.6114.pdf) | |
kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), | |
dim=1) | |
kld_loss = torch.mean(kld_loss) | |
# Reconstruction loss and total loss | |
rec_loss = pitch_loss + dur_loss + s_loss | |
tot_loss = rec_loss + self.beta*kld_loss | |
losses = { | |
'tot': tot_loss.item(), | |
'pitch': pitch_loss.item(), | |
'dur': dur_loss.item(), | |
'structure': s_loss.item(), | |
'reconstruction': rec_loss.item(), | |
'kld': kld_loss.item(), | |
'beta*kld': self.beta*kld_loss.item() | |
} | |
return tot_loss, losses | |
def _accuracies(self, s_tensor, s_logits, c_tensor, c_logits, is_drum): | |
# Do not consider SOS token | |
c_tensor = c_tensor[..., 1:, :] | |
# Reshape logits to match s_tensor dimensions: | |
# n_graphs (in batch) x n_tracks x n_timesteps | |
s_logits = s_tensor.reshape(-1, *s_logits.shape[2:]) | |
# Note accuracy considers both pitches and durations | |
note_acc = self._note_accuracy(c_logits, c_tensor) | |
pitch_acc = self._pitch_accuracy(c_logits, c_tensor) | |
# Compute pitch accuracies for drums and non drums separately | |
pitch_acc_drums = self._pitch_accuracy( | |
c_logits, c_tensor, drums=True, is_drum=is_drum | |
) | |
pitch_acc_non_drums = self._pitch_accuracy( | |
c_logits, c_tensor, drums=False, is_drum=is_drum | |
) | |
dur_acc = self._duration_accuracy(c_logits, c_tensor) | |
s_acc = self._structure_accuracy(s_logits, s_tensor) | |
s_precision = self._structure_precision(s_logits, s_tensor) | |
s_recall = self._structure_recall(s_logits, s_tensor) | |
s_f1 = (2*s_recall*s_precision / (s_recall+s_precision)) | |
accs = { | |
'note': note_acc.item(), | |
'pitch': pitch_acc.item(), | |
'pitch_drums': pitch_acc_drums.item(), | |
'pitch_non_drums': pitch_acc_non_drums.item(), | |
'dur': dur_acc.item(), | |
's_acc': s_acc.item(), | |
's_precision': s_precision.item(), | |
's_recall': s_recall.item(), | |
's_f1': s_f1.item() | |
} | |
return accs | |
def _pitch_accuracy(self, c_logits, c_tensor, drums=None, is_drum=None): | |
# When drums is None, just compute the global pitch accuracy without | |
# distinguishing between drum and non drum pitches | |
if drums is not None: | |
if drums: | |
c_logits = c_logits[is_drum] | |
c_tensor = c_tensor[is_drum] | |
else: | |
c_logits = c_logits[torch.logical_not(is_drum)] | |
c_tensor = c_tensor[torch.logical_not(is_drum)] | |
# Apply softmax to obtain pitch reconstructions | |
pitch_rec = c_logits[..., :constants.N_PITCH_TOKENS] | |
pitch_rec = F.softmax(pitch_rec, dim=-1) | |
pitch_rec = torch.argmax(pitch_rec, dim=-1) | |
pitch_true = c_tensor[..., :constants.N_PITCH_TOKENS] | |
pitch_true = torch.argmax(pitch_true, dim=-1) | |
# Do not consider PAD tokens when computing accuracies | |
not_pad = (pitch_true != PitchToken.PAD.value) | |
correct = (pitch_rec == pitch_true) | |
correct = torch.logical_and(correct, not_pad) | |
return torch.sum(correct) / torch.sum(not_pad) | |
def _duration_accuracy(self, c_logits, c_tensor): | |
# Apply softmax to obtain reconstructed durations | |
dur_rec = c_logits[..., constants.N_PITCH_TOKENS:] | |
dur_rec = F.softmax(dur_rec, dim=-1) | |
dur_rec = torch.argmax(dur_rec, dim=-1) | |
dur_true = c_tensor[..., constants.N_PITCH_TOKENS:] | |
dur_true = torch.argmax(dur_true, dim=-1) | |
# Do not consider PAD tokens when computing accuracies | |
not_pad = (dur_true != DurationToken.PAD.value) | |
correct = (dur_rec == dur_true) | |
correct = torch.logical_and(correct, not_pad) | |
return torch.sum(correct) / torch.sum(not_pad) | |
def _note_accuracy(self, c_logits, c_tensor): | |
# Apply softmax to obtain pitch reconstructions | |
pitch_rec = c_logits[..., :constants.N_PITCH_TOKENS] | |
pitch_rec = F.softmax(pitch_rec, dim=-1) | |
pitch_rec = torch.argmax(pitch_rec, dim=-1) | |
pitch_true = c_tensor[..., :constants.N_PITCH_TOKENS] | |
pitch_true = torch.argmax(pitch_true, dim=-1) | |
not_pad_p = (pitch_true != PitchToken.PAD.value) | |
correct_p = (pitch_rec == pitch_true) | |
correct_p = torch.logical_and(correct_p, not_pad_p) | |
dur_rec = c_logits[..., constants.N_PITCH_TOKENS:] | |
dur_rec = F.softmax(dur_rec, dim=-1) | |
dur_rec = torch.argmax(dur_rec, dim=-1) | |
dur_true = c_tensor[..., constants.N_PITCH_TOKENS:] | |
dur_true = torch.argmax(dur_true, dim=-1) | |
not_pad_d = (dur_true != DurationToken.PAD.value) | |
correct_d = (dur_rec == dur_true) | |
correct_d = torch.logical_and(correct_d, not_pad_d) | |
note_accuracy = torch.sum( | |
torch.logical_and(correct_p, correct_d)) / torch.sum(not_pad_p) | |
return note_accuracy | |
def _structure_accuracy(self, s_logits, s_tensor): | |
s_logits = torch.sigmoid(s_logits) | |
s_logits[s_logits < 0.5] = 0 | |
s_logits[s_logits >= 0.5] = 1 | |
return torch.sum(s_logits == s_tensor) / s_tensor.numel() | |
def _structure_precision(self, s_logits, s_tensor): | |
s_logits = torch.sigmoid(s_logits) | |
s_logits[s_logits < 0.5] = 0 | |
s_logits[s_logits >= 0.5] = 1 | |
tp = torch.sum(s_tensor[s_logits == 1]) | |
return tp / torch.sum(s_logits) | |
def _structure_recall(self, s_logits, s_tensor): | |
s_logits = torch.sigmoid(s_logits) | |
s_logits[s_logits < 0.5] = 0 | |
s_logits[s_logits >= 0.5] = 1 | |
tp = torch.sum(s_tensor[s_logits == 1]) | |
return tp / torch.sum(s_tensor) | |
def _save_model(self, filename): | |
path = os.path.join(self.model_dir, filename) | |
print("Saving model to disk...") | |
torch.save({ | |
'epoch': self.cur_epoch, | |
'batch': self.cur_batch_idx, | |
'tot_batches': self.tot_batches, | |
'betas': self.betas, | |
'min_val_loss': self.min_val_loss, | |
'print_every': self.print_every, | |
'save_every': self.save_every, | |
'eval_every': self.eval_every, | |
'lrs': self.lrs, | |
'tr_losses': self.tr_losses, | |
'tr_accuracies': self.tr_accuracies, | |
'val_losses': self.val_losses, | |
'val_accuracies': self.val_accuracies, | |
'model_state_dict': self.model.state_dict(), | |
'optimizer_state_dict': self.optimizer.state_dict() | |
}, path) | |
print("The model has been successfully saved.") | |
def _print_stats(self): | |
hours, rem = divmod(self.times[-1]-self.times[0], 3600) | |
minutes, seconds = divmod(rem, 60) | |
print("Elapsed time from start (h:m:s): {:0>2}:{:0>2}:{:05.2f}" | |
.format(int(hours), int(minutes), seconds)) | |
# Take mean of the last non-printed batches for each loss and accuracy | |
avg_losses = {} | |
for k, l in self.tr_losses.items(): | |
v = mean(l[-self.print_every:]) | |
avg_losses[k] = round(v, 2) | |
avg_accs = {} | |
for k, l in self.tr_accuracies.items(): | |
v = mean(l[-self.print_every:]) | |
avg_accs[k] = round(v, 2) | |
print("Losses:") | |
pprint.pprint(avg_losses, indent=2) | |
print("Accuracies:") | |
pprint.pprint(avg_accs, indent=2) | |