Fast-GeCo / fastgeco /model.py
anonymous9a7b
1
d4c980e
raw
history blame
11 kB
import random
import time
from math import ceil
import warnings
import numpy as np
# from asteroid.losses.sdr import SingleSrcNegSDR
import torch
import pytorch_lightning as pl
from torch_ema import ExponentialMovingAverage
import torch.nn.functional as F
from geco import sampling
from geco.sdes import SDERegistry
from fastgeco.backbones import BackboneRegistry
from geco.util.inference import evaluate_model2
from geco.util.other import pad_spec
import numpy as np
import matplotlib.pyplot as plt
class ScoreModel(pl.LightningModule):
@staticmethod
def add_argparse_args(parser):
parser.add_argument("--lr", type=float, default=1e-5, help="The learning rate (1e-4 by default)")
parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum time (3e-2 by default)")
parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
parser.add_argument("--loss_type", type=str, default="mse", help="The type of loss function to use.")
parser.add_argument("--loss_abs_exponent", type=float, default=0.5, help="magnitude transformation in the loss term")
parser.add_argument("--output_scale", type=str, choices=('sigma', 'time'), default= 'time', help="backbone model scale before last output layer")
return parser
def __init__(
self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=3e-2, loss_abs_exponent=0.5,
num_eval_files=20, loss_type='mse', data_module_cls=None, output_scale='time', inference_N=1,
inference_start=0.5, **kwargs
):
"""
Create a new ScoreModel.
Args:
backbone: Backbone DNN that serves as a score-based model.
sde: The SDE that defines the diffusion process.
lr: The learning rate of the optimizer. (1e-4 by default).
ema_decay: The decay constant of the parameter EMA (0.999 by default).
t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
"""
super().__init__()
# Initialize Backbone DNN
dnn_cls = BackboneRegistry.get_by_name(backbone)
self.dnn = dnn_cls(**kwargs)
# Initialize SDE
sde_cls = SDERegistry.get_by_name(sde)
self.sde = sde_cls(**kwargs)
# Store hyperparams and save them
self.lr = lr
self.ema_decay = ema_decay
self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
self._error_loading_ema = False
self.t_eps = t_eps
self.loss_type = loss_type
self.num_eval_files = num_eval_files
self.loss_abs_exponent = loss_abs_exponent
self.output_scale = output_scale
self.save_hyperparameters(ignore=['no_wandb'])
self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
self.inference_N = inference_N
self.inference_start = inference_start
# self.si_snr = SingleSrcNegSDR("sisdr", reduction='mean', zero_mean=False)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def optimizer_step(self, *args, **kwargs):
# Method overridden so that the EMA params are updated after each optimizer step
super().optimizer_step(*args, **kwargs)
self.ema.update(self.parameters())
# on_load_checkpoint / on_save_checkpoint needed for EMA storing/loading
def on_load_checkpoint(self, checkpoint):
ema = checkpoint.get('ema', None)
if ema is not None:
self.ema.load_state_dict(checkpoint['ema'])
else:
self._error_loading_ema = True
warnings.warn("EMA state_dict not found in checkpoint!")
def on_save_checkpoint(self, checkpoint):
checkpoint['ema'] = self.ema.state_dict()
def train(self, mode, no_ema=False):
res = super().train(mode) # call the standard `train` method with the given mode
if not self._error_loading_ema:
if mode == False and not no_ema:
# eval
self.ema.store(self.parameters()) # store current params in EMA
self.ema.copy_to(self.parameters()) # copy EMA parameters over current params for evaluation
else:
# train
if self.ema.collected_params is not None:
self.ema.restore(self.parameters()) # restore the EMA weights (if stored)
return res
def eval(self, no_ema=False):
return self.train(False, no_ema=no_ema)
def sisnr(self, est, ref, eps = 1e-8):
est = est - torch.mean(est, dim = -1, keepdim = True)
ref = ref - torch.mean(ref, dim = -1, keepdim = True)
est_p = (torch.sum(est * ref, dim = -1, keepdim = True) * ref) / torch.sum(ref * ref, dim = -1, keepdim = True)
est_v = est - est_p
est_sisnr = 10 * torch.log10((torch.sum(est_p * est_p, dim = -1, keepdim = True) + eps) / (torch.sum(est_v * est_v, dim = -1, keepdim = True) + eps))
return -est_sisnr
def _loss(self, wav_x_tm1, wav_gt):
if self.loss_type == 'default':
min_leng = min(wav_x_tm1.shape[-1], wav_gt.shape[-1])
wav_x_tm1 = wav_x_tm1.squeeze(1)[:,:min_leng]
wav_gt = wav_gt.squeeze(1)[:,:min_leng]
loss = torch.mean(self.sisnr(wav_x_tm1, wav_gt))
else:
raise RuntimeError(f'{self.loss_type} loss not defined')
return loss
def euler_step(self, X, X_t, Y, M, t, dt):
f, g = self.sde.sde(X_t, t, Y)
vec_t = torch.ones(Y.shape[0], device=Y.device) * t
mean_x_tm1 = X_t - (f - g**2*self.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None]))*dt
z = torch.randn_like(X)
X_t = mean_x_tm1 + z*g*torch.sqrt(dt)
return X_t
def training_step(self, batch, batch_idx):
X, Y, M = batch
reverse_start_time = random.uniform(self.t_rsp_min, self.t_rsp_max)
N_reverse = random.randint(self.N_min, self.N_max)
if self.stop_iteration_random == "random":
stop_iteration = random.randint(0, N_reverse-1)
elif self.stop_iteration_random == "last":
#Used in publication. This means that only the last step is used for updating weights.
stop_iteration = N_reverse-1
else:
raise RuntimeError(f'{self.stop_iteration_random} not defined')
timesteps = torch.linspace(reverse_start_time, self.t_eps, N_reverse, device=Y.device)
#prior sampling starting from reverse_start_time
std = self.sde._std(reverse_start_time*torch.ones((Y.shape[0],), device=Y.device))
z = torch.randn_like(Y)
X_t = Y + z * std[:, None, None, None]
#reverse steps by Euler Maruyama
for i in range(len(timesteps)):
t = timesteps[i]
if i != len(timesteps) - 1:
dt = t - timesteps[i+1]
else:
dt = timesteps[-1]
if i != stop_iteration:
with torch.no_grad():
#take Euler step here
X_t = self.euler_step(X, X_t, Y, M, t, dt)
else:
#take a Euler step and compute loss
f, g = self.sde.sde(X_t, t, Y)
vec_t = torch.ones(Y.shape[0], device=Y.device) * t
score = self.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None])
mean_x_tm1 = X_t - (f - g**2*score)*dt #mean of x t minus 1 = mu(x_{t-1})
mean_gt, _ = self.sde.marginal_prob(X, torch.ones(Y.shape[0], device=Y.device) * (t-dt), Y)
wav_gt = self.to_audio(mean_gt.squeeze())
wav_x_tm1 = self.to_audio(mean_x_tm1.squeeze())
loss = self._loss(wav_x_tm1, wav_gt)
break
self.log('train_loss', loss, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
# Evaluate speech enhancement performance, compute loss only for a few val data
if batch_idx == 0 and self.num_eval_files != 0:
pesq, si_sdr, estoi, loss = evaluate_model2(self, self.num_eval_files, self.inference_N, inference_start=self.inference_start)
self.log('pesq', pesq, on_step=False, on_epoch=True)
self.log('si_sdr', si_sdr, on_step=False, on_epoch=True)
self.log('estoi', estoi, on_step=False, on_epoch=True)
self.log('valid_loss', loss, on_step=False, on_epoch=True)
return loss
def forward(self, x, t, y, m, divide_scale):
# Concatenate y as an extra channel
dnn_input = torch.cat([x, y, m], dim=1)
# the minus is most likely unimportant here - taken from Song's repo
score = -self.dnn(dnn_input, t, divide_scale)
return score
def to(self, *args, **kwargs):
"""Override PyTorch .to() to also transfer the EMA of the model weights"""
self.ema.to(*args, **kwargs)
return super().to(*args, **kwargs)
def train_dataloader(self):
return self.data_module.train_dataloader()
def val_dataloader(self):
return self.data_module.val_dataloader()
def test_dataloader(self):
return self.data_module.test_dataloader()
def setup(self, stage=None):
return self.data_module.setup(stage=stage)
def to_audio(self, spec, length=None):
return self._istft(self._backward_transform(spec), length)
def _forward_transform(self, spec):
return self.data_module.spec_fwd(spec)
def _backward_transform(self, spec):
return self.data_module.spec_back(spec)
def _stft(self, sig):
return self.data_module.stft(sig)
def _istft(self, spec, length=None):
return self.data_module.istft(spec, length)
def add_para(self, N_min=1, N_max=1, t_rsp_min=0.5, t_rsp_max=0.5, batch_size=64, loss_type='default', lr=5e-5, stop_iteration_random='last', inference_N=1, inference_start=0.5):
self.t_rsp_min = t_rsp_min
self.t_rsp_max = t_rsp_max
self.N_min = N_min
self.N_max = N_max
self.data_module.batch_size = batch_size
self.data_module.num_workers = 4
self.data_module.gpu = True
self.loss_type = loss_type
self.lr = lr
self.stop_iteration_random = stop_iteration_random
self.inference_N = inference_N
self.inference_start = inference_start