zhzluke96
update
32b2aaa
raw
history blame
No virus
2.06 kB
import numpy as np
import torch
from torch import nn
from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram
from .hparams import HParams
class MelSpectrogram(nn.Module):
def __init__(self, hp: HParams):
"""
Torch implementation of Resemble's mel extraction.
Note that the values are NOT identical to librosa's implementation
due to floating point precisions.
"""
super().__init__()
self.hp = hp
self.melspec = TorchMelSpectrogram(
hp.wav_rate,
n_fft=hp.n_fft,
win_length=hp.win_size,
hop_length=hp.hop_size,
f_min=0,
f_max=hp.wav_rate // 2,
n_mels=hp.num_mels,
power=1,
normalized=False,
# NOTE: Folowing librosa's default.
pad_mode="constant",
norm="slaney",
mel_scale="slaney",
)
self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min]))
self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
self.preemphasis = hp.preemphasis
self.hop_size = hp.hop_size
def forward(self, wav, pad=True):
"""
Args:
wav: [B, T]
"""
device = wav.device
if wav.is_mps:
wav = wav.cpu()
self.to(wav.device)
if self.preemphasis > 0:
wav = torch.nn.functional.pad(wav, [1, 0], value=0)
wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
mel = self.melspec(wav)
mel = self._amp_to_db(mel)
mel_normed = self._normalize(mel)
assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size # Sanity check
mel_normed = mel_normed.to(device)
return mel_normed # (M, T)
def _normalize(self, s, headroom_db=15):
return (s - self.min_level_db) / (-self.min_level_db + headroom_db)
def _amp_to_db(self, x):
return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20