|
|
|
|
|
""" |
|
FIR windowed sinc lowpass filters. |
|
""" |
|
|
|
import math |
|
from typing import Sequence, Optional |
|
|
|
import torch |
|
from torch.nn import functional as F |
|
|
|
from .core import sinc |
|
from .fftconv import fft_conv1d |
|
from .utils import simple_repr |
|
|
|
|
|
class LowPassFilters(torch.nn.Module): |
|
""" |
|
Bank of low pass filters. Note that a high pass or band pass filter can easily |
|
be implemented by substracting a same signal processed with low pass filters with different |
|
frequencies (see `julius.bands.SplitBands` for instance). |
|
This uses a windowed sinc filter, very similar to the one used in |
|
`julius.resample`. However, because we do not change the sample rate here, |
|
this filter can be much more efficiently implemented using the FFT convolution from |
|
`julius.fftconv`. |
|
|
|
Args: |
|
cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where |
|
f_s is the samplerate and `f` is the cutoff frequency. |
|
The upper limit is 0.5, because a signal sampled at `f_s` contains only |
|
frequencies under `f_s / 2`. |
|
stride (int): how much to decimate the output. Keep in mind that decimation |
|
of the output is only acceptable if the cutoff frequency is under `1/ (2 * stride)` |
|
of the original sampling rate. |
|
pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`, |
|
the output will have the same length as the input. |
|
zeros (float): Number of zero crossings to keep. |
|
Controls the receptive field of the Finite Impulse Response filter. |
|
For lowpass filters with low cutoff frequency, e.g. 40Hz at 44.1kHz, |
|
it is a bad idea to set this to a high value. |
|
This is likely appropriate for most use. Lower values |
|
will result in a faster filter, but with a slower attenuation around the |
|
cutoff frequency. |
|
fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions. |
|
If False, uses PyTorch convolutions. If None, either one will be chosen automatically |
|
depending on the effective filter size. |
|
|
|
|
|
..warning:: |
|
All the filters will use the same filter size, aligned on the lowest |
|
frequency provided. If you combine a lot of filters with very diverse frequencies, it might |
|
be more efficient to split them over multiple modules with similar frequencies. |
|
|
|
..note:: |
|
A lowpass with a cutoff frequency of 0 is defined as the null function |
|
by convention here. This allows for a highpass with a cutoff of 0 to |
|
be equal to identity, as defined in `julius.filters.HighPassFilters`. |
|
|
|
Shape: |
|
|
|
- Input: `[*, T]` |
|
- Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and |
|
`F` is the numer of cutoff frequencies. |
|
|
|
>>> lowpass = LowPassFilters([1/4]) |
|
>>> x = torch.randn(4, 12, 21, 1024) |
|
>>> list(lowpass(x).shape) |
|
[1, 4, 12, 21, 1024] |
|
""" |
|
|
|
def __init__(self, cutoffs: Sequence[float], stride: int = 1, pad: bool = True, |
|
zeros: float = 8, fft: Optional[bool] = None): |
|
super().__init__() |
|
self.cutoffs = list(cutoffs) |
|
if min(self.cutoffs) < 0: |
|
raise ValueError("Minimum cutoff must be larger than zero.") |
|
if max(self.cutoffs) > 0.5: |
|
raise ValueError("A cutoff above 0.5 does not make sense.") |
|
self.stride = stride |
|
self.pad = pad |
|
self.zeros = zeros |
|
self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) / 2) |
|
if fft is None: |
|
fft = self.half_size > 32 |
|
self.fft = fft |
|
window = torch.hann_window(2 * self.half_size + 1, periodic=False) |
|
time = torch.arange(-self.half_size, self.half_size + 1) |
|
filters = [] |
|
for cutoff in cutoffs: |
|
if cutoff == 0: |
|
filter_ = torch.zeros_like(time) |
|
else: |
|
filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi * time) |
|
|
|
|
|
filter_ /= filter_.sum() |
|
filters.append(filter_) |
|
self.register_buffer("filters", torch.stack(filters)[:, None]) |
|
|
|
def forward(self, input): |
|
shape = list(input.shape) |
|
input = input.view(-1, 1, shape[-1]) |
|
if self.pad: |
|
input = F.pad(input, (self.half_size, self.half_size), mode='replicate') |
|
if self.fft: |
|
out = fft_conv1d(input, self.filters, stride=self.stride) |
|
else: |
|
out = F.conv1d(input, self.filters, stride=self.stride) |
|
shape.insert(0, len(self.cutoffs)) |
|
shape[-1] = out.shape[-1] |
|
return out.permute(1, 0, 2).reshape(shape) |
|
|
|
def __repr__(self): |
|
return simple_repr(self) |
|
|
|
|
|
class LowPassFilter(torch.nn.Module): |
|
""" |
|
Same as `LowPassFilters` but applies a single low pass filter. |
|
|
|
Shape: |
|
|
|
- Input: `[*, T]` |
|
- Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1. |
|
|
|
>>> lowpass = LowPassFilter(1/4, stride=2) |
|
>>> x = torch.randn(4, 124) |
|
>>> list(lowpass(x).shape) |
|
[4, 62] |
|
""" |
|
|
|
def __init__(self, cutoff: float, stride: int = 1, pad: bool = True, |
|
zeros: float = 8, fft: Optional[bool] = None): |
|
super().__init__() |
|
self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft) |
|
|
|
@property |
|
def cutoff(self): |
|
return self._lowpasses.cutoffs[0] |
|
|
|
@property |
|
def stride(self): |
|
return self._lowpasses.stride |
|
|
|
@property |
|
def pad(self): |
|
return self._lowpasses.pad |
|
|
|
@property |
|
def zeros(self): |
|
return self._lowpasses.zeros |
|
|
|
@property |
|
def fft(self): |
|
return self._lowpasses.fft |
|
|
|
def forward(self, input): |
|
return self._lowpasses(input)[0] |
|
|
|
def __repr__(self): |
|
return simple_repr(self) |
|
|
|
|
|
def lowpass_filters(input: torch.Tensor, cutoffs: Sequence[float], |
|
stride: int = 1, pad: bool = True, |
|
zeros: float = 8, fft: Optional[bool] = None): |
|
""" |
|
Functional version of `LowPassFilters`, refer to this class for more information. |
|
""" |
|
return LowPassFilters(cutoffs, stride, pad, zeros, fft).to(input)(input) |
|
|
|
|
|
def lowpass_filter(input: torch.Tensor, cutoff: float, |
|
stride: int = 1, pad: bool = True, |
|
zeros: float = 8, fft: Optional[bool] = None): |
|
""" |
|
Same as `lowpass_filters` but with a single cutoff frequency. |
|
Output will not have a dimension inserted in the front. |
|
""" |
|
return lowpass_filters(input, [cutoff], stride, pad, zeros, fft)[0] |
|
|