Spaces:
Running
Running
# File under the MIT license, see https://github.com/adefossez/julius/LICENSE for details. | |
# Author: adefossez, 2020 | |
""" | |
FIR windowed sinc lowpass filters. | |
""" | |
import math | |
from typing import Sequence, Optional | |
import torch | |
from torch.nn import functional as F | |
from .core import sinc | |
from .fftconv import fft_conv1d | |
from .utils import simple_repr | |
class LowPassFilters(torch.nn.Module): | |
""" | |
Bank of low pass filters. Note that a high pass or band pass filter can easily | |
be implemented by substracting a same signal processed with low pass filters with different | |
frequencies (see `julius.bands.SplitBands` for instance). | |
This uses a windowed sinc filter, very similar to the one used in | |
`julius.resample`. However, because we do not change the sample rate here, | |
this filter can be much more efficiently implemented using the FFT convolution from | |
`julius.fftconv`. | |
Args: | |
cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where | |
f_s is the samplerate and `f` is the cutoff frequency. | |
The upper limit is 0.5, because a signal sampled at `f_s` contains only | |
frequencies under `f_s / 2`. | |
stride (int): how much to decimate the output. Keep in mind that decimation | |
of the output is only acceptable if the cutoff frequency is under `1/ (2 * stride)` | |
of the original sampling rate. | |
pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`, | |
the output will have the same length as the input. | |
zeros (float): Number of zero crossings to keep. | |
Controls the receptive field of the Finite Impulse Response filter. | |
For lowpass filters with low cutoff frequency, e.g. 40Hz at 44.1kHz, | |
it is a bad idea to set this to a high value. | |
This is likely appropriate for most use. Lower values | |
will result in a faster filter, but with a slower attenuation around the | |
cutoff frequency. | |
fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions. | |
If False, uses PyTorch convolutions. If None, either one will be chosen automatically | |
depending on the effective filter size. | |
..warning:: | |
All the filters will use the same filter size, aligned on the lowest | |
frequency provided. If you combine a lot of filters with very diverse frequencies, it might | |
be more efficient to split them over multiple modules with similar frequencies. | |
..note:: | |
A lowpass with a cutoff frequency of 0 is defined as the null function | |
by convention here. This allows for a highpass with a cutoff of 0 to | |
be equal to identity, as defined in `julius.filters.HighPassFilters`. | |
Shape: | |
- Input: `[*, T]` | |
- Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and | |
`F` is the numer of cutoff frequencies. | |
>>> lowpass = LowPassFilters([1/4]) | |
>>> x = torch.randn(4, 12, 21, 1024) | |
>>> list(lowpass(x).shape) | |
[1, 4, 12, 21, 1024] | |
""" | |
def __init__(self, cutoffs: Sequence[float], stride: int = 1, pad: bool = True, | |
zeros: float = 8, fft: Optional[bool] = None): | |
super().__init__() | |
self.cutoffs = list(cutoffs) | |
if min(self.cutoffs) < 0: | |
raise ValueError("Minimum cutoff must be larger than zero.") | |
if max(self.cutoffs) > 0.5: | |
raise ValueError("A cutoff above 0.5 does not make sense.") | |
self.stride = stride | |
self.pad = pad | |
self.zeros = zeros | |
self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) / 2) | |
if fft is None: | |
fft = self.half_size > 32 | |
self.fft = fft | |
window = torch.hann_window(2 * self.half_size + 1, periodic=False) | |
time = torch.arange(-self.half_size, self.half_size + 1) | |
filters = [] | |
for cutoff in cutoffs: | |
if cutoff == 0: | |
filter_ = torch.zeros_like(time) | |
else: | |
filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi * time) | |
# Normalize filter to have sum = 1, otherwise we will have a small leakage | |
# of the constant component in the input signal. | |
filter_ /= filter_.sum() | |
filters.append(filter_) | |
self.register_buffer("filters", torch.stack(filters)[:, None]) | |
def forward(self, input): | |
shape = list(input.shape) | |
input = input.view(-1, 1, shape[-1]) | |
if self.pad: | |
input = F.pad(input, (self.half_size, self.half_size), mode='replicate') | |
if self.fft: | |
out = fft_conv1d(input, self.filters, stride=self.stride) | |
else: | |
out = F.conv1d(input, self.filters, stride=self.stride) | |
shape.insert(0, len(self.cutoffs)) | |
shape[-1] = out.shape[-1] | |
return out.permute(1, 0, 2).reshape(shape) | |
def __repr__(self): | |
return simple_repr(self) | |
class LowPassFilter(torch.nn.Module): | |
""" | |
Same as `LowPassFilters` but applies a single low pass filter. | |
Shape: | |
- Input: `[*, T]` | |
- Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1. | |
>>> lowpass = LowPassFilter(1/4, stride=2) | |
>>> x = torch.randn(4, 124) | |
>>> list(lowpass(x).shape) | |
[4, 62] | |
""" | |
def __init__(self, cutoff: float, stride: int = 1, pad: bool = True, | |
zeros: float = 8, fft: Optional[bool] = None): | |
super().__init__() | |
self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft) | |
def cutoff(self): | |
return self._lowpasses.cutoffs[0] | |
def stride(self): | |
return self._lowpasses.stride | |
def pad(self): | |
return self._lowpasses.pad | |
def zeros(self): | |
return self._lowpasses.zeros | |
def fft(self): | |
return self._lowpasses.fft | |
def forward(self, input): | |
return self._lowpasses(input)[0] | |
def __repr__(self): | |
return simple_repr(self) | |
def lowpass_filters(input: torch.Tensor, cutoffs: Sequence[float], | |
stride: int = 1, pad: bool = True, | |
zeros: float = 8, fft: Optional[bool] = None): | |
""" | |
Functional version of `LowPassFilters`, refer to this class for more information. | |
""" | |
return LowPassFilters(cutoffs, stride, pad, zeros, fft).to(input)(input) | |
def lowpass_filter(input: torch.Tensor, cutoff: float, | |
stride: int = 1, pad: bool = True, | |
zeros: float = 8, fft: Optional[bool] = None): | |
""" | |
Same as `lowpass_filters` but with a single cutoff frequency. | |
Output will not have a dimension inserted in the front. | |
""" | |
return lowpass_filters(input, [cutoff], stride, pad, zeros, fft)[0] | |