Spaces:
Restarting
Restarting
File size: 9,094 Bytes
7bc29af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
# File under the MIT license, see https://github.com/adefossez/julius/LICENSE for details.
# Author: adefossez, 2021
"""
FIR windowed sinc highpass and bandpass filters.
Those are convenience wrappers around the filters defined in `julius.lowpass`.
"""
from typing import Sequence, Optional
import torch
# Import all lowpass filters for consistency.
from .lowpass import lowpass_filter, lowpass_filters, LowPassFilter, LowPassFilters # noqa
from .utils import simple_repr
class HighPassFilters(torch.nn.Module):
"""
Bank of high pass filters. See `julius.lowpass.LowPassFilters` for more
details on the implementation.
Args:
cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where
f_s is the samplerate and `f` is the cutoff frequency.
The upper limit is 0.5, because a signal sampled at `f_s` contains only
frequencies under `f_s / 2`.
stride (int): how much to decimate the output. Probably not a good idea
to do so with a high pass filters though...
pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`,
the output will have the same length as the input.
zeros (float): Number of zero crossings to keep.
Controls the receptive field of the Finite Impulse Response filter.
For filters with low cutoff frequency, e.g. 40Hz at 44.1kHz,
it is a bad idea to set this to a high value.
This is likely appropriate for most use. Lower values
will result in a faster filter, but with a slower attenuation around the
cutoff frequency.
fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions.
If False, uses PyTorch convolutions. If None, either one will be chosen automatically
depending on the effective filter size.
..warning::
All the filters will use the same filter size, aligned on the lowest
frequency provided. If you combine a lot of filters with very diverse frequencies, it might
be more efficient to split them over multiple modules with similar frequencies.
Shape:
- Input: `[*, T]`
- Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and
`F` is the numer of cutoff frequencies.
>>> highpass = HighPassFilters([1/4])
>>> x = torch.randn(4, 12, 21, 1024)
>>> list(highpass(x).shape)
[1, 4, 12, 21, 1024]
"""
def __init__(self, cutoffs: Sequence[float], stride: int = 1, pad: bool = True,
zeros: float = 8, fft: Optional[bool] = None):
super().__init__()
self._lowpasses = LowPassFilters(cutoffs, stride, pad, zeros, fft)
@property
def cutoffs(self):
return self._lowpasses.cutoffs
@property
def stride(self):
return self._lowpasses.stride
@property
def pad(self):
return self._lowpasses.pad
@property
def zeros(self):
return self._lowpasses.zeros
@property
def fft(self):
return self._lowpasses.fft
def forward(self, input):
lows = self._lowpasses(input)
# We need to extract the right portion of the input in case
# pad is False or stride > 1
if self.pad:
start, end = 0, input.shape[-1]
else:
start = self._lowpasses.half_size
end = -start
input = input[..., start:end:self.stride]
highs = input - lows
return highs
def __repr__(self):
return simple_repr(self)
class HighPassFilter(torch.nn.Module):
"""
Same as `HighPassFilters` but applies a single high pass filter.
Shape:
- Input: `[*, T]`
- Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1.
>>> highpass = HighPassFilter(1/4, stride=1)
>>> x = torch.randn(4, 124)
>>> list(highpass(x).shape)
[4, 124]
"""
def __init__(self, cutoff: float, stride: int = 1, pad: bool = True,
zeros: float = 8, fft: Optional[bool] = None):
super().__init__()
self._highpasses = HighPassFilters([cutoff], stride, pad, zeros, fft)
@property
def cutoff(self):
return self._highpasses.cutoffs[0]
@property
def stride(self):
return self._highpasses.stride
@property
def pad(self):
return self._highpasses.pad
@property
def zeros(self):
return self._highpasses.zeros
@property
def fft(self):
return self._highpasses.fft
def forward(self, input):
return self._highpasses(input)[0]
def __repr__(self):
return simple_repr(self)
def highpass_filters(input: torch.Tensor, cutoffs: Sequence[float],
stride: int = 1, pad: bool = True,
zeros: float = 8, fft: Optional[bool] = None):
"""
Functional version of `HighPassFilters`, refer to this class for more information.
"""
return HighPassFilters(cutoffs, stride, pad, zeros, fft).to(input)(input)
def highpass_filter(input: torch.Tensor, cutoff: float,
stride: int = 1, pad: bool = True,
zeros: float = 8, fft: Optional[bool] = None):
"""
Functional version of `HighPassFilter`, refer to this class for more information.
Output will not have a dimension inserted in the front.
"""
return highpass_filters(input, [cutoff], stride, pad, zeros, fft)[0]
class BandPassFilter(torch.nn.Module):
"""
Single band pass filter, implemented as a the difference of two lowpass filters.
Args:
cutoff_low (float): lower cutoff frequency, in [0, 0.5] expressed as `f/f_s` where
f_s is the samplerate and `f` is the cutoff frequency.
The upper limit is 0.5, because a signal sampled at `f_s` contains only
frequencies under `f_s / 2`.
cutoff_high (float): higher cutoff frequency, in [0, 0.5] expressed as `f/f_s`.
This must be higher than cutoff_high. Note that due to the fact
that filter are not perfect, the output will be non zero even if
cutoff_high == cutoff_low.
stride (int): how much to decimate the output.
pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`,
the output will have the same length as the input.
zeros (float): Number of zero crossings to keep.
Controls the receptive field of the Finite Impulse Response filter.
For filters with low cutoff frequency, e.g. 40Hz at 44.1kHz,
it is a bad idea to set this to a high value.
This is likely appropriate for most use. Lower values
will result in a faster filter, but with a slower attenuation around the
cutoff frequency.
fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions.
If False, uses PyTorch convolutions. If None, either one will be chosen automatically
depending on the effective filter size.
Shape:
- Input: `[*, T]`
- Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1.
..Note:: There is no BandPassFilters (bank of bandpasses) because its
signification would be the same as `julius.bands.SplitBands`.
>>> bandpass = BandPassFilter(1/4, 1/3)
>>> x = torch.randn(4, 12, 21, 1024)
>>> list(bandpass(x).shape)
[4, 12, 21, 1024]
"""
def __init__(self, cutoff_low: float, cutoff_high: float, stride: int = 1, pad: bool = True,
zeros: float = 8, fft: Optional[bool] = None):
super().__init__()
if cutoff_low > cutoff_high:
raise ValueError(f"Lower cutoff {cutoff_low} should be less than "
f"higher cutoff {cutoff_high}.")
self._lowpasses = LowPassFilters([cutoff_low, cutoff_high], stride, pad, zeros, fft)
@property
def cutoff_low(self):
return self._lowpasses.cutoffs[0]
@property
def cutoff_high(self):
return self._lowpasses.cutoffs[1]
@property
def stride(self):
return self._lowpasses.stride
@property
def pad(self):
return self._lowpasses.pad
@property
def zeros(self):
return self._lowpasses.zeros
@property
def fft(self):
return self._lowpasses.fft
def forward(self, input):
lows = self._lowpasses(input)
return lows[1] - lows[0]
def __repr__(self):
return simple_repr(self)
def bandpass_filter(input: torch.Tensor, cutoff_low: float, cutoff_high: float,
stride: int = 1, pad: bool = True,
zeros: float = 8, fft: Optional[bool] = None):
"""
Functional version of `BandPassfilter`, refer to this class for more information.
Output will not have a dimension inserted in the front.
"""
return BandPassFilter(cutoff_low, cutoff_high, stride, pad, zeros, fft).to(input)(input)
|