OpenSound's picture
Upload 33 files
71de706 verified
import typing
from typing import List
import numpy as np
from torch import nn
from .. import AudioSignal
from .. import STFTParams
class MultiScaleSTFTLoss(nn.Module):
"""Computes the multi-scale STFT loss from [1].
Parameters
----------
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
References
----------
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
"DDSP: Differentiable Digital Signal Processing."
International Conference on Learning Representations. 2019.
"""
def __init__(
self,
window_lengths: List[int] = [2048, 512],
loss_fn: typing.Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
pow: float = 2.0,
weight: float = 1.0,
match_stride: bool = False,
window_type: str = None,
):
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
]
self.loss_fn = loss_fn
self.log_weight = log_weight
self.mag_weight = mag_weight
self.clamp_eps = clamp_eps
self.weight = weight
self.pow = pow
def forward(self, x: AudioSignal, y: AudioSignal):
"""Computes multi-scale STFT between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal
Estimate signal
y : AudioSignal
Reference signal
Returns
-------
torch.Tensor
Multi-scale STFT loss.
"""
loss = 0.0
for s in self.stft_params:
x.stft(s.window_length, s.hop_length, s.window_type)
y.stft(s.window_length, s.hop_length, s.window_type)
loss += self.log_weight * self.loss_fn(
x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
)
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
return loss
class MelSpectrogramLoss(nn.Module):
"""Compute distance between mel spectrograms. Can be used
in a multi-scale way.
Parameters
----------
n_mels : List[int]
Number of mels per STFT, by default [150, 80],
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
"""
def __init__(
self,
n_mels: List[int] = [150, 80],
window_lengths: List[int] = [2048, 512],
loss_fn: typing.Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
pow: float = 2.0,
weight: float = 1.0,
match_stride: bool = False,
mel_fmin: List[float] = [0.0, 0.0],
mel_fmax: List[float] = [None, None],
window_type: str = None,
):
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
]
self.n_mels = n_mels
self.loss_fn = loss_fn
self.clamp_eps = clamp_eps
self.log_weight = log_weight
self.mag_weight = mag_weight
self.weight = weight
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.pow = pow
def forward(self, x: AudioSignal, y: AudioSignal):
"""Computes mel loss between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal
Estimate signal
y : AudioSignal
Reference signal
Returns
-------
torch.Tensor
Mel loss.
"""
loss = 0.0
for n_mels, fmin, fmax, s in zip(
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
):
kwargs = {
"window_length": s.window_length,
"hop_length": s.hop_length,
"window_type": s.window_type,
}
x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
loss += self.log_weight * self.loss_fn(
x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
)
loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
return loss
class PhaseLoss(nn.Module):
"""Difference between phase spectrograms.
Parameters
----------
window_length : int, optional
Length of STFT window, by default 2048
hop_length : int, optional
Hop length of STFT window, by default 512
weight : float, optional
Weight of loss, by default 1.0
"""
def __init__(
self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0
):
super().__init__()
self.weight = weight
self.stft_params = STFTParams(window_length, hop_length)
def forward(self, x: AudioSignal, y: AudioSignal):
"""Computes phase loss between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal
Estimate signal
y : AudioSignal
Reference signal
Returns
-------
torch.Tensor
Phase loss.
"""
s = self.stft_params
x.stft(s.window_length, s.hop_length, s.window_type)
y.stft(s.window_length, s.hop_length, s.window_type)
# Take circular difference
diff = x.phase - y.phase
diff[diff < -np.pi] += 2 * np.pi
diff[diff > np.pi] -= -2 * np.pi
# Scale true magnitude to weights in [0, 1]
x_min, x_max = x.magnitude.min(), x.magnitude.max()
weights = (x.magnitude - x_min) / (x_max - x_min)
# Take weighted mean of all phase errors
loss = ((weights * diff) ** 2).mean()
return loss