|
import math |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class Deltas(torch.nn.Module): |
|
"""Computes delta coefficients (time derivatives). |
|
Arguments |
|
--------- |
|
win_length : int |
|
Length of the window used to compute the time derivatives. |
|
Example |
|
------- |
|
>>> inputs = torch.randn([10, 101, 20]) |
|
>>> compute_deltas = Deltas(input_size=inputs.size(-1)) |
|
>>> features = compute_deltas(inputs) |
|
>>> features.shape |
|
torch.Size([10, 101, 20]) |
|
""" |
|
|
|
def __init__( |
|
self, input_size, window_length=5, |
|
): |
|
super().__init__() |
|
self.n = (window_length - 1) // 2 |
|
self.denom = self.n * (self.n + 1) * (2 * self.n + 1) / 3 |
|
|
|
self.register_buffer( |
|
"kernel", |
|
torch.arange(-self.n, self.n + 1, dtype=torch.float32,).repeat( |
|
input_size, 1, 1 |
|
), |
|
) |
|
|
|
def forward(self, x): |
|
"""Returns the delta coefficients. |
|
Arguments |
|
--------- |
|
x : tensor |
|
A batch of tensors. |
|
""" |
|
|
|
x = x.transpose(1, 2).transpose(2, -1) |
|
or_shape = x.shape |
|
if len(or_shape) == 4: |
|
x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3]) |
|
|
|
|
|
x = torch.nn.functional.pad(x, (self.n, self.n), mode="replicate") |
|
|
|
|
|
delta_coeff = ( |
|
torch.nn.functional.conv1d( |
|
x, self.kernel.to(x.device), groups=x.shape[1] |
|
) |
|
/ self.denom |
|
) |
|
|
|
|
|
if len(or_shape) == 4: |
|
delta_coeff = delta_coeff.reshape( |
|
or_shape[0], or_shape[1], or_shape[2], or_shape[3], |
|
) |
|
delta_coeff = delta_coeff.transpose(1, -1).transpose(2, -1) |
|
|
|
return delta_coeff |
|
|
|
|
|
class Filterbank(torch.nn.Module): |
|
"""computes filter bank (FBANK) features given spectral magnitudes. |
|
Arguments |
|
--------- |
|
n_mels : float |
|
Number of Mel filters used to average the spectrogram. |
|
log_mel : bool |
|
If True, it computes the log of the FBANKs. |
|
filter_shape : str |
|
Shape of the filters ('triangular', 'rectangular', 'gaussian'). |
|
f_min : int |
|
Lowest frequency for the Mel filters. |
|
f_max : int |
|
Highest frequency for the Mel filters. |
|
n_fft : int |
|
Number of fft points of the STFT. It defines the frequency resolution |
|
(n_fft should be<= than win_len). |
|
sample_rate : int |
|
Sample rate of the input audio signal (e.g, 16000) |
|
power_spectrogram : float |
|
Exponent used for spectrogram computation. |
|
amin : float |
|
Minimum amplitude (used for numerical stability). |
|
ref_value : float |
|
Reference value used for the dB scale. |
|
top_db : float |
|
Minimum negative cut-off in decibels. |
|
freeze : bool |
|
If False, it the central frequency and the band of each filter are |
|
added into nn.parameters. If True, the standard frozen features |
|
are computed. |
|
param_change_factor: bool |
|
If freeze=False, this parameter affects the speed at which the filter |
|
parameters (i.e., central_freqs and bands) can be changed. When high |
|
(e.g., param_change_factor=1) the filters change a lot during training. |
|
When low (e.g. param_change_factor=0.1) the filter parameters are more |
|
stable during training |
|
param_rand_factor: float |
|
This parameter can be used to randomly change the filter parameters |
|
(i.e, central frequencies and bands) during training. It is thus a |
|
sort of regularization. param_rand_factor=0 does not affect, while |
|
param_rand_factor=0.15 allows random variations within +-15% of the |
|
standard values of the filter parameters (e.g., if the central freq |
|
is 100 Hz, we can randomly change it from 85 Hz to 115 Hz). |
|
Example |
|
------- |
|
>>> import torch |
|
>>> compute_fbanks = Filterbank() |
|
>>> inputs = torch.randn([10, 101, 201]) |
|
>>> features = compute_fbanks(inputs) |
|
>>> features.shape |
|
torch.Size([10, 101, 40]) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
n_mels=40, |
|
log_mel=True, |
|
filter_shape="triangular", |
|
f_min=0, |
|
f_max=8000, |
|
n_fft=400, |
|
sample_rate=16000, |
|
power_spectrogram=2, |
|
amin=1e-10, |
|
ref_value=1.0, |
|
top_db=80.0, |
|
param_change_factor=1.0, |
|
param_rand_factor=0.0, |
|
freeze=True, |
|
): |
|
super().__init__() |
|
self.n_mels = n_mels |
|
self.log_mel = log_mel |
|
self.filter_shape = filter_shape |
|
self.f_min = f_min |
|
self.f_max = f_max |
|
self.n_fft = n_fft |
|
self.sample_rate = sample_rate |
|
self.power_spectrogram = power_spectrogram |
|
self.amin = amin |
|
self.ref_value = ref_value |
|
self.top_db = top_db |
|
self.freeze = freeze |
|
self.n_stft = self.n_fft // 2 + 1 |
|
self.db_multiplier = math.log10(max(self.amin, self.ref_value)) |
|
self.device_inp = torch.device("cpu") |
|
self.param_change_factor = param_change_factor |
|
self.param_rand_factor = param_rand_factor |
|
|
|
if self.power_spectrogram == 2: |
|
self.multiplier = 10 |
|
else: |
|
self.multiplier = 20 |
|
|
|
|
|
if self.f_min >= self.f_max: |
|
err_msg = "Require f_min: %f < f_max: %f" % ( |
|
self.f_min, |
|
self.f_max, |
|
) |
|
print(err_msg) |
|
|
|
|
|
mel = torch.linspace( |
|
self._to_mel(self.f_min), self._to_mel(self.f_max), self.n_mels + 2 |
|
) |
|
hz = self._to_hz(mel) |
|
|
|
|
|
band = hz[1:] - hz[:-1] |
|
self.band = band[:-1] |
|
self.f_central = hz[1:-1] |
|
|
|
|
|
if not self.freeze: |
|
self.f_central = torch.nn.Parameter( |
|
self.f_central / (self.sample_rate * self.param_change_factor) |
|
) |
|
self.band = torch.nn.Parameter( |
|
self.band / (self.sample_rate * self.param_change_factor) |
|
) |
|
|
|
|
|
all_freqs = torch.linspace(0, self.sample_rate // 2, self.n_stft) |
|
|
|
|
|
self.all_freqs_mat = all_freqs.repeat(self.f_central.shape[0], 1) |
|
|
|
def forward(self, spectrogram): |
|
"""Returns the FBANks. |
|
Arguments |
|
--------- |
|
x : tensor |
|
A batch of spectrogram tensors. |
|
""" |
|
|
|
f_central_mat = self.f_central.repeat( |
|
self.all_freqs_mat.shape[1], 1 |
|
).transpose(0, 1) |
|
band_mat = self.band.repeat(self.all_freqs_mat.shape[1], 1).transpose( |
|
0, 1 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.freeze: |
|
f_central_mat = f_central_mat * ( |
|
self.sample_rate |
|
* self.param_change_factor |
|
* self.param_change_factor |
|
) |
|
band_mat = band_mat * ( |
|
self.sample_rate |
|
* self.param_change_factor |
|
* self.param_change_factor |
|
) |
|
|
|
|
|
elif self.param_rand_factor != 0 and self.training: |
|
rand_change = ( |
|
1.0 |
|
+ torch.rand(2) * 2 * self.param_rand_factor |
|
- self.param_rand_factor |
|
) |
|
f_central_mat = f_central_mat * rand_change[0] |
|
band_mat = band_mat * rand_change[1] |
|
|
|
fbank_matrix = self._create_fbank_matrix(f_central_mat, band_mat).to( |
|
spectrogram.device |
|
) |
|
|
|
sp_shape = spectrogram.shape |
|
|
|
|
|
if len(sp_shape) == 4: |
|
spectrogram = spectrogram.permute(0, 3, 1, 2) |
|
spectrogram = spectrogram.reshape( |
|
sp_shape[0] * sp_shape[3], sp_shape[1], sp_shape[2] |
|
) |
|
|
|
|
|
fbanks = torch.matmul(spectrogram, fbank_matrix) |
|
if self.log_mel: |
|
fbanks = self._amplitude_to_DB(fbanks) |
|
|
|
|
|
if len(sp_shape) == 4: |
|
fb_shape = fbanks.shape |
|
fbanks = fbanks.reshape( |
|
sp_shape[0], sp_shape[3], fb_shape[1], fb_shape[2] |
|
) |
|
fbanks = fbanks.permute(0, 2, 3, 1) |
|
|
|
return fbanks |
|
|
|
@staticmethod |
|
def _to_mel(hz): |
|
"""Returns mel-frequency value corresponding to the input |
|
frequency value in Hz. |
|
Arguments |
|
--------- |
|
x : float |
|
The frequency point in Hz. |
|
""" |
|
return 2595 * math.log10(1 + hz / 700) |
|
|
|
@staticmethod |
|
def _to_hz(mel): |
|
"""Returns hz-frequency value corresponding to the input |
|
mel-frequency value. |
|
Arguments |
|
--------- |
|
x : float |
|
The frequency point in the mel-scale. |
|
""" |
|
return 700 * (10 ** (mel / 2595) - 1) |
|
|
|
def _triangular_filters(self, all_freqs, f_central, band): |
|
"""Returns fbank matrix using triangular filters. |
|
Arguments |
|
--------- |
|
all_freqs : Tensor |
|
Tensor gathering all the frequency points. |
|
f_central : Tensor |
|
Tensor gathering central frequencies of each filter. |
|
band : Tensor |
|
Tensor gathering the bands of each filter. |
|
""" |
|
|
|
|
|
slope = (all_freqs - f_central) / band |
|
left_side = slope + 1.0 |
|
right_side = -slope + 1.0 |
|
|
|
|
|
zero = torch.zeros(1, device=self.device_inp) |
|
fbank_matrix = torch.max( |
|
zero, torch.min(left_side, right_side) |
|
).transpose(0, 1) |
|
|
|
return fbank_matrix |
|
|
|
def _rectangular_filters(self, all_freqs, f_central, band): |
|
"""Returns fbank matrix using rectangular filters. |
|
Arguments |
|
--------- |
|
all_freqs : Tensor |
|
Tensor gathering all the frequency points. |
|
f_central : Tensor |
|
Tensor gathering central frequencies of each filter. |
|
band : Tensor |
|
Tensor gathering the bands of each filter. |
|
""" |
|
|
|
|
|
low_hz = f_central - band |
|
high_hz = f_central + band |
|
|
|
|
|
left_side = right_size = all_freqs.ge(low_hz) |
|
right_size = all_freqs.le(high_hz) |
|
|
|
fbank_matrix = (left_side * right_size).float().transpose(0, 1) |
|
|
|
return fbank_matrix |
|
|
|
def _gaussian_filters( |
|
self, all_freqs, f_central, band, smooth_factor=torch.tensor(2) |
|
): |
|
"""Returns fbank matrix using gaussian filters. |
|
Arguments |
|
--------- |
|
all_freqs : Tensor |
|
Tensor gathering all the frequency points. |
|
f_central : Tensor |
|
Tensor gathering central frequencies of each filter. |
|
band : Tensor |
|
Tensor gathering the bands of each filter. |
|
smooth_factor: Tensor |
|
Smoothing factor of the gaussian filter. It can be used to employ |
|
sharper or flatter filters. |
|
""" |
|
fbank_matrix = torch.exp( |
|
-0.5 * ((all_freqs - f_central) / (band / smooth_factor)) ** 2 |
|
).transpose(0, 1) |
|
|
|
return fbank_matrix |
|
|
|
def _create_fbank_matrix(self, f_central_mat, band_mat): |
|
"""Returns fbank matrix to use for averaging the spectrum with |
|
the set of filter-banks. |
|
Arguments |
|
--------- |
|
f_central : Tensor |
|
Tensor gathering central frequencies of each filter. |
|
band : Tensor |
|
Tensor gathering the bands of each filter. |
|
smooth_factor: Tensor |
|
Smoothing factor of the gaussian filter. It can be used to employ |
|
sharper or flatter filters. |
|
""" |
|
if self.filter_shape == "triangular": |
|
fbank_matrix = self._triangular_filters( |
|
self.all_freqs_mat, f_central_mat, band_mat |
|
) |
|
|
|
elif self.filter_shape == "rectangular": |
|
fbank_matrix = self._rectangular_filters( |
|
self.all_freqs_mat, f_central_mat, band_mat |
|
) |
|
|
|
else: |
|
fbank_matrix = self._gaussian_filters( |
|
self.all_freqs_mat, f_central_mat, band_mat |
|
) |
|
|
|
return fbank_matrix |
|
|
|
def _amplitude_to_DB(self, x): |
|
"""Converts linear-FBANKs to log-FBANKs. |
|
Arguments |
|
--------- |
|
x : Tensor |
|
A batch of linear FBANK tensors. |
|
""" |
|
|
|
x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin)) |
|
x_db -= self.multiplier * self.db_multiplier |
|
|
|
|
|
|
|
new_x_db_max = x_db.amax(dim=(-2, -1)) - self.top_db |
|
|
|
|
|
|
|
x_db = torch.max(x_db, new_x_db_max.view(x_db.shape[0], 1, 1)) |
|
|
|
return x_db |
|
|
|
|
|
class STFT(torch.nn.Module): |
|
"""computes the Short-Term Fourier Transform (STFT). |
|
This class computes the Short-Term Fourier Transform of an audio signal. |
|
It supports multi-channel audio inputs (batch, time, channels). |
|
Arguments |
|
--------- |
|
sample_rate : int |
|
Sample rate of the input audio signal (e.g 16000). |
|
win_length : float |
|
Length (in ms) of the sliding window used to compute the STFT. |
|
hop_length : float |
|
Length (in ms) of the hope of the sliding window used to compute |
|
the STFT. |
|
n_fft : int |
|
Number of fft point of the STFT. It defines the frequency resolution |
|
(n_fft should be <= than win_len). |
|
window_fn : function |
|
A function that takes an integer (number of samples) and outputs a |
|
tensor to be multiplied with each window before fft. |
|
normalized_stft : bool |
|
If True, the function returns the normalized STFT results, |
|
i.e., multiplied by win_length^-0.5 (default is False). |
|
center : bool |
|
If True (default), the input will be padded on both sides so that the |
|
t-th frame is centered at time t×hop_length. Otherwise, the t-th frame |
|
begins at time t×hop_length. |
|
pad_mode : str |
|
It can be 'constant','reflect','replicate', 'circular', 'reflect' |
|
(default). 'constant' pads the input tensor boundaries with a |
|
constant value. 'reflect' pads the input tensor using the reflection |
|
of the input boundary. 'replicate' pads the input tensor using |
|
replication of the input boundary. 'circular' pads using circular |
|
replication. |
|
onesided : True |
|
If True (default) only returns nfft/2 values. Note that the other |
|
samples are redundant due to the Fourier transform conjugate symmetry. |
|
Example |
|
------- |
|
>>> import torch |
|
>>> compute_STFT = STFT( |
|
... sample_rate=16000, win_length=25, hop_length=10, n_fft=400 |
|
... ) |
|
>>> inputs = torch.randn([10, 16000]) |
|
>>> features = compute_STFT(inputs) |
|
>>> features.shape |
|
torch.Size([10, 101, 201, 2]) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
sample_rate, |
|
win_length=25, |
|
hop_length=10, |
|
n_fft=400, |
|
window_fn=torch.hamming_window, |
|
normalized_stft=False, |
|
center=True, |
|
pad_mode="constant", |
|
onesided=True, |
|
): |
|
super().__init__() |
|
self.sample_rate = sample_rate |
|
self.win_length = win_length |
|
self.hop_length = hop_length |
|
self.n_fft = n_fft |
|
self.normalized_stft = normalized_stft |
|
self.center = center |
|
self.pad_mode = pad_mode |
|
self.onesided = onesided |
|
|
|
|
|
self.win_length = int( |
|
round((self.sample_rate / 1000.0) * self.win_length) |
|
) |
|
self.hop_length = int( |
|
round((self.sample_rate / 1000.0) * self.hop_length) |
|
) |
|
|
|
self.window = window_fn(self.win_length) |
|
|
|
def forward(self, x): |
|
"""Returns the STFT generated from the input waveforms. |
|
Arguments |
|
--------- |
|
x : tensor |
|
A batch of audio signals to transform. |
|
""" |
|
|
|
|
|
or_shape = x.shape |
|
if len(or_shape) == 3: |
|
x = x.transpose(1, 2) |
|
x = x.reshape(or_shape[0] * or_shape[2], or_shape[1]) |
|
|
|
stft = torch.stft( |
|
x, |
|
self.n_fft, |
|
self.hop_length, |
|
self.win_length, |
|
self.window.to(x.device), |
|
self.center, |
|
self.pad_mode, |
|
self.normalized_stft, |
|
self.onesided, |
|
return_complex=True, |
|
) |
|
|
|
stft = torch.view_as_real(stft) |
|
|
|
|
|
if len(or_shape) == 3: |
|
stft = stft.reshape( |
|
or_shape[0], |
|
or_shape[2], |
|
stft.shape[1], |
|
stft.shape[2], |
|
stft.shape[3], |
|
) |
|
stft = stft.permute(0, 3, 2, 4, 1) |
|
else: |
|
|
|
stft = stft.transpose(2, 1) |
|
|
|
return stft |
|
|
|
|
|
def spectral_magnitude( |
|
stft, power: int = 1, log: bool = False, eps: float = 1e-14 |
|
): |
|
"""Returns the magnitude of a complex spectrogram. |
|
Arguments |
|
--------- |
|
stft : torch.Tensor |
|
A tensor, output from the stft function. |
|
power : int |
|
What power to use in computing the magnitude. |
|
Use power=1 for the power spectrogram. |
|
Use power=0.5 for the magnitude spectrogram. |
|
log : bool |
|
Whether to apply log to the spectral features. |
|
Example |
|
------- |
|
>>> a = torch.Tensor([[3, 4]]) |
|
>>> spectral_magnitude(a, power=0.5) |
|
tensor([5.]) |
|
""" |
|
spectr = stft.pow(2).sum(-1) |
|
|
|
|
|
if power < 1: |
|
spectr = spectr + eps |
|
spectr = spectr.pow(power) |
|
|
|
if log: |
|
return torch.log(spectr + eps) |
|
return spectr |
|
|
|
|
|
class ContextWindow(torch.nn.Module): |
|
"""Computes the context window. |
|
This class applies a context window by gathering multiple time steps |
|
in a single feature vector. The operation is performed with a |
|
convolutional layer based on a fixed kernel designed for that. |
|
Arguments |
|
--------- |
|
left_frames : int |
|
Number of left frames (i.e, past frames) to collect. |
|
right_frames : int |
|
Number of right frames (i.e, future frames) to collect. |
|
Example |
|
------- |
|
>>> import torch |
|
>>> compute_cw = ContextWindow(left_frames=5, right_frames=5) |
|
>>> inputs = torch.randn([10, 101, 20]) |
|
>>> features = compute_cw(inputs) |
|
>>> features.shape |
|
torch.Size([10, 101, 220]) |
|
""" |
|
|
|
def __init__( |
|
self, left_frames=0, right_frames=0, |
|
): |
|
super().__init__() |
|
self.left_frames = left_frames |
|
self.right_frames = right_frames |
|
self.context_len = self.left_frames + self.right_frames + 1 |
|
self.kernel_len = 2 * max(self.left_frames, self.right_frames) + 1 |
|
|
|
|
|
self.kernel = torch.eye(self.context_len, self.kernel_len) |
|
|
|
if self.right_frames > self.left_frames: |
|
lag = self.right_frames - self.left_frames |
|
self.kernel = torch.roll(self.kernel, lag, 1) |
|
|
|
self.first_call = True |
|
|
|
def forward(self, x): |
|
"""Returns the tensor with the surrounding context. |
|
Arguments |
|
--------- |
|
x : tensor |
|
A batch of tensors. |
|
""" |
|
|
|
x = x.transpose(1, 2) |
|
|
|
if self.first_call is True: |
|
self.first_call = False |
|
self.kernel = ( |
|
self.kernel.repeat(x.shape[1], 1, 1) |
|
.view(x.shape[1] * self.context_len, self.kernel_len,) |
|
.unsqueeze(1) |
|
) |
|
|
|
|
|
or_shape = x.shape |
|
if len(or_shape) == 4: |
|
x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3]) |
|
|
|
|
|
cw_x = torch.nn.functional.conv1d( |
|
x, |
|
self.kernel.to(x.device), |
|
groups=x.shape[1], |
|
padding=max(self.left_frames, self.right_frames), |
|
) |
|
|
|
|
|
if len(or_shape) == 4: |
|
cw_x = cw_x.reshape( |
|
or_shape[0], cw_x.shape[1], or_shape[2], cw_x.shape[-1] |
|
) |
|
|
|
cw_x = cw_x.transpose(1, 2) |
|
|
|
return cw_x |
|
|
|
|
|
class Fbank(torch.nn.Module): |
|
|
|
def __init__( |
|
self, |
|
deltas=False, |
|
context=False, |
|
requires_grad=False, |
|
sample_rate=16000, |
|
f_min=0, |
|
f_max=None, |
|
n_fft=400, |
|
n_mels=40, |
|
filter_shape="triangular", |
|
param_change_factor=1.0, |
|
param_rand_factor=0.0, |
|
left_frames=5, |
|
right_frames=5, |
|
win_length=25, |
|
hop_length=10, |
|
): |
|
super().__init__() |
|
self.deltas = deltas |
|
self.context = context |
|
self.requires_grad = requires_grad |
|
|
|
if f_max is None: |
|
f_max = sample_rate / 2 |
|
|
|
self.compute_STFT = STFT( |
|
sample_rate=sample_rate, |
|
n_fft=n_fft, |
|
win_length=win_length, |
|
hop_length=hop_length, |
|
) |
|
self.compute_fbanks = Filterbank( |
|
sample_rate=sample_rate, |
|
n_fft=n_fft, |
|
n_mels=n_mels, |
|
f_min=f_min, |
|
f_max=f_max, |
|
freeze=not requires_grad, |
|
filter_shape=filter_shape, |
|
param_change_factor=param_change_factor, |
|
param_rand_factor=param_rand_factor, |
|
) |
|
self.compute_deltas = Deltas(input_size=n_mels) |
|
self.context_window = ContextWindow( |
|
left_frames=left_frames, right_frames=right_frames, |
|
) |
|
|
|
def forward(self, wav): |
|
"""Returns a set of features generated from the input waveforms. |
|
Arguments |
|
--------- |
|
wav : tensor |
|
A batch of audio signals to transform to features. |
|
""" |
|
STFT = self.compute_STFT(wav) |
|
mag = spectral_magnitude(STFT) |
|
fbanks = self.compute_fbanks(mag) |
|
if self.deltas: |
|
delta1 = self.compute_deltas(fbanks) |
|
delta2 = self.compute_deltas(delta1) |
|
fbanks = torch.cat([fbanks, delta1, delta2], dim=2) |
|
if self.context: |
|
fbanks = self.context_window(fbanks) |
|
return fbanks |