svoice_demo / svoice /solver.py
ahmedghani's picture
initial commit
8235b4f
raw
history blame
9.01 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Author: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf
import json
import logging
from pathlib import Path
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from . import distrib
from .separate import separate
from .evaluate import evaluate
from .models.sisnr_loss import cal_loss
from .models.swave import SWave
from .utils import bold, copy_state, pull_metric, serialize_model, swap_state, LogProgress
logger = logging.getLogger(__name__)
class Solver(object):
def __init__(self, data, model, optimizer, args):
self.tr_loader = data['tr_loader']
self.cv_loader = data['cv_loader']
self.tt_loader = data['tt_loader']
self.model = model
self.dmodel = distrib.wrap(model)
self.optimizer = optimizer
if args.lr_sched == 'step':
self.sched = StepLR(
self.optimizer, step_size=args.step.step_size, gamma=args.step.gamma)
elif args.lr_sched == 'plateau':
self.sched = ReduceLROnPlateau(
self.optimizer, factor=args.plateau.factor, patience=args.plateau.patience)
else:
self.sched = None
# Training config
self.device = args.device
self.epochs = args.epochs
self.max_norm = args.max_norm
# Checkpoints
self.continue_from = args.continue_from
self.eval_every = args.eval_every
self.checkpoint = Path(
args.checkpoint_file) if args.checkpoint else None
if self.checkpoint:
logger.debug("Checkpoint will be saved to %s",
self.checkpoint.resolve())
self.history_file = args.history_file
self.best_state = None
self.restart = args.restart
# keep track of losses
self.history = []
# Where to save samples
self.samples_dir = args.samples_dir
# logging
self.num_prints = args.num_prints
# for seperation tests
self.args = args
self._reset()
def _serialize(self, path):
package = {}
package['model'] = serialize_model(self.model)
package['optimizer'] = self.optimizer.state_dict()
package['history'] = self.history
package['best_state'] = self.best_state
package['args'] = self.args
torch.save(package, path)
def _reset(self):
load_from = None
# Reset
if self.checkpoint and self.checkpoint.exists() and not self.restart:
load_from = self.checkpoint
elif self.continue_from:
load_from = self.continue_from
if load_from:
logger.info(f'Loading checkpoint model: {load_from}')
package = torch.load(load_from, 'cpu')
if load_from == self.continue_from and self.args.continue_best:
self.model.load_state_dict(package['best_state'])
else:
self.model.load_state_dict(package['model']['state'])
if 'optimizer' in package and not self.args.continue_best:
self.optimizer.load_state_dict(package['optimizer'])
self.history = package['history']
self.best_state = package['best_state']
def train(self):
# Optimizing the model
if self.history:
logger.info("Replaying metrics from previous run")
for epoch, metrics in enumerate(self.history):
info = " ".join(f"{k}={v:.5f}" for k, v in metrics.items())
logger.info(f"Epoch {epoch}: {info}")
for epoch in range(len(self.history), self.epochs):
# Train one epoch
self.model.train() # Turn on BatchNorm & Dropout
start = time.time()
logger.info('-' * 70)
logger.info("Training...")
train_loss = self._run_one_epoch(epoch)
logger.info(bold(f'Train Summary | End of Epoch {epoch + 1} | '
f'Time {time.time() - start:.2f}s | Train Loss {train_loss:.5f}'))
# Cross validation
logger.info('-' * 70)
logger.info('Cross validation...')
self.model.eval() # Turn off Batchnorm & Dropout
with torch.no_grad():
valid_loss = self._run_one_epoch(epoch, cross_valid=True)
logger.info(bold(f'Valid Summary | End of Epoch {epoch + 1} | '
f'Time {time.time() - start:.2f}s | Valid Loss {valid_loss:.5f}'))
# learning rate scheduling
if self.sched:
if self.args.lr_sched == 'plateau':
self.sched.step(valid_loss)
else:
self.sched.step()
logger.info(
f'Learning rate adjusted: {self.optimizer.state_dict()["param_groups"][0]["lr"]:.5f}')
best_loss = min(pull_metric(self.history, 'valid') + [valid_loss])
metrics = {'train': train_loss,
'valid': valid_loss, 'best': best_loss}
# Save the best model
if valid_loss == best_loss or self.args.keep_last:
logger.info(bold('New best valid loss %.4f'), valid_loss)
self.best_state = copy_state(self.model.state_dict())
# evaluate and separate samples every 'eval_every' argument number of epochs
# also evaluate on last epoch
if (epoch + 1) % self.eval_every == 0 or epoch == self.epochs - 1:
# Evaluate on the testset
logger.info('-' * 70)
logger.info('Evaluating on the test set...')
# We switch to the best known model for testing
with swap_state(self.model, self.best_state):
sisnr, pesq, stoi = evaluate(
self.args, self.model, self.tt_loader, self.args.sample_rate)
metrics.update({'sisnr': sisnr, 'pesq': pesq, 'stoi': stoi})
# separate some samples
logger.info('Separate and save samples...')
separate(self.args, self.model, self.samples_dir)
self.history.append(metrics)
info = " | ".join(
f"{k.capitalize()} {v:.5f}" for k, v in metrics.items())
logger.info('-' * 70)
logger.info(bold(f"Overall Summary | Epoch {epoch + 1} | {info}"))
if distrib.rank == 0:
json.dump(self.history, open(self.history_file, "w"), indent=2)
# Save model each epoch
if self.checkpoint:
self._serialize(self.checkpoint)
logger.debug("Checkpoint saved to %s",
self.checkpoint.resolve())
def _run_one_epoch(self, epoch, cross_valid=False):
total_loss = 0
data_loader = self.tr_loader if not cross_valid else self.cv_loader
# get a different order for distributed training, otherwise this will get ignored
data_loader.epoch = epoch
label = ["Train", "Valid"][cross_valid]
name = label + f" | Epoch {epoch + 1}"
logprog = LogProgress(logger, data_loader,
updates=self.num_prints, name=name)
for i, data in enumerate(logprog):
mixture, lengths, sources = [x.to(self.device) for x in data]
estimate_source = self.dmodel(mixture)
# only eval last layer
if cross_valid:
estimate_source = estimate_source[-1:]
loss = 0
cnt = len(estimate_source)
# apply a loss function after each layer
with torch.autograd.set_detect_anomaly(True):
for c_idx, est_src in enumerate(estimate_source):
coeff = ((c_idx+1)*(1/cnt))
loss_i = 0
# SI-SNR loss
sisnr_loss, snr, est_src, reorder_est_src = cal_loss(
sources, estimate_source[c_idx], lengths)
loss += (coeff * sisnr_loss)
loss /= len(estimate_source)
if not cross_valid:
# optimize model in training mode
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(),
self.max_norm)
self.optimizer.step()
total_loss += loss.item()
logprog.update(loss=format(total_loss / (i + 1), ".5f"))
# Just in case, clear some memory
del loss, estimate_source
return distrib.average([total_loss / (i + 1)], i + 1)[0]