|
|
|
""" SincNet model """ |
|
from functools import lru_cache |
|
import numpy as np |
|
import logging |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class SincNetFilterConvLayer(nn.Module): |
|
"""SincNet fast convolution filter layer""" |
|
|
|
def __init__(self, out_channels: int, kernel_size: int, sample_rate=16000, |
|
stride=1, padding=0, dilation=1, min_low_hz=50, min_band_hz=50, |
|
in_channels=1, requires_grad=False): |
|
""" |
|
Args: |
|
out_channels : `int` number of filters. |
|
kernel_size : `int` filter length. |
|
sample_rate : `int`, optional sample rate. Defaults to 16000. |
|
""" |
|
super(SincNetFilterConvLayer, self).__init__() |
|
|
|
if in_channels != 1: |
|
raise ValueError(f"SincNetFilterConvLayer only support in_channels = 1, was in_channels = {in_channels}") |
|
|
|
self._out_channels = out_channels |
|
self._kernel_size = kernel_size |
|
|
|
if kernel_size % 2 == 0: |
|
self._kernel_size += 1 |
|
|
|
self._stride = stride |
|
self._padding = padding |
|
self._dilation = dilation |
|
self._sample_rate = sample_rate |
|
self._min_low_hz = min_low_hz |
|
self._min_band_hz = min_band_hz |
|
|
|
|
|
low_hz = 30 |
|
high_hz = self._sample_rate / 2 - (self._min_low_hz + self._min_band_hz) |
|
mel = np.linspace( |
|
2595 * np.log10(1 + low_hz / 700), |
|
2595 * np.log10(1 + high_hz / 700), |
|
self._out_channels // 2 + 1 |
|
) |
|
hz = 700 * (10 ** (mel / 2595) - 1) |
|
|
|
self._low_hz = nn.Parameter( |
|
torch.Tensor(hz[:-1]).view(-1, 1), |
|
requires_grad=requires_grad |
|
) |
|
self._band_hz = nn.Parameter( |
|
torch.Tensor(np.diff(hz)).view(-1, 1), |
|
requires_grad=requires_grad |
|
) |
|
self.register_buffer( |
|
"_window", |
|
torch.from_numpy(np.hamming(self._kernel_size)[: self._kernel_size // 2]).float() |
|
) |
|
self.register_buffer( |
|
"_n", |
|
(2* np.pi * torch.arange(-(self._kernel_size // 2), 0.0).view(1, -1) / self._sample_rate) |
|
) |
|
|
|
@property |
|
@lru_cache(maxsize=1) |
|
def filters(self) -> torch.Tensor: |
|
low = self._min_low_hz + torch.abs(self._low_hz) |
|
high = torch.clamp(low + self._min_band_hz + torch.abs(self._band_hz), self._min_low_hz, self._sample_rate/2) |
|
band = (high-low)[:,0] |
|
|
|
f_times_t_low = torch.matmul(low, self._n) |
|
f_times_t_high = torch.matmul(high, self._n) |
|
|
|
band_pass_left = ((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self._n/2))*self._window |
|
band_pass_center = 2 * band.view(-1, 1) |
|
band_pass_right = torch.flip(band_pass_left, dims=[1]) |
|
|
|
band_pass = torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1) |
|
band_pass = band_pass / (2*band[:,None]) |
|
return band_pass.view(self._out_channels, 1, self._kernel_size) |
|
|
|
def forward(self, waveforms: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
waveforms : (batch_size, 1, n_samples) batch of waveforms. |
|
|
|
Returns: |
|
features : (batch_size, out_channels, n_samples_out) batch of sinc filters activations. |
|
""" |
|
return F.conv1d(waveforms, self.filters, stride=self._stride, |
|
padding=self._padding, dilation=self._dilation, |
|
).abs_() |
|
|
|
class SincNet(nn.Module): |
|
"""SincNet""" |
|
|
|
def __init__( |
|
self, |
|
num_sinc_filters: int = 80, |
|
sinc_filter_length: int = 251, |
|
num_conv_filters: int = 60, |
|
conv_filter_length: int = 5, |
|
pool_kernel_size: int = 3, |
|
pool_stride: int = 3, |
|
sample_rate: int = 16000, |
|
sinc_filter_stride: int = 10, |
|
sinc_filter_padding: int = 0, |
|
sinc_filter_dilation: int = 1, |
|
min_low_hz: int = 50, |
|
min_band_hz: int = 50, |
|
sinc_filter_in_channels: int = 1, |
|
num_wavform_channels: int = 1, |
|
): |
|
super().__init__() |
|
|
|
if sample_rate != 16000: |
|
raise NotImplementedError(f"SincNet only supports 16kHz audio (sample_rate = 16000), was sample_rate = {sample_rate}") |
|
|
|
self.wav_norm1d = nn.InstanceNorm1d(num_wavform_channels, affine=True) |
|
|
|
self.conv1d = nn.ModuleList([ |
|
SincNetFilterConvLayer( |
|
num_sinc_filters, |
|
sinc_filter_length, |
|
sample_rate=sample_rate, |
|
stride=sinc_filter_stride, |
|
padding=sinc_filter_padding, |
|
dilation=sinc_filter_dilation, |
|
min_low_hz=min_low_hz, |
|
min_band_hz=min_band_hz, |
|
in_channels=sinc_filter_in_channels, |
|
), |
|
nn.Conv1d(num_sinc_filters, num_conv_filters, conv_filter_length), |
|
nn.Conv1d(num_conv_filters, num_conv_filters, conv_filter_length), |
|
]) |
|
self.pool1d = nn.ModuleList([ |
|
nn.MaxPool1d(pool_kernel_size, stride=pool_stride), |
|
nn.MaxPool1d(pool_kernel_size, stride=pool_stride), |
|
nn.MaxPool1d(pool_kernel_size, stride=pool_stride), |
|
]) |
|
self.norm1d = nn.ModuleList([ |
|
nn.InstanceNorm1d(num_sinc_filters, affine=True), |
|
nn.InstanceNorm1d(num_conv_filters, affine=True), |
|
nn.InstanceNorm1d(num_conv_filters, affine=True), |
|
]) |
|
|
|
def forward(self, waveforms: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
waveforms : (batch, channel, sample) |
|
""" |
|
outputs = self.wav_norm1d(waveforms) |
|
|
|
for _, (conv1d, pool1d, norm1d) in enumerate( |
|
zip(self.conv1d, self.pool1d, self.norm1d) |
|
): |
|
outputs = conv1d(outputs) |
|
outputs = F.leaky_relu(norm1d(pool1d(outputs))) |
|
|
|
return outputs |
|
|