Spaces:
Build error
Build error
import math | |
import torch | |
from typing import List | |
def butter(fc, fs: float = 2.0): | |
""" | |
Recall Butterworth polynomials | |
N = 1 s + 1 | |
N = 2 s^2 + sqrt(2s) + 1 | |
N = 3 (s^2 + s + 1)(s + 1) | |
N = 4 (s^2 + 0.76536s + 1)(s^2 + 1.84776s + 1) | |
Scaling | |
LP to LP: s -> s/w_c | |
LP to HP: s -> w_c/s | |
Bilinear transform: | |
s = 2/T_d * (1 - z^-1)/(1 + z^-1) | |
For 1-pole butterworth lowpass | |
1 / (s + 1) 1-pole prototype | |
1 / (s/w_c + 1) LP to LP | |
1 / (2/T_d * (1 - z^-1)/(1 + z^-1))/w_c + 1) Bilinear transform | |
""" | |
# apply pre-warping to the cutoff | |
T_d = 1 / fs | |
w_d = (2 * math.pi * fc) / fs | |
# sys.exit() | |
w_c = (2 / T_d) * torch.tan(w_d / 2) | |
a0 = 2 + (T_d * w_c) | |
a1 = (T_d * w_c) - 2 | |
b0 = T_d * w_c | |
b1 = T_d * w_c | |
b = torch.stack([b0, b1], dim=0).view(-1) | |
a = torch.stack([a0, a1], dim=0).view(-1) | |
# normalize | |
b = b.type_as(fc) / a0 | |
a = a.type_as(fc) / a0 | |
return b, a | |
def biqaud( | |
gain_dB: torch.Tensor, | |
cutoff_freq: torch.Tensor, | |
q_factor: torch.Tensor, | |
sample_rate: float, | |
filter_type: str = "peaking", | |
): | |
# convert inputs to Tensors if needed | |
# gain_dB = torch.tensor([gain_dB]) | |
# cutoff_freq = torch.tensor([cutoff_freq]) | |
# q_factor = torch.tensor([q_factor]) | |
A = 10 ** (gain_dB / 40.0) | |
w0 = 2 * math.pi * (cutoff_freq / sample_rate) | |
alpha = torch.sin(w0) / (2 * q_factor) | |
cos_w0 = torch.cos(w0) | |
sqrt_A = torch.sqrt(A) | |
if filter_type == "high_shelf": | |
b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha) | |
b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0) | |
b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha) | |
a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha | |
a1 = 2 * ((A - 1) - (A + 1) * cos_w0) | |
a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha | |
elif filter_type == "low_shelf": | |
b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha) | |
b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0) | |
b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha) | |
a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha | |
a1 = -2 * ((A - 1) + (A + 1) * cos_w0) | |
a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha | |
elif filter_type == "peaking": | |
b0 = 1 + alpha * A | |
b1 = -2 * cos_w0 | |
b2 = 1 - alpha * A | |
a0 = 1 + (alpha / A) | |
a1 = -2 * cos_w0 | |
a2 = 1 - (alpha / A) | |
else: | |
raise ValueError(f"Invalid filter_type: {filter_type}.") | |
b = torch.stack([b0, b1, b2], dim=0).view(-1) | |
a = torch.stack([a0, a1, a2], dim=0).view(-1) | |
# normalize | |
b = b.type_as(gain_dB) / a0 | |
a = a.type_as(gain_dB) / a0 | |
return b, a | |
def freqz(b, a, n_fft: int = 512): | |
B = torch.fft.rfft(b, n_fft) | |
A = torch.fft.rfft(a, n_fft) | |
H = B / A | |
return H | |
def freq_domain_filter(x, H, n_fft): | |
X = torch.fft.rfft(x, n_fft) | |
# move H to same device as input x | |
H = H.type_as(X) | |
Y = X * H | |
y = torch.fft.irfft(Y, n_fft) | |
return y | |
def approx_iir_filter(b, a, x): | |
"""Approimxate the application of an IIR filter. | |
Args: | |
b (Tensor): The numerator coefficients. | |
""" | |
# round up to nearest power of 2 for FFT | |
# n_fft = 2 ** math.ceil(math.log2(x.shape[-1] + x.shape[-1] - 1)) | |
n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1))) | |
n_fft = n_fft.int() | |
# move coefficients to same device as x | |
b = b.type_as(x).view(-1) | |
a = a.type_as(x).view(-1) | |
# compute complex response | |
H = freqz(b, a, n_fft=n_fft).view(-1) | |
# apply filter | |
y = freq_domain_filter(x, H, n_fft) | |
# crop | |
y = y[: x.shape[-1]] | |
return y | |
def approx_iir_filter_cascade( | |
b_s: List[torch.Tensor], | |
a_s: List[torch.Tensor], | |
x: torch.Tensor, | |
): | |
"""Apply a cascade of IIR filters. | |
Args: | |
b (list[Tensor]): List of tensors of shape (3) | |
a (list[Tensor]): List of tensors of (3) | |
x (torch.Tensor): 1d Tensor. | |
""" | |
if len(b_s) != len(a_s): | |
raise RuntimeError( | |
f"Must have same number of coefficients. Got b: {len(b_s)} and a: {len(a_s)}." | |
) | |
# round up to nearest power of 2 for FFT | |
# n_fft = 2 ** math.ceil(math.log2(x.shape[-1] + x.shape[-1] - 1)) | |
n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1))) | |
n_fft = n_fft.int() | |
# this could be done in parallel | |
b = torch.stack(b_s, dim=0).type_as(x) | |
a = torch.stack(a_s, dim=0).type_as(x) | |
H = freqz(b, a, n_fft=n_fft) | |
H = torch.prod(H, dim=0).view(-1) | |
# apply filter | |
y = freq_domain_filter(x, H, n_fft) | |
# crop | |
y = y[: x.shape[-1]] | |
return y | |