|
import math |
|
import random |
|
from typing import Optional, Union, Tuple |
|
|
|
import librosa |
|
import torchaudio |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
try: |
|
import torchaudio |
|
|
|
HAVE_TORCHAUDIO = True |
|
except ModuleNotFoundError: |
|
HAVE_TORCHAUDIO = False |
|
|
|
CONSTANT = 1e-5 |
|
|
|
|
|
def normalize_batch(x, seq_len, normalize_type): |
|
x_mean = None |
|
x_std = None |
|
if normalize_type == "per_feature": |
|
batch_size = x.shape[0] |
|
max_time = x.shape[2] |
|
|
|
|
|
|
|
|
|
if ( |
|
torch.cuda.is_available() |
|
and not torch.cuda.is_current_stream_capturing() |
|
and torch.any(seq_len == 1).item() |
|
): |
|
raise ValueError( |
|
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result " |
|
"in torch.std() returning nan. Make sure your audio length has enough samples for a single " |
|
"feature (ex. at least `hop_length` for Mel Spectrograms)." |
|
) |
|
time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time) |
|
valid_mask = time_steps < seq_len.unsqueeze(1) |
|
x_mean_numerator = torch.where(valid_mask.unsqueeze(1), x, 0.0).sum(axis=2) |
|
x_mean_denominator = valid_mask.sum(axis=1) |
|
x_mean = x_mean_numerator / x_mean_denominator.unsqueeze(1) |
|
|
|
|
|
x_std = torch.sqrt( |
|
torch.sum(torch.where(valid_mask.unsqueeze(1), x - x_mean.unsqueeze(2), 0.0) ** 2, axis=2) |
|
/ (x_mean_denominator.unsqueeze(1) - 1.0) |
|
) |
|
|
|
x_std += CONSTANT |
|
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std |
|
elif normalize_type == "all_features": |
|
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) |
|
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) |
|
for i in range(x.shape[0]): |
|
x_mean[i] = x[i, :, : seq_len[i].item()].mean() |
|
x_std[i] = x[i, :, : seq_len[i].item()].std() |
|
|
|
x_std += CONSTANT |
|
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std |
|
elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type: |
|
x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device) |
|
x_std = torch.tensor(normalize_type["fixed_std"], device=x.device) |
|
return ( |
|
(x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2), |
|
x_mean, |
|
x_std, |
|
) |
|
else: |
|
return x, x_mean, x_std |
|
|
|
|
|
def clean_spectrogram_batch(spectrogram: torch.Tensor, spectrogram_len: torch.Tensor, fill_value=0.0) -> torch.Tensor: |
|
""" |
|
Fill spectrogram values outside the length with `fill_value` |
|
|
|
Args: |
|
spectrogram: Tensor with shape [B, C, L] containing batched spectrograms |
|
spectrogram_len: Tensor with shape [B] containing the sequence length of each batch element |
|
fill_value: value to fill with, 0.0 by default |
|
|
|
Returns: |
|
cleaned spectrogram, tensor with shape equal to `spectrogram` |
|
""" |
|
device = spectrogram.device |
|
batch_size, _, max_len = spectrogram.shape |
|
mask = torch.arange(max_len, device=device)[None, :] >= spectrogram_len[:, None] |
|
mask = mask.unsqueeze(1).expand_as(spectrogram) |
|
return spectrogram.masked_fill(mask, fill_value) |
|
|
|
|
|
def splice_frames(x, frame_splicing): |
|
"""Stacks frames together across feature dim |
|
|
|
input is batch_size, feature_dim, num_frames |
|
output is batch_size, feature_dim*frame_splicing, num_frames |
|
|
|
""" |
|
seq = [x] |
|
for n in range(1, frame_splicing): |
|
seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2)) |
|
return torch.cat(seq, dim=1) |
|
|
|
|
|
@torch.jit.script_if_tracing |
|
def make_seq_mask_like( |
|
lengths: torch.Tensor, like: torch.Tensor, time_dim: int = -1, valid_ones: bool = True |
|
) -> torch.Tensor: |
|
""" |
|
|
|
Args: |
|
lengths: Tensor with shape [B] containing the sequence length of each batch element |
|
like: The mask will contain the same number of dimensions as this Tensor, and will have the same max |
|
length in the time dimension of this Tensor. |
|
time_dim: Time dimension of the `shape_tensor` and the resulting mask. Zero-based. |
|
valid_ones: If True, valid tokens will contain value `1` and padding will be `0`. Else, invert. |
|
|
|
Returns: |
|
A :class:`torch.Tensor` containing 1's and 0's for valid and invalid tokens, respectively, if `valid_ones`, else |
|
vice-versa. Mask will have the same number of dimensions as `like`. Batch and time dimensions will match |
|
the `like`. All other dimensions will be singletons. E.g., if `like.shape == [3, 4, 5]` and |
|
`time_dim == -1', mask will have shape `[3, 1, 5]`. |
|
""" |
|
|
|
mask = torch.arange(like.shape[time_dim], device=like.device).repeat(lengths.shape[0], 1).lt(lengths.view(-1, 1)) |
|
|
|
for _ in range(like.dim() - mask.dim()): |
|
mask = mask.unsqueeze(1) |
|
|
|
if time_dim != -1 and time_dim != mask.dim() - 1: |
|
mask = mask.transpose(-1, time_dim) |
|
|
|
if not valid_ones: |
|
mask = ~mask |
|
return mask |
|
|
|
|
|
class FilterbankFeatures(nn.Module): |
|
"""Featurizer that converts wavs to Mel Spectrograms. |
|
See AudioToMelSpectrogramPreprocessor for args. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sample_rate=16000, |
|
n_window_size=320, |
|
n_window_stride=160, |
|
window="hann", |
|
normalize="per_feature", |
|
n_fft=None, |
|
preemph=0.97, |
|
nfilt=64, |
|
lowfreq=0, |
|
highfreq=None, |
|
log=True, |
|
log_zero_guard_type="add", |
|
log_zero_guard_value=2**-24, |
|
dither=CONSTANT, |
|
pad_to=16, |
|
max_duration=16.7, |
|
frame_splicing=1, |
|
exact_pad=False, |
|
pad_value=0, |
|
mag_power=2.0, |
|
use_grads=False, |
|
rng=None, |
|
nb_augmentation_prob=0.0, |
|
nb_max_freq=4000, |
|
mel_norm="slaney", |
|
stft_exact_pad=False, |
|
stft_conv=False, |
|
): |
|
super().__init__() |
|
if stft_conv or stft_exact_pad: |
|
print( |
|
"Using torch_stft is deprecated and has been removed. The values have been forcibly set to False " |
|
"for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True " |
|
"as needed." |
|
) |
|
if exact_pad and n_window_stride % 2 == 1: |
|
raise NotImplementedError( |
|
f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the " |
|
"returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size." |
|
) |
|
self.log_zero_guard_value = log_zero_guard_value |
|
if ( |
|
n_window_size is None |
|
or n_window_stride is None |
|
or not isinstance(n_window_size, int) |
|
or not isinstance(n_window_stride, int) |
|
or n_window_size <= 0 |
|
or n_window_stride <= 0 |
|
): |
|
raise ValueError( |
|
f"{self} got an invalid value for either n_window_size or " |
|
f"n_window_stride. Both must be positive ints." |
|
) |
|
|
|
self.win_length = n_window_size |
|
self.hop_length = n_window_stride |
|
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) |
|
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None |
|
self.exact_pad = exact_pad |
|
|
|
if exact_pad: |
|
print("STFT using exact pad") |
|
torch_windows = { |
|
'hann': torch.hann_window, |
|
'hamming': torch.hamming_window, |
|
'blackman': torch.blackman_window, |
|
'bartlett': torch.bartlett_window, |
|
'none': None, |
|
} |
|
window_fn = torch_windows.get(window, None) |
|
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None |
|
self.register_buffer("window", window_tensor) |
|
|
|
self.normalize = normalize |
|
self.log = log |
|
self.dither = dither |
|
self.frame_splicing = frame_splicing |
|
self.nfilt = nfilt |
|
self.preemph = preemph |
|
self.pad_to = pad_to |
|
highfreq = highfreq or sample_rate / 2 |
|
|
|
filterbanks = torch.tensor( |
|
librosa.filters.mel( |
|
sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq, norm=mel_norm |
|
), |
|
dtype=torch.float, |
|
).unsqueeze(0) |
|
self.register_buffer("fb", filterbanks) |
|
|
|
|
|
max_length = self.get_seq_len(torch.tensor(max_duration * sample_rate, dtype=torch.float)) |
|
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0 |
|
self.max_length = max_length + max_pad |
|
self.pad_value = pad_value |
|
self.mag_power = mag_power |
|
|
|
|
|
|
|
if log_zero_guard_type not in ["add", "clamp"]: |
|
raise ValueError( |
|
f"{self} received {log_zero_guard_type} for the " |
|
f"log_zero_guard_type parameter. It must be either 'add' or " |
|
f"'clamp'." |
|
) |
|
|
|
self.use_grads = use_grads |
|
if not use_grads: |
|
self.forward = torch.no_grad()(self.forward) |
|
self._rng = random.Random() if rng is None else rng |
|
self.nb_augmentation_prob = nb_augmentation_prob |
|
if self.nb_augmentation_prob > 0.0: |
|
if nb_max_freq >= sample_rate / 2: |
|
self.nb_augmentation_prob = 0.0 |
|
else: |
|
self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft) |
|
|
|
|
|
|
|
self.log_zero_guard_type = log_zero_guard_type |
|
|
|
def stft(self, x): |
|
return torch.stft( |
|
x, |
|
n_fft=self.n_fft, |
|
hop_length=self.hop_length, |
|
win_length=self.win_length, |
|
center=False if self.exact_pad else True, |
|
window=self.window.to(dtype=torch.float), |
|
return_complex=True, |
|
) |
|
|
|
def log_zero_guard_value_fn(self, x): |
|
if isinstance(self.log_zero_guard_value, str): |
|
if self.log_zero_guard_value == "tiny": |
|
return torch.finfo(x.dtype).tiny |
|
elif self.log_zero_guard_value == "eps": |
|
return torch.finfo(x.dtype).eps |
|
else: |
|
raise ValueError( |
|
f"{self} received {self.log_zero_guard_value} for the " |
|
f"log_zero_guard_type parameter. It must be either a " |
|
f"number, 'tiny', or 'eps'" |
|
) |
|
else: |
|
return self.log_zero_guard_value |
|
|
|
def get_seq_len(self, seq_len): |
|
|
|
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2 |
|
seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1 |
|
return seq_len.to(dtype=torch.long) |
|
|
|
@property |
|
def filter_banks(self): |
|
return self.fb |
|
|
|
def forward(self, x, seq_len, linear_spec=False): |
|
seq_len = self.get_seq_len(seq_len) |
|
|
|
if self.stft_pad_amount is not None: |
|
x = torch.nn.functional.pad( |
|
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "reflect" |
|
).squeeze(1) |
|
|
|
|
|
if self.training and self.dither > 0: |
|
x += self.dither * torch.randn_like(x) |
|
|
|
|
|
if self.preemph is not None: |
|
x = torch.cat((x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1) |
|
|
|
|
|
with torch.amp.autocast(x.device.type, enabled=False): |
|
x = self.stft(x) |
|
|
|
|
|
|
|
guard = 0 if not self.use_grads else CONSTANT |
|
x = torch.view_as_real(x) |
|
x = torch.sqrt(x.pow(2).sum(-1) + guard) |
|
|
|
if self.training and self.nb_augmentation_prob > 0.0: |
|
for idx in range(x.shape[0]): |
|
if self._rng.random() < self.nb_augmentation_prob: |
|
x[idx, self._nb_max_fft_bin :, :] = 0.0 |
|
|
|
|
|
if self.mag_power != 1.0: |
|
x = x.pow(self.mag_power) |
|
|
|
|
|
if linear_spec: |
|
return x, seq_len |
|
|
|
|
|
x = torch.matmul(self.fb.to(x.dtype), x) |
|
|
|
if self.log: |
|
if self.log_zero_guard_type == "add": |
|
x = torch.log(x + self.log_zero_guard_value_fn(x)) |
|
elif self.log_zero_guard_type == "clamp": |
|
x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x))) |
|
else: |
|
raise ValueError("log_zero_guard_type was not understood") |
|
|
|
|
|
if self.frame_splicing > 1: |
|
x = splice_frames(x, self.frame_splicing) |
|
|
|
|
|
if self.normalize: |
|
x, _, _ = normalize_batch(x, seq_len, normalize_type=self.normalize) |
|
|
|
|
|
max_len = x.size(-1) |
|
mask = torch.arange(max_len, device=x.device) |
|
mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1) |
|
x = x.masked_fill(mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value) |
|
del mask |
|
pad_to = self.pad_to |
|
if pad_to == "max": |
|
x = nn.functional.pad(x, (0, self.max_length - x.size(-1)), value=self.pad_value) |
|
elif pad_to > 0: |
|
pad_amt = x.size(-1) % pad_to |
|
if pad_amt != 0: |
|
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value) |
|
return x, seq_len |
|
|
|
|
|
class FilterbankFeaturesTA(nn.Module): |
|
""" |
|
Exportable, `torchaudio`-based implementation of Mel Spectrogram extraction. |
|
|
|
See `AudioToMelSpectrogramPreprocessor` for args. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
sample_rate: int = 16000, |
|
n_window_size: int = 320, |
|
n_window_stride: int = 160, |
|
normalize: Optional[str] = "per_feature", |
|
nfilt: int = 64, |
|
n_fft: Optional[int] = None, |
|
preemph: float = 0.97, |
|
lowfreq: float = 0, |
|
highfreq: Optional[float] = None, |
|
log: bool = True, |
|
log_zero_guard_type: str = "add", |
|
log_zero_guard_value: Union[float, str] = 2**-24, |
|
dither: float = 1e-5, |
|
window: str = "hann", |
|
pad_to: int = 0, |
|
pad_value: float = 0.0, |
|
mel_norm="slaney", |
|
|
|
use_grads: bool = False, |
|
max_duration: float = 16.7, |
|
frame_splicing: int = 1, |
|
exact_pad: bool = False, |
|
nb_augmentation_prob: float = 0.0, |
|
nb_max_freq: int = 4000, |
|
mag_power: float = 2.0, |
|
rng: Optional[random.Random] = None, |
|
stft_exact_pad: bool = False, |
|
stft_conv: bool = False, |
|
): |
|
super().__init__() |
|
if not HAVE_TORCHAUDIO: |
|
raise ValueError(f"Need to install torchaudio to instantiate a {self.__class__.__name__}") |
|
|
|
|
|
supported_log_zero_guard_strings = {"eps", "tiny"} |
|
if isinstance(log_zero_guard_value, str) and log_zero_guard_value not in supported_log_zero_guard_strings: |
|
raise ValueError( |
|
f"Log zero guard value must either be a float or a member of {supported_log_zero_guard_strings}" |
|
) |
|
|
|
|
|
self.torch_windows = { |
|
'hann': torch.hann_window, |
|
'hamming': torch.hamming_window, |
|
'blackman': torch.blackman_window, |
|
'bartlett': torch.bartlett_window, |
|
'ones': torch.ones, |
|
None: torch.ones, |
|
} |
|
|
|
|
|
if window not in self.torch_windows: |
|
raise ValueError(f"Got window value '{window}' but expected a member of {self.torch_windows.keys()}") |
|
|
|
self.win_length = n_window_size |
|
self.hop_length = n_window_stride |
|
self._sample_rate = sample_rate |
|
self._normalize_strategy = normalize |
|
self._use_log = log |
|
self._preemphasis_value = preemph |
|
self.log_zero_guard_type = log_zero_guard_type |
|
self.log_zero_guard_value: Union[str, float] = log_zero_guard_value |
|
self.dither = dither |
|
self.pad_to = pad_to |
|
self.pad_value = pad_value |
|
self.n_fft = n_fft |
|
self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=self._sample_rate, |
|
win_length=self.win_length, |
|
hop_length=self.hop_length, |
|
n_mels=nfilt, |
|
window_fn=self.torch_windows[window], |
|
mel_scale="slaney", |
|
norm=mel_norm, |
|
n_fft=n_fft, |
|
f_max=highfreq, |
|
f_min=lowfreq, |
|
wkwargs={"periodic": False}, |
|
) |
|
|
|
@property |
|
def filter_banks(self): |
|
"""Matches the analogous class""" |
|
return self._mel_spec_extractor.mel_scale.fb |
|
|
|
def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float: |
|
if isinstance(self.log_zero_guard_value, float): |
|
return self.log_zero_guard_value |
|
return getattr(torch.finfo(dtype), self.log_zero_guard_value) |
|
|
|
def _apply_dithering(self, signals: torch.Tensor) -> torch.Tensor: |
|
if self.training and self.dither > 0.0: |
|
noise = torch.randn_like(signals) * self.dither |
|
signals = signals + noise |
|
return signals |
|
|
|
def _apply_preemphasis(self, signals: torch.Tensor) -> torch.Tensor: |
|
if self._preemphasis_value is not None: |
|
padded = torch.nn.functional.pad(signals, (1, 0)) |
|
signals = signals - self._preemphasis_value * padded[:, :-1] |
|
return signals |
|
|
|
def _compute_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor: |
|
out_lengths = input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long() |
|
return out_lengths |
|
|
|
def _apply_pad_to(self, features: torch.Tensor) -> torch.Tensor: |
|
|
|
if not self.training or self.pad_to == 0 or features.shape[-1] % self.pad_to == 0: |
|
return features |
|
pad_length = self.pad_to - (features.shape[-1] % self.pad_to) |
|
return torch.nn.functional.pad(features, pad=(0, pad_length), value=self.pad_value) |
|
|
|
def _apply_log(self, features: torch.Tensor) -> torch.Tensor: |
|
if self._use_log: |
|
zero_guard = self._resolve_log_zero_guard_value(features.dtype) |
|
if self.log_zero_guard_type == "add": |
|
features = features + zero_guard |
|
elif self.log_zero_guard_type == "clamp": |
|
features = features.clamp(min=zero_guard) |
|
else: |
|
raise ValueError(f"Unsupported log zero guard type: '{self.log_zero_guard_type}'") |
|
features = features.log() |
|
return features |
|
|
|
def _extract_spectrograms(self, signals: torch.Tensor) -> torch.Tensor: |
|
|
|
with torch.amp.autocast('cuda', enabled=False): |
|
features = self._mel_spec_extractor(waveform=signals) |
|
return features |
|
|
|
def _apply_normalization(self, features: torch.Tensor, lengths: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: |
|
|
|
mask: torch.Tensor = make_seq_mask_like(lengths=lengths, like=features, time_dim=-1, valid_ones=False) |
|
features = features.masked_fill(mask, 0.0) |
|
|
|
if self._normalize_strategy is None: |
|
return features |
|
|
|
guard_value = self._resolve_log_zero_guard_value(features.dtype) |
|
if self._normalize_strategy == "per_feature" or self._normalize_strategy == "all_features": |
|
|
|
reduce_dim = 2 |
|
if self._normalize_strategy == "all_features": |
|
reduce_dim = [1, 2] |
|
|
|
means = features.sum(dim=reduce_dim, keepdim=True).div(lengths.view(-1, 1, 1)) |
|
stds = ( |
|
features.sub(means) |
|
.masked_fill(mask, 0.0) |
|
.pow(2.0) |
|
.sum(dim=reduce_dim, keepdim=True) |
|
.div(lengths.view(-1, 1, 1) - 1) |
|
.clamp(min=guard_value) |
|
.sqrt() |
|
) |
|
features = (features - means) / (stds + eps) |
|
else: |
|
|
|
raise ValueError(f"Unsupported norm type: '{self._normalize_strategy}") |
|
features = features.masked_fill(mask, 0.0) |
|
return features |
|
|
|
def forward(self, input_signal: torch.Tensor, length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
feature_lengths = self._compute_output_lengths(input_lengths=length) |
|
signals = self._apply_dithering(signals=input_signal) |
|
signals = self._apply_preemphasis(signals=signals) |
|
features = self._extract_spectrograms(signals=signals) |
|
features = self._apply_log(features=features) |
|
features = self._apply_normalization(features=features, lengths=feature_lengths) |
|
features = self._apply_pad_to(features=features) |
|
return features, feature_lengths |