Spaces:
Build error
Build error
from librosa.util import pad_center, tiny | |
from scipy.signal import get_window | |
from torch import Tensor | |
from torch.autograd import Variable | |
from typing import Optional, Tuple | |
import librosa | |
import librosa.util as librosa_util | |
import math | |
import numpy as np | |
import scipy | |
import torch | |
import torch.nn.functional as F | |
import warnings | |
def create_fb_matrix( | |
n_freqs: int, | |
f_min: float, | |
f_max: float, | |
n_mels: int, | |
sample_rate: int, | |
norm: Optional[str] = None | |
) -> Tensor: | |
r"""Create a frequency bin conversion matrix. | |
Args: | |
n_freqs (int): Number of frequencies to highlight/apply | |
f_min (float): Minimum frequency (Hz) | |
f_max (float): Maximum frequency (Hz) | |
n_mels (int): Number of mel filterbanks | |
sample_rate (int): Sample rate of the audio waveform | |
norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band | |
(area normalization). (Default: ``None``) | |
Returns: | |
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) | |
meaning number of frequencies to highlight/apply to x the number of filterbanks. | |
Each column is a filterbank so that assuming there is a matrix A of | |
size (..., ``n_freqs``), the applied result would be | |
``A * create_fb_matrix(A.size(-1), ...)``. | |
""" | |
if norm is not None and norm != "slaney": | |
raise ValueError("norm must be one of None or 'slaney'") | |
# freq bins | |
# Equivalent filterbank construction by Librosa | |
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) | |
# calculate mel freq bins | |
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) | |
m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0)) | |
m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0)) | |
m_pts = torch.linspace(m_min, m_max, n_mels + 2) | |
# mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.) | |
f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0) | |
# calculate the difference between each mel point and each stft freq point in hertz | |
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) | |
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2) | |
# create overlapping triangles | |
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) | |
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels) | |
fb = torch.min(down_slopes, up_slopes) | |
fb = torch.clamp(fb, 1e-6, 1) | |
if norm is not None and norm == "slaney": | |
# Slaney-style mel is scaled to be approx constant energy per channel | |
enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels]) | |
fb *= enorm.unsqueeze(0) | |
return fb | |
def lfilter( | |
waveform: Tensor, | |
a_coeffs: Tensor, | |
b_coeffs: Tensor, | |
clamp: bool = True, | |
) -> Tensor: | |
r"""Perform an IIR filter by evaluating difference equation. | |
Args: | |
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1. | |
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``. | |
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``. | |
Must be same size as b_coeffs (pad with 0's as necessary). | |
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``. | |
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. | |
Must be same size as a_coeffs (pad with 0's as necessary). | |
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) | |
Returns: | |
Tensor: Waveform with dimension of ``(..., time)``. | |
""" | |
# pack batch | |
shape = waveform.size() | |
waveform = waveform.reshape(-1, shape[-1]) | |
assert (a_coeffs.size(0) == b_coeffs.size(0)) | |
assert (len(waveform.size()) == 2) | |
assert (waveform.device == a_coeffs.device) | |
assert (b_coeffs.device == a_coeffs.device) | |
device = waveform.device | |
dtype = waveform.dtype | |
n_channel, n_sample = waveform.size() | |
n_order = a_coeffs.size(0) | |
n_sample_padded = n_sample + n_order - 1 | |
assert (n_order > 0) | |
# Pad the input and create output | |
padded_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device) | |
padded_waveform[:, (n_order - 1):] = waveform | |
padded_output_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device) | |
# Set up the coefficients matrix | |
# Flip coefficients' order | |
a_coeffs_flipped = a_coeffs.flip(0) | |
b_coeffs_flipped = b_coeffs.flip(0) | |
# calculate windowed_input_signal in parallel | |
# create indices of original with shape (n_channel, n_order, n_sample) | |
window_idxs = torch.arange(n_sample, device=device).unsqueeze(0) + torch.arange(n_order, device=device).unsqueeze(1) | |
window_idxs = window_idxs.repeat(n_channel, 1, 1) | |
window_idxs += (torch.arange(n_channel, device=device).unsqueeze(-1).unsqueeze(-1) * n_sample_padded) | |
window_idxs = window_idxs.long() | |
# (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample) | |
input_signal_windows = torch.matmul(b_coeffs_flipped, torch.take(padded_waveform, window_idxs)) | |
input_signal_windows.div_(a_coeffs[0]) | |
a_coeffs_flipped.div_(a_coeffs[0]) | |
for i_sample, o0 in enumerate(input_signal_windows.t()): | |
windowed_output_signal = padded_output_waveform[:, i_sample:(i_sample + n_order)] | |
o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1) | |
padded_output_waveform[:, i_sample + n_order - 1] = o0 | |
output = padded_output_waveform[:, (n_order - 1):] | |
if clamp: | |
output = torch.clamp(output, min=-1., max=1.) | |
# unpack batch | |
output = output.reshape(shape[:-1] + output.shape[-1:]) | |
return output | |
def biquad( | |
waveform: Tensor, | |
b0: float, | |
b1: float, | |
b2: float, | |
a0: float, | |
a1: float, | |
a2: float | |
) -> Tensor: | |
r"""Perform a biquad filter of input tensor. Initial conditions set to 0. | |
https://en.wikipedia.org/wiki/Digital_biquad_filter | |
Args: | |
waveform (Tensor): audio waveform of dimension of `(..., time)` | |
b0 (float): numerator coefficient of current input, x[n] | |
b1 (float): numerator coefficient of input one time step ago x[n-1] | |
b2 (float): numerator coefficient of input two time steps ago x[n-2] | |
a0 (float): denominator coefficient of current output y[n], typically 1 | |
a1 (float): denominator coefficient of current output y[n-1] | |
a2 (float): denominator coefficient of current output y[n-2] | |
Returns: | |
Tensor: Waveform with dimension of `(..., time)` | |
""" | |
device = waveform.device | |
dtype = waveform.dtype | |
output_waveform = lfilter( | |
waveform, | |
torch.tensor([a0, a1, a2], dtype=dtype, device=device), | |
torch.tensor([b0, b1, b2], dtype=dtype, device=device) | |
) | |
return output_waveform | |
def _dB2Linear(x: float) -> float: | |
return math.exp(x * math.log(10) / 20.0) | |
def highpass_biquad( | |
waveform: Tensor, | |
sample_rate: int, | |
cutoff_freq: float, | |
Q: float = 0.707 | |
) -> Tensor: | |
r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation. | |
Args: | |
waveform (Tensor): audio waveform of dimension of `(..., time)` | |
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) | |
cutoff_freq (float): filter cutoff frequency | |
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) | |
Returns: | |
Tensor: Waveform dimension of `(..., time)` | |
""" | |
w0 = 2 * math.pi * cutoff_freq / sample_rate | |
alpha = math.sin(w0) / 2. / Q | |
b0 = (1 + math.cos(w0)) / 2 | |
b1 = -1 - math.cos(w0) | |
b2 = b0 | |
a0 = 1 + alpha | |
a1 = -2 * math.cos(w0) | |
a2 = 1 - alpha | |
return biquad(waveform, b0, b1, b2, a0, a1, a2) | |
def lowpass_biquad( | |
waveform: Tensor, | |
sample_rate: int, | |
cutoff_freq: float, | |
Q: float = 0.707 | |
) -> Tensor: | |
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation. | |
Args: | |
waveform (torch.Tensor): audio waveform of dimension of `(..., time)` | |
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) | |
cutoff_freq (float): filter cutoff frequency | |
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``) | |
Returns: | |
Tensor: Waveform of dimension of `(..., time)` | |
""" | |
w0 = 2 * math.pi * cutoff_freq / sample_rate | |
alpha = math.sin(w0) / 2 / Q | |
b0 = (1 - math.cos(w0)) / 2 | |
b1 = 1 - math.cos(w0) | |
b2 = b0 | |
a0 = 1 + alpha | |
a1 = -2 * math.cos(w0) | |
a2 = 1 - alpha | |
return biquad(waveform, b0, b1, b2, a0, a1, a2) | |
def window_sumsquare(window, n_frames, hop_length=200, win_length=800, | |
n_fft=800, dtype=np.float32, norm=None): | |
""" | |
# from librosa 0.6 | |
Compute the sum-square envelope of a window function at a given hop length. | |
This is used to estimate modulation effects induced by windowing | |
observations in short-time fourier transforms. | |
Parameters | |
---------- | |
window : string, tuple, number, callable, or list-like | |
Window specification, as in `get_window` | |
n_frames : int > 0 | |
The number of analysis frames | |
hop_length : int > 0 | |
The number of samples to advance between frames | |
win_length : [optional] | |
The length of the window function. By default, this matches `n_fft`. | |
n_fft : int > 0 | |
The length of each analysis frame. | |
dtype : np.dtype | |
The data type of the output | |
Returns | |
------- | |
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` | |
The sum-squared envelope of the window function | |
""" | |
if win_length is None: | |
win_length = n_fft | |
n = n_fft + hop_length * (n_frames - 1) | |
x = np.zeros(n, dtype=dtype) | |
# Compute the squared window at the desired length | |
win_sq = get_window(window, win_length, fftbins=True) | |
win_sq = librosa_util.normalize(win_sq, norm=norm)**2 | |
win_sq = librosa_util.pad_center(win_sq, n_fft) | |
# Fill the envelope | |
for i in range(n_frames): | |
sample = i * hop_length | |
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] | |
return x | |
class MelScale(torch.nn.Module): | |
r"""Turn a normal STFT into a mel frequency STFT, using a conversion | |
matrix. This uses triangular filter banks. | |
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). | |
Args: | |
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``) | |
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) | |
f_min (float, optional): Minimum frequency. (Default: ``0.``) | |
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) | |
n_stft (int, optional): Number of bins in STFT. Calculated from first input | |
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``) | |
""" | |
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] | |
def __init__(self, | |
n_mels: int = 128, | |
sample_rate: int = 24000, | |
f_min: float = 0., | |
f_max: Optional[float] = None, | |
n_stft: Optional[int] = None) -> None: | |
super(MelScale, self).__init__() | |
self.n_mels = n_mels | |
self.sample_rate = sample_rate | |
self.f_max = f_max if f_max is not None else float(sample_rate // 2) | |
self.f_min = f_min | |
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) | |
fb = torch.empty(0) if n_stft is None else create_fb_matrix( | |
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) | |
self.register_buffer('fb', fb) | |
def forward(self, specgram: Tensor) -> Tensor: | |
r""" | |
Args: | |
specgram (Tensor): A spectrogram STFT of dimension (..., freq, time). | |
Returns: | |
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time). | |
""" | |
# pack batch | |
shape = specgram.size() | |
specgram = specgram.reshape(-1, shape[-2], shape[-1]) | |
if self.fb.numel() == 0: | |
tmp_fb = create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate) | |
# Attributes cannot be reassigned outside __init__ so workaround | |
self.fb.resize_(tmp_fb.size()) | |
self.fb.copy_(tmp_fb) | |
# (channel, frequency, time).transpose(...) dot (frequency, n_mels) | |
# -> (channel, time, n_mels).transpose(...) | |
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) | |
# unpack batch | |
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:]) | |
return mel_specgram | |
class TorchSTFT(torch.nn.Module): | |
def __init__(self, fft_size, hop_size, win_size, | |
normalized=False, domain='linear', | |
mel_scale=False, ref_level_db=20, min_level_db=-100): | |
super().__init__() | |
self.fft_size = fft_size | |
self.hop_size = hop_size | |
self.win_size = win_size | |
self.ref_level_db = ref_level_db | |
self.min_level_db = min_level_db | |
self.window = torch.hann_window(win_size) | |
self.normalized = normalized | |
self.domain = domain | |
self.mel_scale = MelScale(n_mels=(fft_size // 2 + 1), | |
n_stft=(fft_size // 2 + 1)) if mel_scale else None | |
def transform(self, x): | |
x_stft = torch.stft(x, self.fft_size, self.hop_size, self.win_size, | |
self.window.type_as(x), normalized=self.normalized) | |
real = x_stft[..., 0] | |
imag = x_stft[..., 1] | |
mag = torch.clamp(real ** 2 + imag ** 2, min=1e-7) | |
mag = torch.sqrt(mag) | |
phase = torch.atan2(imag, real) | |
if self.mel_scale is not None: | |
mag = self.mel_scale(mag) | |
if self.domain == 'log': | |
mag = 20 * torch.log10(mag) - self.ref_level_db | |
mag = torch.clamp((mag - self.min_level_db) / -self.min_level_db, 0, 1) | |
return mag, phase | |
elif self.domain == 'linear': | |
return mag, phase | |
elif self.domain == 'double': | |
log_mag = 20 * torch.log10(mag) - self.ref_level_db | |
log_mag = torch.clamp((log_mag - self.min_level_db) / -self.min_level_db, 0, 1) | |
return torch.cat((mag, log_mag), dim=1), phase | |
def complex(self, x): | |
x_stft = torch.stft(x, self.fft_size, self.hop_size, self.win_size, | |
self.window.type_as(x), normalized=self.normalized) | |
real = x_stft[..., 0] | |
imag = x_stft[..., 1] | |
return real, imag | |
class STFT(torch.nn.Module): | |
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" | |
def __init__(self, filter_length=800, hop_length=200, win_length=800, | |
window='hann'): | |
super(STFT, self).__init__() | |
self.filter_length = filter_length | |
self.hop_length = hop_length | |
self.win_length = win_length | |
self.window = window | |
self.forward_transform = None | |
scale = self.filter_length / self.hop_length | |
fourier_basis = np.fft.fft(np.eye(self.filter_length)) | |
cutoff = int((self.filter_length / 2 + 1)) | |
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), | |
np.imag(fourier_basis[:cutoff, :])]) | |
forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) | |
inverse_basis = torch.FloatTensor( | |
np.linalg.pinv(scale * fourier_basis).T[:, None, :]) | |
if window is not None: | |
assert(filter_length >= win_length) | |
# get window and zero center pad it to filter_length | |
fft_window = get_window(window, win_length, fftbins=True) | |
fft_window = pad_center(fft_window, filter_length) | |
fft_window = torch.from_numpy(fft_window).float() | |
# window the bases | |
forward_basis *= fft_window | |
inverse_basis *= fft_window | |
self.register_buffer('forward_basis', forward_basis.float()) | |
self.register_buffer('inverse_basis', inverse_basis.float()) | |
def transform(self, input_data): | |
num_batches = input_data.size(0) | |
num_samples = input_data.size(1) | |
self.num_samples = num_samples | |
# similar to librosa, reflect-pad the input | |
input_data = input_data.view(num_batches, 1, num_samples) | |
input_data = F.pad( | |
input_data.unsqueeze(1), | |
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), | |
mode='reflect') | |
input_data = input_data.squeeze(1) | |
forward_transform = F.conv1d( | |
input_data, | |
Variable(self.forward_basis, requires_grad=False), | |
stride=self.hop_length, | |
padding=0) | |
cutoff = int((self.filter_length / 2) + 1) | |
real_part = forward_transform[:, :cutoff, :] | |
imag_part = forward_transform[:, cutoff:, :] | |
magnitude = torch.sqrt(real_part**2 + imag_part**2) | |
phase = torch.autograd.Variable( | |
torch.atan2(imag_part.data, real_part.data)) | |
return magnitude, phase | |
def inverse(self, magnitude, phase): | |
recombine_magnitude_phase = torch.cat( | |
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) | |
inverse_transform = F.conv_transpose1d( | |
recombine_magnitude_phase, | |
Variable(self.inverse_basis, requires_grad=False), | |
stride=self.hop_length, | |
padding=0) | |
if self.window is not None: | |
window_sum = window_sumsquare( | |
self.window, magnitude.size(-1), hop_length=self.hop_length, | |
win_length=self.win_length, n_fft=self.filter_length, | |
dtype=np.float32) | |
# remove modulation effects | |
approx_nonzero_indices = torch.from_numpy( | |
np.where(window_sum > tiny(window_sum))[0]) | |
window_sum = torch.autograd.Variable( | |
torch.from_numpy(window_sum), requires_grad=False) | |
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum | |
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] | |
# scale by hop ratio | |
inverse_transform *= float(self.filter_length) / self.hop_length | |
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] | |
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] | |
return inverse_transform | |
def forward(self, input_data): | |
self.magnitude, self.phase = self.transform(input_data) | |
reconstruction = self.inverse(self.magnitude, self.phase) | |
return reconstruction | |