Spaces:
Paused
Paused
# Copyright (c) 2024 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/jik876/hifi-gan under the MIT license. | |
# LICENSE is in incl_licenses directory. | |
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
from librosa.filters import mel as librosa_mel_fn | |
from scipy import signal | |
import typing | |
from typing import Optional, List, Union, Dict, Tuple | |
from collections import namedtuple | |
import math | |
import functools | |
# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license. | |
# LICENSE is in incl_licenses directory. | |
class MultiScaleMelSpectrogramLoss(nn.Module): | |
"""Compute distance between mel spectrograms. Can be used | |
in a multi-scale way. | |
Parameters | |
---------- | |
n_mels : List[int] | |
Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320], | |
window_lengths : List[int], optional | |
Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048] | |
loss_fn : typing.Callable, optional | |
How to compare each loss, by default nn.L1Loss() | |
clamp_eps : float, optional | |
Clamp on the log magnitude, below, by default 1e-5 | |
mag_weight : float, optional | |
Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part) | |
log_weight : float, optional | |
Weight of log magnitude portion of loss, by default 1.0 | |
pow : float, optional | |
Power to raise magnitude to before taking log, by default 1.0 | |
weight : float, optional | |
Weight of this loss, by default 1.0 | |
match_stride : bool, optional | |
Whether to match the stride of convolutional layers, by default False | |
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py | |
Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py | |
""" | |
def __init__( | |
self, | |
sampling_rate: int, | |
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320], | |
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048], | |
loss_fn: typing.Callable = nn.L1Loss(), | |
clamp_eps: float = 1e-5, | |
mag_weight: float = 0.0, | |
log_weight: float = 1.0, | |
pow: float = 1.0, | |
weight: float = 1.0, | |
match_stride: bool = False, | |
mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0], | |
mel_fmax: List[float] = [None, None, None, None, None, None, None], | |
window_type: str = "hann", | |
): | |
super().__init__() | |
self.sampling_rate = sampling_rate | |
STFTParams = namedtuple( | |
"STFTParams", | |
["window_length", "hop_length", "window_type", "match_stride"], | |
) | |
self.stft_params = [ | |
STFTParams( | |
window_length=w, | |
hop_length=w // 4, | |
match_stride=match_stride, | |
window_type=window_type, | |
) | |
for w in window_lengths | |
] | |
self.n_mels = n_mels | |
self.loss_fn = loss_fn | |
self.clamp_eps = clamp_eps | |
self.log_weight = log_weight | |
self.mag_weight = mag_weight | |
self.weight = weight | |
self.mel_fmin = mel_fmin | |
self.mel_fmax = mel_fmax | |
self.pow = pow | |
def get_window( | |
window_type, | |
window_length, | |
): | |
return signal.get_window(window_type, window_length) | |
def get_mel_filters(sr, n_fft, n_mels, fmin, fmax): | |
return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) | |
def mel_spectrogram( | |
self, | |
wav, | |
n_mels, | |
fmin, | |
fmax, | |
window_length, | |
hop_length, | |
match_stride, | |
window_type, | |
): | |
""" | |
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from: | |
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py | |
""" | |
B, C, T = wav.shape | |
if match_stride: | |
assert ( | |
hop_length == window_length // 4 | |
), "For match_stride, hop must equal n_fft // 4" | |
right_pad = math.ceil(T / hop_length) * hop_length - T | |
pad = (window_length - hop_length) // 2 | |
else: | |
right_pad = 0 | |
pad = 0 | |
wav = torch.nn.functional.pad(wav, (pad, pad + right_pad), mode="reflect") | |
window = self.get_window(window_type, window_length) | |
window = torch.from_numpy(window).to(wav.device).float() | |
stft = torch.stft( | |
wav.reshape(-1, T), | |
n_fft=window_length, | |
hop_length=hop_length, | |
window=window, | |
return_complex=True, | |
center=True, | |
) | |
_, nf, nt = stft.shape | |
stft = stft.reshape(B, C, nf, nt) | |
if match_stride: | |
""" | |
Drop first two and last two frames, which are added, because of padding. Now num_frames * hop_length = num_samples. | |
""" | |
stft = stft[..., 2:-2] | |
magnitude = torch.abs(stft) | |
nf = magnitude.shape[2] | |
mel_basis = self.get_mel_filters( | |
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax | |
) | |
mel_basis = torch.from_numpy(mel_basis).to(wav.device) | |
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T | |
mel_spectrogram = mel_spectrogram.transpose(-1, 2) | |
return mel_spectrogram | |
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
"""Computes mel loss between an estimate and a reference | |
signal. | |
Parameters | |
---------- | |
x : torch.Tensor | |
Estimate signal | |
y : torch.Tensor | |
Reference signal | |
Returns | |
------- | |
torch.Tensor | |
Mel loss. | |
""" | |
loss = 0.0 | |
for n_mels, fmin, fmax, s in zip( | |
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params | |
): | |
kwargs = { | |
"n_mels": n_mels, | |
"fmin": fmin, | |
"fmax": fmax, | |
"window_length": s.window_length, | |
"hop_length": s.hop_length, | |
"match_stride": s.match_stride, | |
"window_type": s.window_type, | |
} | |
x_mels = self.mel_spectrogram(x, **kwargs) | |
y_mels = self.mel_spectrogram(y, **kwargs) | |
x_logmels = torch.log( | |
x_mels.clamp(min=self.clamp_eps).pow(self.pow) | |
) / torch.log(torch.tensor(10.0)) | |
y_logmels = torch.log( | |
y_mels.clamp(min=self.clamp_eps).pow(self.pow) | |
) / torch.log(torch.tensor(10.0)) | |
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels) | |
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels) | |
return loss | |
# Loss functions | |
def feature_loss( | |
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]] | |
) -> torch.Tensor: | |
loss = 0 | |
for dr, dg in zip(fmap_r, fmap_g): | |
for rl, gl in zip(dr, dg): | |
loss += torch.mean(torch.abs(rl - gl)) | |
return loss * 2 # This equates to lambda=2.0 for the feature matching loss | |
def discriminator_loss( | |
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor] | |
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: | |
loss = 0 | |
r_losses = [] | |
g_losses = [] | |
for dr, dg in zip(disc_real_outputs, disc_generated_outputs): | |
r_loss = torch.mean((1 - dr) ** 2) | |
g_loss = torch.mean(dg**2) | |
loss += r_loss + g_loss | |
r_losses.append(r_loss.item()) | |
g_losses.append(g_loss.item()) | |
return loss, r_losses, g_losses | |
def generator_loss( | |
disc_outputs: List[torch.Tensor], | |
) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
loss = 0 | |
gen_losses = [] | |
for dg in disc_outputs: | |
l = torch.mean((1 - dg) ** 2) | |
gen_losses.append(l) | |
loss += l | |
return loss, gen_losses | |