Spaces:
Running
on
Zero
Running
on
Zero
# File under the MIT license, see https://github.com/adefossez/julius/LICENSE for details. | |
# Author: adefossez, 2020 | |
""" | |
Implementation of a FFT based 1D convolution in PyTorch. | |
While FFT is used in CUDNN for small kernel sizes, it is not the case for long ones, e.g. 512. | |
This module implements efficient FFT based convolutions for such convolutions. A typical | |
application is for evaluationg FIR filters with a long receptive field, typically | |
evaluated with a stride of 1. | |
""" | |
from typing import Optional | |
import torch | |
try: | |
import torch.fft as new_fft | |
except ImportError: | |
new_fft = None # type: ignore | |
from torch.nn import functional as F | |
from .core import pad_to, unfold | |
from .utils import simple_repr | |
# This is quite verbose, but sadly needed to make TorchScript happy. | |
def _new_rfft(x: torch.Tensor): | |
z = new_fft.rfft(x, dim=-1) | |
return torch.view_as_real(z) | |
def _old_rfft(x: torch.Tensor): | |
return torch.rfft(x, 1) # type: ignore | |
def _old_irfft(x: torch.Tensor, length: int): | |
result = torch.irfft(x, 1, signal_sizes=(length,)) # type: ignore | |
return result | |
def _new_irfft(x: torch.Tensor, length: int): | |
x = torch.view_as_complex(x) | |
return new_fft.irfft(x, length, dim=-1) | |
if new_fft is None: | |
_rfft = _old_rfft | |
_irfft = _old_irfft | |
else: | |
_rfft = _new_rfft | |
_irfft = _new_irfft | |
def _compl_mul_conjugate(a: torch.Tensor, b: torch.Tensor): | |
""" | |
Given a and b two tensors of dimension 4 | |
with the last dimension being the real and imaginary part, | |
returns a multiplied by the conjugate of b, the multiplication | |
being with respect to the second dimension. | |
""" | |
# PyTorch 1.7 supports complex number, but not for all operations. | |
# Once the support is widespread, this can likely go away. | |
op = "bcft,dct->bdft" | |
return torch.stack([ | |
torch.einsum(op, a[..., 0], b[..., 0]) + torch.einsum(op, a[..., 1], b[..., 1]), | |
torch.einsum(op, a[..., 1], b[..., 0]) - torch.einsum(op, a[..., 0], b[..., 1]) | |
], | |
dim=-1) | |
def fft_conv1d( | |
input: torch.Tensor, weight: torch.Tensor, | |
bias: Optional[torch.Tensor] = None, stride: int = 1, padding: int = 0, | |
block_ratio: float = 5): | |
""" | |
Same as `torch.nn.functional.conv1d` but using FFT for the convolution. | |
Please check PyTorch documentation for more information. | |
Args: | |
input (Tensor): input signal of shape `[B, C, T]`. | |
weight (Tensor): weight of the convolution `[D, C, K]` with `D` the number | |
of output channels. | |
bias (Tensor or None): if not None, bias term for the convolution. | |
stride (int): stride of convolution. | |
padding (int): padding to apply to the input. | |
block_ratio (float): can be tuned for speed. The input is splitted in chunks | |
with a size of `int(block_ratio * kernel_size)`. | |
Shape: | |
- Inputs: `input` is `[B, C, T]`, `weight` is `[D, C, K]` and bias is `[D]`. | |
- Output: `(*, T)` | |
..note:: | |
This function is faster than `torch.nn.functional.conv1d` only in specific cases. | |
Typically, the kernel size should be of the order of 256 to see any real gain, | |
for a stride of 1. | |
..Warning:: | |
Dilation and groups are not supported at the moment. This function might use | |
more memory than the default Conv1d implementation. | |
""" | |
input = F.pad(input, (padding, padding)) | |
batch, channels, length = input.shape | |
out_channels, _, kernel_size = weight.shape | |
if length < kernel_size: | |
raise RuntimeError(f"Input should be at least as large as the kernel size {kernel_size}, " | |
f"but it is only {length} samples long.") | |
if block_ratio < 1: | |
raise RuntimeError("Block ratio must be greater than 1.") | |
# We are going to process the input blocks by blocks, as for some reason it is faster | |
# and less memory intensive (I think the culprit is `torch.einsum`. | |
block_size: int = min(int(kernel_size * block_ratio), length) | |
fold_stride = block_size - kernel_size + 1 | |
weight = pad_to(weight, block_size) | |
weight_z = _rfft(weight) | |
# We pad the input and get the different frames, on which | |
frames = unfold(input, block_size, fold_stride) | |
frames_z = _rfft(frames) | |
out_z = _compl_mul_conjugate(frames_z, weight_z) | |
out = _irfft(out_z, block_size) | |
# The last bit is invalid, because FFT will do a circular convolution. | |
out = out[..., :-kernel_size + 1] | |
out = out.reshape(batch, out_channels, -1) | |
out = out[..., ::stride] | |
target_length = (length - kernel_size) // stride + 1 | |
out = out[..., :target_length] | |
if bias is not None: | |
out += bias[:, None] | |
return out | |
class FFTConv1d(torch.nn.Module): | |
""" | |
Same as `torch.nn.Conv1d` but based on `fft_conv1d`. | |
Please check PyTorch documentation for more information. | |
Args: | |
in_channels (int): number of input channels. | |
out_channels (int): number of output channels. | |
kernel_size (int): kernel size of convolution. | |
stride (int): stride of convolution. | |
padding (int): padding to apply to the input. | |
bias (bool): if True, use a bias term. | |
..note:: | |
This module is faster than `torch.nn.Conv1d` only in specific cases. | |
Typically, `kernel_size` should be of the order of 256 to see any real gain, | |
for a stride of 1. | |
..warning:: | |
Dilation and groups are not supported at the moment. This module might use | |
more memory than the default Conv1d implementation. | |
>>> fftconv = FFTConv1d(12, 24, 128, 4) | |
>>> x = torch.randn(4, 12, 1024) | |
>>> print(list(fftconv(x).shape)) | |
[4, 24, 225] | |
""" | |
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, | |
stride: int = 1, padding: int = 0, bias: bool = True): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.padding = padding | |
conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, bias=bias) | |
self.weight = conv.weight | |
self.bias = conv.bias | |
def forward(self, input: torch.Tensor): | |
return fft_conv1d( | |
input, self.weight, self.bias, self.stride, self.padding) | |
def __repr__(self): | |
return simple_repr(self, overrides={"bias": self.bias is not None}) | |