|
from collections import OrderedDict |
|
from typing import List |
|
from typing import Tuple |
|
from typing import Union |
|
|
|
import torch |
|
from torch_complex.tensor import ComplexTensor |
|
|
|
from espnet2.enh.layers.tcn import TemporalConvNet |
|
from espnet2.enh.separator.abs_separator import AbsSeparator |
|
|
|
|
|
class TCNSeparator(AbsSeparator): |
|
def __init__( |
|
self, |
|
input_dim: int, |
|
num_spk: int = 2, |
|
layer: int = 8, |
|
stack: int = 3, |
|
bottleneck_dim: int = 128, |
|
hidden_dim: int = 512, |
|
kernel: int = 3, |
|
causal: bool = False, |
|
norm_type: str = "gLN", |
|
nonlinear: str = "relu", |
|
): |
|
"""Temporal Convolution Separator |
|
|
|
Args: |
|
input_dim: input feature dimension |
|
num_spk: number of speakers |
|
layer: int, number of layers in each stack. |
|
stack: int, number of stacks |
|
bottleneck_dim: bottleneck dimension |
|
hidden_dim: number of convolution channel |
|
kernel: int, kernel size. |
|
causal: bool, defalut False. |
|
norm_type: str, choose from 'BN', 'gLN', 'cLN' |
|
nonlinear: the nonlinear function for mask estimation, |
|
select from 'relu', 'tanh', 'sigmoid' |
|
""" |
|
super().__init__() |
|
|
|
self._num_spk = num_spk |
|
|
|
if nonlinear not in ("sigmoid", "relu", "tanh"): |
|
raise ValueError("Not supporting nonlinear={}".format(nonlinear)) |
|
|
|
self.tcn = TemporalConvNet( |
|
N=input_dim, |
|
B=bottleneck_dim, |
|
H=hidden_dim, |
|
P=kernel, |
|
X=layer, |
|
R=stack, |
|
C=num_spk, |
|
norm_type=norm_type, |
|
causal=causal, |
|
mask_nonlinear=nonlinear, |
|
) |
|
|
|
def forward( |
|
self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor |
|
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: |
|
"""Forward. |
|
|
|
Args: |
|
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] |
|
ilens (torch.Tensor): input lengths [Batch] |
|
|
|
Returns: |
|
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] |
|
ilens (torch.Tensor): (B,) |
|
others predicted data, e.g. masks: OrderedDict[ |
|
'mask_spk1': torch.Tensor(Batch, Frames, Freq), |
|
'mask_spk2': torch.Tensor(Batch, Frames, Freq), |
|
... |
|
'mask_spkn': torch.Tensor(Batch, Frames, Freq), |
|
] |
|
""" |
|
|
|
if isinstance(input, ComplexTensor): |
|
feature = abs(input) |
|
else: |
|
feature = input |
|
B, L, N = feature.shape |
|
|
|
feature = feature.transpose(1, 2) |
|
|
|
masks = self.tcn(feature) |
|
masks = masks.transpose(2, 3) |
|
masks = masks.unbind(dim=1) |
|
|
|
masked = [input * m for m in masks] |
|
|
|
others = OrderedDict( |
|
zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) |
|
) |
|
|
|
return masked, ilens, others |
|
|
|
@property |
|
def num_spk(self): |
|
return self._num_spk |
|
|