|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""spectrogram.py""" |
|
import importlib |
|
from typing import Optional, Literal, Dict, Tuple |
|
from packaging.version import parse as VersionParse |
|
|
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from model.ops import minmax_normalize |
|
from config.config import audio_cfg as default_audio_cfg |
|
""" |
|
Example usage: |
|
|
|
# MT3 setup |
|
>>> hop = 8 ms or 128 samples |
|
>>> melspec = Melspectrogram(sample_rate=16000, n_fft=2048, hop_length=128, |
|
f_min=50, f_max=8000, n_mels=512) |
|
>>> x = torch.randn(2, 1, 32767) # (B, C=1, T): 2.048 s |
|
>>> y = melspec(x) # (2, 256, 512) (B, T, F) |
|
|
|
# PerceiverTF-like setup |
|
>>> hop = 18.75 ms or 300 samples |
|
>>> spec = Spectrogram(n_fft=2048, hop_length=300) |
|
) |
|
>>> x = torch.randn(2, 1, 95999) # (B, C=1, T): 6.000 s |
|
>>> y = spec(x) # (2, 320, 1024) (B, T, F) |
|
|
|
# Hybrid setup (2.048 seconds segment and spectrogram with hop=300) |
|
>>> hop = 18.75 ms or 300 samples |
|
>>> spec = Spectrogram(n_fft=2048, hop_length=300) |
|
>>> x = torch.randn(2, 1, 32767) # (B, C=1, T): 2.048 s |
|
>>> y = spec(x) # (2, 110, 1024) (B, T, F) |
|
|
|
# PerceiverTF-like setup, hop=256 |
|
>>> hop = 16 ms or 256 samples |
|
>>> spec256 = Spectrogram(sample_rate=16000, n_fft=2048, hop_length=256, |
|
f_min=20, f_max=8000, n_mels=256) |
|
>>> x = torch.randn(2, 1, 32767) # (B, C=1, T): 2.048 s |
|
>>> y = spec256(x) # (2, 128, 1024) (B, T, F) |
|
""" |
|
|
|
|
|
def optional_compiler_disable(func): |
|
if VersionParse(torch.__version__) >= VersionParse("2.1"): |
|
|
|
return torch.compiler.disable(func) |
|
else: |
|
|
|
return func |
|
|
|
|
|
|
|
|
|
|
|
class Melspectrogram(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
audio_backend: Literal['torchaudio', 'nnaudio'] = 'torchaudio', |
|
sample_rate: int = 16000, |
|
n_fft: int = 2048, |
|
hop_length: int = 128, |
|
f_min: int = 50, |
|
f_max: Optional[int] = 8000, |
|
n_mels: int = 512, |
|
eps: float = 1e-5, |
|
**kwargs, |
|
): |
|
""" |
|
Log-Melspectrogram |
|
|
|
Args: |
|
audio_backend (str): 'torchaudio' or 'nnaudio' |
|
sample_rate (int): sample rate in Hz |
|
n_fft (int): FFT window size |
|
hop_length (int): hop length in samples |
|
f_min (int): minimum frequency in Hz |
|
f_max (int): maximum frequency in Hz |
|
n_mels (int): number of mel frequency bins |
|
eps (float): epsilon for numerical stability |
|
|
|
""" |
|
super(Melspectrogram, self).__init__() |
|
self.audio_backend = audio_backend.lower() |
|
|
|
if audio_backend.lower() == 'torchaudio': |
|
torchaudio = importlib.import_module('torchaudio') |
|
self.mel_stft = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=sample_rate, |
|
n_fft=n_fft, |
|
hop_length=hop_length, |
|
f_min=f_min, |
|
f_max=f_max, |
|
n_mels=n_mels, |
|
) |
|
elif audio_backend.lower() == 'nnaudio': |
|
nnaudio = importlib.import_module('nnAudio.features') |
|
self.mel_stft_nnaudio = nnaudio.mel.MelSpectrogram( |
|
sr=sample_rate, |
|
win_length=n_fft, |
|
n_mels=n_mels, |
|
hop_length=hop_length, |
|
fmin=20, |
|
fmax=f_max) |
|
else: |
|
raise NotImplementedError(audio_backend) |
|
self.eps = eps |
|
|
|
@optional_compiler_disable |
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
x (torch.Tensor): (B, 1, T) |
|
|
|
Returns: |
|
torch.Tensor: (B, T, F) |
|
|
|
""" |
|
if self.audio_backend == 'torchaudio': |
|
x = self.mel_stft(x) |
|
x = rearrange(x, 'b 1 f t -> b t f') |
|
x = minmax_normalize(torch.log(x + self.eps)) |
|
|
|
return torch.nan_to_num(x) |
|
|
|
elif self.audio_backend == 'nnaudio': |
|
x = self.mel_stft_nnaudio(x) |
|
x = rearrange(x, 'b f t -> b t f') |
|
x = minmax_normalize(torch.log(x + self.eps)) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class Spectrogram(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
audio_backend: Literal['torchaudio', 'nnaudio'] = 'torchaudio', |
|
n_fft: int = 2048, |
|
hop_length: int = 128, |
|
eps: float = 1e-5, |
|
**kwargs, |
|
): |
|
""" |
|
Log-Magnitude Spectrogram |
|
|
|
Args: |
|
audio_backend (str): 'torchaudio' or 'nnaudio' |
|
n_fft (int): FFT window size, creates n_fft // 2 + 1 freq-bins |
|
hop_length (int): hop length in samples |
|
eps (float): epsilon for numerical stability |
|
|
|
""" |
|
super(Spectrogram, self).__init__() |
|
self.audio_backend = audio_backend.lower() |
|
|
|
if audio_backend.lower() == 'torchaudio': |
|
torchaudio = importlib.import_module('torchaudio') |
|
self.stft = torchaudio.transforms.Spectrogram(n_fft=n_fft, |
|
hop_length=hop_length, |
|
window_fn=torch.hann_window, |
|
power=1.) |
|
elif audio_backend.lower() == 'nnaudio': |
|
|
|
raise NotImplementedError(audio_backend) |
|
else: |
|
raise NotImplementedError(audio_backend) |
|
self.eps = eps |
|
|
|
@optional_compiler_disable |
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
x (torch.Tensor): (B, 1, T) |
|
|
|
Returns: |
|
torch.Tensor: (B, T, F) |
|
|
|
""" |
|
if self.audio_backend == 'torchaudio': |
|
x = self.stft(x)[:, :, 1:, :] |
|
x = rearrange(x, 'b 1 f t -> b t f') |
|
x = minmax_normalize(torch.log(x + self.eps)) |
|
return torch.nan_to_num(x) |
|
elif self.audio_backend == 'nnaudio': |
|
raise NotImplementedError(self.audio_backend) |
|
|
|
|
|
def get_spectrogram_layer_from_audio_cfg(audio_cfg: Optional[Dict] = None) -> Tuple[nn.Module, Tuple[int]]: |
|
"""Get mel-/spectrogram layer from config. |
|
- Used by 'ymt3' to create a spectrogram layer. |
|
- Returns output shape of spectrogram layer, which is used to determine input shape of model. |
|
|
|
Args: |
|
audio_cfg (dict): see config/config.py |
|
|
|
Returns: |
|
layer (nn.Module): mel-/spectrogram layer |
|
output_shape (tuple): inferred output shape of layer excluding batch dim. (T, F) |
|
""" |
|
if audio_cfg is None: |
|
audio_cfg = default_audio_cfg |
|
|
|
if audio_cfg['codec'] == 'melspec': |
|
layer = Melspectrogram(**audio_cfg) |
|
elif audio_cfg['codec'] == 'spec': |
|
layer = Spectrogram(**audio_cfg) |
|
else: |
|
raise NotImplementedError(audio_cfg['codec']) |
|
|
|
|
|
with torch.no_grad(): |
|
output_shape = layer(torch.randn(1, 1, audio_cfg['input_frames'])).shape[1:] |
|
return layer, output_shape |
|
|