|
"""CausalConv1d module definition for custom decoder.""" |
|
|
|
import torch |
|
|
|
|
|
class CausalConv1d(torch.nn.Module): |
|
"""CausalConv1d module for custom decoder. |
|
|
|
Args: |
|
idim (int): dimension of inputs |
|
odim (int): dimension of outputs |
|
kernel_size (int): size of convolving kernel |
|
stride (int): stride of the convolution |
|
dilation (int): spacing between the kernel points |
|
groups (int): number of blocked connections from ichannels to ochannels |
|
bias (bool): whether to add a learnable bias to the output |
|
|
|
""" |
|
|
|
def __init__( |
|
self, idim, odim, kernel_size, stride=1, dilation=1, groups=1, bias=True |
|
): |
|
"""Construct a CausalConv1d object.""" |
|
super().__init__() |
|
|
|
self._pad = (kernel_size - 1) * dilation |
|
|
|
self.causal_conv1d = torch.nn.Conv1d( |
|
idim, |
|
odim, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=self._pad, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=bias, |
|
) |
|
|
|
def forward(self, x, x_mask, cache=None): |
|
"""CausalConv1d forward for x. |
|
|
|
Args: |
|
x (torch.Tensor): input torch (B, U, idim) |
|
x_mask (torch.Tensor): (B, 1, U) |
|
|
|
Returns: |
|
x (torch.Tensor): input torch (B, sub(U), attention_dim) |
|
x_mask (torch.Tensor): (B, 1, sub(U)) |
|
|
|
""" |
|
x = x.permute(0, 2, 1) |
|
x = self.causal_conv1d(x) |
|
|
|
if self._pad != 0: |
|
x = x[:, :, : -self._pad] |
|
|
|
x = x.permute(0, 2, 1) |
|
|
|
return x, x_mask |
|
|