Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import typing as tp | |
import numpy as np | |
from torchaudio.transforms import MelSpectrogram | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from ..modules import pad_for_conv1d | |
class MelSpectrogramWrapper(nn.Module): | |
"""Wrapper around MelSpectrogram torchaudio transform providing proper padding | |
and additional post-processing including log scaling. | |
Args: | |
n_mels (int): Number of mel bins. | |
n_fft (int): Number of fft. | |
hop_length (int): Hop size. | |
win_length (int): Window length. | |
n_mels (int): Number of mel bins. | |
sample_rate (int): Sample rate. | |
f_min (float or None): Minimum frequency. | |
f_max (float or None): Maximum frequency. | |
log (bool): Whether to scale with log. | |
normalized (bool): Whether to normalize the melspectrogram. | |
floor_level (float): Floor level based on human perception (default=1e-5). | |
""" | |
def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None, | |
n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None, | |
log: bool = True, normalized: bool = False, floor_level: float = 1e-5): | |
super().__init__() | |
self.n_fft = n_fft | |
hop_length = int(hop_length) | |
self.hop_length = hop_length | |
self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, | |
win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized, | |
window_fn=torch.hann_window, center=False) | |
self.floor_level = floor_level | |
self.log = log | |
def forward(self, x): | |
p = int((self.n_fft - self.hop_length) // 2) | |
if len(x.shape) == 2: | |
x = x.unsqueeze(1) | |
x = F.pad(x, (p, p), "reflect") | |
# Make sure that all the frames are full. | |
# The combination of `pad_for_conv1d` and the above padding | |
# will make the output of size ceil(T / hop). | |
x = pad_for_conv1d(x, self.n_fft, self.hop_length) | |
self.mel_transform.to(x.device) | |
mel_spec = self.mel_transform(x) | |
B, C, freqs, frame = mel_spec.shape | |
if self.log: | |
mel_spec = torch.log10(self.floor_level + mel_spec) | |
return mel_spec.reshape(B, C * freqs, frame) | |
class MelSpectrogramL1Loss(torch.nn.Module): | |
"""L1 Loss on MelSpectrogram. | |
Args: | |
sample_rate (int): Sample rate. | |
n_fft (int): Number of fft. | |
hop_length (int): Hop size. | |
win_length (int): Window length. | |
n_mels (int): Number of mel bins. | |
f_min (float or None): Minimum frequency. | |
f_max (float or None): Maximum frequency. | |
log (bool): Whether to scale with log. | |
normalized (bool): Whether to normalize the melspectrogram. | |
floor_level (float): Floor level value based on human perception (default=1e-5). | |
""" | |
def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, | |
n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None, | |
log: bool = True, normalized: bool = False, floor_level: float = 1e-5): | |
super().__init__() | |
self.l1 = torch.nn.L1Loss() | |
self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length, | |
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, | |
log=log, normalized=normalized, floor_level=floor_level) | |
def forward(self, x, y): | |
self.melspec.to(x.device) | |
s_x = self.melspec(x) | |
s_y = self.melspec(y) | |
return self.l1(s_x, s_y) | |
class MultiScaleMelSpectrogramLoss(nn.Module): | |
"""Multi-Scale spectrogram loss (msspec). | |
Args: | |
sample_rate (int): Sample rate. | |
range_start (int): Power of 2 to use for the first scale. | |
range_stop (int): Power of 2 to use for the last scale. | |
n_mels (int): Number of mel bins. | |
f_min (float): Minimum frequency. | |
f_max (float or None): Maximum frequency. | |
normalized (bool): Whether to normalize the melspectrogram. | |
alphas (bool): Whether to use alphas as coefficients or not. | |
floor_level (float): Floor level value based on human perception (default=1e-5). | |
""" | |
def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11, | |
n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None, | |
normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5): | |
super().__init__() | |
l1s = list() | |
l2s = list() | |
self.alphas = list() | |
self.total = 0 | |
self.normalized = normalized | |
for i in range(range_start, range_end): | |
l1s.append( | |
MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i, | |
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, | |
log=False, normalized=normalized, floor_level=floor_level)) | |
l2s.append( | |
MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i, | |
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, | |
log=True, normalized=normalized, floor_level=floor_level)) | |
if alphas: | |
self.alphas.append(np.sqrt(2 ** i - 1)) | |
else: | |
self.alphas.append(1) | |
self.total += self.alphas[-1] + 1 | |
self.l1s = nn.ModuleList(l1s) | |
self.l2s = nn.ModuleList(l2s) | |
def forward(self, x, y): | |
loss = 0.0 | |
self.l1s.to(x.device) | |
self.l2s.to(x.device) | |
for i in range(len(self.alphas)): | |
s_x_1 = self.l1s[i](x) | |
s_y_1 = self.l1s[i](y) | |
s_x_2 = self.l2s[i](x) | |
s_y_2 = self.l2s[i](y) | |
loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2) | |
if self.normalized: | |
loss = loss / self.total | |
return loss | |