import torch from espnet2.enh.encoder.abs_encoder import AbsEncoder class ConvEncoder(AbsEncoder): """Convolutional encoder for speech enhancement and separation """ def __init__( self, channel: int, kernel_size: int, stride: int, ): super().__init__() self.conv1d = torch.nn.Conv1d( 1, channel, kernel_size=kernel_size, stride=stride, bias=False ) self.stride = stride self.kernel_size = kernel_size self._output_dim = channel @property def output_dim(self) -> int: return self._output_dim def forward(self, input: torch.Tensor, ilens: torch.Tensor): """Forward. Args: input (torch.Tensor): mixed speech [Batch, sample] ilens (torch.Tensor): input lengths [Batch] Returns: feature (torch.Tensor): mixed feature after encoder [Batch, flens, channel] """ assert input.dim() == 2, "Currently only support single channle input" input = torch.unsqueeze(input, 1) feature = self.conv1d(input) feature = torch.nn.functional.relu(feature) feature = feature.transpose(1, 2) flens = (ilens - self.kernel_size) // self.stride + 1 return feature, flens