Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Author: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf | |
import json | |
import logging | |
from pathlib import Path | |
import os | |
import time | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR | |
from . import distrib | |
from .separate import separate | |
from .evaluate import evaluate | |
from .models.sisnr_loss import cal_loss | |
from .models.swave import SWave | |
from .utils import bold, copy_state, pull_metric, serialize_model, swap_state, LogProgress | |
logger = logging.getLogger(__name__) | |
class Solver(object): | |
def __init__(self, data, model, optimizer, args): | |
self.tr_loader = data['tr_loader'] | |
self.cv_loader = data['cv_loader'] | |
self.tt_loader = data['tt_loader'] | |
self.model = model | |
self.dmodel = distrib.wrap(model) | |
self.optimizer = optimizer | |
if args.lr_sched == 'step': | |
self.sched = StepLR( | |
self.optimizer, step_size=args.step.step_size, gamma=args.step.gamma) | |
elif args.lr_sched == 'plateau': | |
self.sched = ReduceLROnPlateau( | |
self.optimizer, factor=args.plateau.factor, patience=args.plateau.patience) | |
else: | |
self.sched = None | |
# Training config | |
self.device = args.device | |
self.epochs = args.epochs | |
self.max_norm = args.max_norm | |
# Checkpoints | |
self.continue_from = args.continue_from | |
self.eval_every = args.eval_every | |
self.checkpoint = Path( | |
args.checkpoint_file) if args.checkpoint else None | |
if self.checkpoint: | |
logger.debug("Checkpoint will be saved to %s", | |
self.checkpoint.resolve()) | |
self.history_file = args.history_file | |
self.best_state = None | |
self.restart = args.restart | |
# keep track of losses | |
self.history = [] | |
# Where to save samples | |
self.samples_dir = args.samples_dir | |
# logging | |
self.num_prints = args.num_prints | |
# for seperation tests | |
self.args = args | |
self._reset() | |
def _serialize(self, path): | |
package = {} | |
package['model'] = serialize_model(self.model) | |
package['optimizer'] = self.optimizer.state_dict() | |
package['history'] = self.history | |
package['best_state'] = self.best_state | |
package['args'] = self.args | |
torch.save(package, path) | |
def _reset(self): | |
load_from = None | |
# Reset | |
if self.checkpoint and self.checkpoint.exists() and not self.restart: | |
load_from = self.checkpoint | |
elif self.continue_from: | |
load_from = self.continue_from | |
if load_from: | |
logger.info(f'Loading checkpoint model: {load_from}') | |
package = torch.load(load_from, 'cpu') | |
if load_from == self.continue_from and self.args.continue_best: | |
self.model.load_state_dict(package['best_state']) | |
else: | |
self.model.load_state_dict(package['model']['state']) | |
if 'optimizer' in package and not self.args.continue_best: | |
self.optimizer.load_state_dict(package['optimizer']) | |
self.history = package['history'] | |
self.best_state = package['best_state'] | |
def train(self): | |
# Optimizing the model | |
if self.history: | |
logger.info("Replaying metrics from previous run") | |
for epoch, metrics in enumerate(self.history): | |
info = " ".join(f"{k}={v:.5f}" for k, v in metrics.items()) | |
logger.info(f"Epoch {epoch}: {info}") | |
for epoch in range(len(self.history), self.epochs): | |
# Train one epoch | |
self.model.train() # Turn on BatchNorm & Dropout | |
start = time.time() | |
logger.info('-' * 70) | |
logger.info("Training...") | |
train_loss = self._run_one_epoch(epoch) | |
logger.info(bold(f'Train Summary | End of Epoch {epoch + 1} | ' | |
f'Time {time.time() - start:.2f}s | Train Loss {train_loss:.5f}')) | |
# Cross validation | |
logger.info('-' * 70) | |
logger.info('Cross validation...') | |
self.model.eval() # Turn off Batchnorm & Dropout | |
with torch.no_grad(): | |
valid_loss = self._run_one_epoch(epoch, cross_valid=True) | |
logger.info(bold(f'Valid Summary | End of Epoch {epoch + 1} | ' | |
f'Time {time.time() - start:.2f}s | Valid Loss {valid_loss:.5f}')) | |
# learning rate scheduling | |
if self.sched: | |
if self.args.lr_sched == 'plateau': | |
self.sched.step(valid_loss) | |
else: | |
self.sched.step() | |
logger.info( | |
f'Learning rate adjusted: {self.optimizer.state_dict()["param_groups"][0]["lr"]:.5f}') | |
best_loss = min(pull_metric(self.history, 'valid') + [valid_loss]) | |
metrics = {'train': train_loss, | |
'valid': valid_loss, 'best': best_loss} | |
# Save the best model | |
if valid_loss == best_loss or self.args.keep_last: | |
logger.info(bold('New best valid loss %.4f'), valid_loss) | |
self.best_state = copy_state(self.model.state_dict()) | |
# evaluate and separate samples every 'eval_every' argument number of epochs | |
# also evaluate on last epoch | |
if (epoch + 1) % self.eval_every == 0 or epoch == self.epochs - 1: | |
# Evaluate on the testset | |
logger.info('-' * 70) | |
logger.info('Evaluating on the test set...') | |
# We switch to the best known model for testing | |
with swap_state(self.model, self.best_state): | |
sisnr, pesq, stoi = evaluate( | |
self.args, self.model, self.tt_loader, self.args.sample_rate) | |
metrics.update({'sisnr': sisnr, 'pesq': pesq, 'stoi': stoi}) | |
# separate some samples | |
logger.info('Separate and save samples...') | |
separate(self.args, self.model, self.samples_dir) | |
self.history.append(metrics) | |
info = " | ".join( | |
f"{k.capitalize()} {v:.5f}" for k, v in metrics.items()) | |
logger.info('-' * 70) | |
logger.info(bold(f"Overall Summary | Epoch {epoch + 1} | {info}")) | |
if distrib.rank == 0: | |
json.dump(self.history, open(self.history_file, "w"), indent=2) | |
# Save model each epoch | |
if self.checkpoint: | |
self._serialize(self.checkpoint) | |
logger.debug("Checkpoint saved to %s", | |
self.checkpoint.resolve()) | |
def _run_one_epoch(self, epoch, cross_valid=False): | |
total_loss = 0 | |
data_loader = self.tr_loader if not cross_valid else self.cv_loader | |
# get a different order for distributed training, otherwise this will get ignored | |
data_loader.epoch = epoch | |
label = ["Train", "Valid"][cross_valid] | |
name = label + f" | Epoch {epoch + 1}" | |
logprog = LogProgress(logger, data_loader, | |
updates=self.num_prints, name=name) | |
for i, data in enumerate(logprog): | |
mixture, lengths, sources = [x.to(self.device) for x in data] | |
estimate_source = self.dmodel(mixture) | |
# only eval last layer | |
if cross_valid: | |
estimate_source = estimate_source[-1:] | |
loss = 0 | |
cnt = len(estimate_source) | |
# apply a loss function after each layer | |
with torch.autograd.set_detect_anomaly(True): | |
for c_idx, est_src in enumerate(estimate_source): | |
coeff = ((c_idx+1)*(1/cnt)) | |
loss_i = 0 | |
# SI-SNR loss | |
sisnr_loss, snr, est_src, reorder_est_src = cal_loss( | |
sources, estimate_source[c_idx], lengths) | |
loss += (coeff * sisnr_loss) | |
loss /= len(estimate_source) | |
if not cross_valid: | |
# optimize model in training mode | |
self.optimizer.zero_grad() | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), | |
self.max_norm) | |
self.optimizer.step() | |
total_loss += loss.item() | |
logprog.update(loss=format(total_loss / (i + 1), ".5f")) | |
# Just in case, clear some memory | |
del loss, estimate_source | |
return distrib.average([total_loss / (i + 1)], i + 1)[0] | |