r3gm's picture
Upload 288 files
7bc29af
raw
history blame
6.92 kB
# File under the MIT license, see https://github.com/adefossez/julius/LICENSE for details.
# Author: adefossez, 2020
"""
FIR windowed sinc lowpass filters.
"""
import math
from typing import Sequence, Optional
import torch
from torch.nn import functional as F
from .core import sinc
from .fftconv import fft_conv1d
from .utils import simple_repr
class LowPassFilters(torch.nn.Module):
"""
Bank of low pass filters. Note that a high pass or band pass filter can easily
be implemented by substracting a same signal processed with low pass filters with different
frequencies (see `julius.bands.SplitBands` for instance).
This uses a windowed sinc filter, very similar to the one used in
`julius.resample`. However, because we do not change the sample rate here,
this filter can be much more efficiently implemented using the FFT convolution from
`julius.fftconv`.
Args:
cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where
f_s is the samplerate and `f` is the cutoff frequency.
The upper limit is 0.5, because a signal sampled at `f_s` contains only
frequencies under `f_s / 2`.
stride (int): how much to decimate the output. Keep in mind that decimation
of the output is only acceptable if the cutoff frequency is under `1/ (2 * stride)`
of the original sampling rate.
pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`,
the output will have the same length as the input.
zeros (float): Number of zero crossings to keep.
Controls the receptive field of the Finite Impulse Response filter.
For lowpass filters with low cutoff frequency, e.g. 40Hz at 44.1kHz,
it is a bad idea to set this to a high value.
This is likely appropriate for most use. Lower values
will result in a faster filter, but with a slower attenuation around the
cutoff frequency.
fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions.
If False, uses PyTorch convolutions. If None, either one will be chosen automatically
depending on the effective filter size.
..warning::
All the filters will use the same filter size, aligned on the lowest
frequency provided. If you combine a lot of filters with very diverse frequencies, it might
be more efficient to split them over multiple modules with similar frequencies.
..note::
A lowpass with a cutoff frequency of 0 is defined as the null function
by convention here. This allows for a highpass with a cutoff of 0 to
be equal to identity, as defined in `julius.filters.HighPassFilters`.
Shape:
- Input: `[*, T]`
- Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and
`F` is the numer of cutoff frequencies.
>>> lowpass = LowPassFilters([1/4])
>>> x = torch.randn(4, 12, 21, 1024)
>>> list(lowpass(x).shape)
[1, 4, 12, 21, 1024]
"""
def __init__(self, cutoffs: Sequence[float], stride: int = 1, pad: bool = True,
zeros: float = 8, fft: Optional[bool] = None):
super().__init__()
self.cutoffs = list(cutoffs)
if min(self.cutoffs) < 0:
raise ValueError("Minimum cutoff must be larger than zero.")
if max(self.cutoffs) > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.stride = stride
self.pad = pad
self.zeros = zeros
self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) / 2)
if fft is None:
fft = self.half_size > 32
self.fft = fft
window = torch.hann_window(2 * self.half_size + 1, periodic=False)
time = torch.arange(-self.half_size, self.half_size + 1)
filters = []
for cutoff in cutoffs:
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filters.append(filter_)
self.register_buffer("filters", torch.stack(filters)[:, None])
def forward(self, input):
shape = list(input.shape)
input = input.view(-1, 1, shape[-1])
if self.pad:
input = F.pad(input, (self.half_size, self.half_size), mode='replicate')
if self.fft:
out = fft_conv1d(input, self.filters, stride=self.stride)
else:
out = F.conv1d(input, self.filters, stride=self.stride)
shape.insert(0, len(self.cutoffs))
shape[-1] = out.shape[-1]
return out.permute(1, 0, 2).reshape(shape)
def __repr__(self):
return simple_repr(self)
class LowPassFilter(torch.nn.Module):
"""
Same as `LowPassFilters` but applies a single low pass filter.
Shape:
- Input: `[*, T]`
- Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1.
>>> lowpass = LowPassFilter(1/4, stride=2)
>>> x = torch.randn(4, 124)
>>> list(lowpass(x).shape)
[4, 62]
"""
def __init__(self, cutoff: float, stride: int = 1, pad: bool = True,
zeros: float = 8, fft: Optional[bool] = None):
super().__init__()
self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft)
@property
def cutoff(self):
return self._lowpasses.cutoffs[0]
@property
def stride(self):
return self._lowpasses.stride
@property
def pad(self):
return self._lowpasses.pad
@property
def zeros(self):
return self._lowpasses.zeros
@property
def fft(self):
return self._lowpasses.fft
def forward(self, input):
return self._lowpasses(input)[0]
def __repr__(self):
return simple_repr(self)
def lowpass_filters(input: torch.Tensor, cutoffs: Sequence[float],
stride: int = 1, pad: bool = True,
zeros: float = 8, fft: Optional[bool] = None):
"""
Functional version of `LowPassFilters`, refer to this class for more information.
"""
return LowPassFilters(cutoffs, stride, pad, zeros, fft).to(input)(input)
def lowpass_filter(input: torch.Tensor, cutoff: float,
stride: int = 1, pad: bool = True,
zeros: float = 8, fft: Optional[bool] = None):
"""
Same as `lowpass_filters` but with a single cutoff frequency.
Output will not have a dimension inserted in the front.
"""
return lowpass_filters(input, [cutoff], stride, pad, zeros, fft)[0]