conex / espnet2 /enh /layers /dnn_wpe.py
tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
5.3 kB
from typing import Tuple
from pytorch_wpe import wpe_one_iteration
import torch
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet2.enh.layers.mask_estimator import MaskEstimator
class DNN_WPE(torch.nn.Module):
def __init__(
self,
wtype: str = "blstmp",
widim: int = 257,
wlayers: int = 3,
wunits: int = 300,
wprojs: int = 320,
dropout_rate: float = 0.0,
taps: int = 5,
delay: int = 3,
use_dnn_mask: bool = True,
nmask: int = 1,
nonlinear: str = "sigmoid",
iterations: int = 1,
normalization: bool = False,
eps: float = 1e-6,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
mask_flooring: bool = False,
flooring_thres: float = 1e-6,
use_torch_solver: bool = True,
):
super().__init__()
self.iterations = iterations
self.taps = taps
self.delay = delay
self.eps = eps
self.normalization = normalization
self.use_dnn_mask = use_dnn_mask
self.inverse_power = True
self.diagonal_loading = diagonal_loading
self.diag_eps = diag_eps
self.mask_flooring = mask_flooring
self.flooring_thres = flooring_thres
self.use_torch_solver = use_torch_solver
if self.use_dnn_mask:
self.nmask = nmask
self.mask_est = MaskEstimator(
wtype,
widim,
wlayers,
wunits,
wprojs,
dropout_rate,
nmask=nmask,
nonlinear=nonlinear,
)
else:
self.nmask = 1
def forward(
self, data: ComplexTensor, ilens: torch.LongTensor
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
"""DNN_WPE forward function.
Notation:
B: Batch
C: Channel
T: Time or Sequence length
F: Freq or Some dimension of the feature vector
Args:
data: (B, T, C, F)
ilens: (B,)
Returns:
enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
ilens: (B,)
masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
power (List[torch.Tensor]): (B, F, T)
"""
# (B, T, C, F) -> (B, F, C, T)
data = data.permute(0, 3, 2, 1)
enhanced = [data for i in range(self.nmask)]
masks = None
power = None
for i in range(self.iterations):
# Calculate power: (..., C, T)
power = [enh.real ** 2 + enh.imag ** 2 for enh in enhanced]
if i == 0 and self.use_dnn_mask:
# mask: (B, F, C, T)
masks, _ = self.mask_est(data, ilens)
# floor masks to increase numerical stability
if self.mask_flooring:
masks = [m.clamp(min=self.flooring_thres) for m in masks]
if self.normalization:
# Normalize along T
masks = [m / m.sum(dim=-1, keepdim=True) for m in masks]
# (..., C, T) * (..., C, T) -> (..., C, T)
power = [p * masks[i] for i, p in enumerate(power)]
# Averaging along the channel axis: (..., C, T) -> (..., T)
power = [p.mean(dim=-2).clamp(min=self.eps) for p in power]
# enhanced: (..., C, T) -> (..., C, T)
# NOTE(kamo): Calculate in double precision
enhanced = [
wpe_one_iteration(
data.contiguous().double(),
p.double(),
taps=self.taps,
delay=self.delay,
inverse_power=self.inverse_power,
)
for p in power
]
enhanced = [
enh.to(dtype=data.dtype).masked_fill(make_pad_mask(ilens, enh.real), 0)
for enh in enhanced
]
# (B, F, C, T) -> (B, T, C, F)
enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced]
if masks is not None:
masks = (
[m.transpose(-1, -3) for m in masks]
if self.nmask > 1
else masks[0].transpose(-1, -3)
)
if self.nmask == 1:
enhanced = enhanced[0]
return enhanced, ilens, masks, power
def predict_mask(
self, data: ComplexTensor, ilens: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""Predict mask for WPE dereverberation.
Args:
data (ComplexTensor): (B, T, C, F), double precision
ilens (torch.Tensor): (B,)
Returns:
masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
ilens (torch.Tensor): (B,)
"""
if self.use_dnn_mask:
masks, ilens = self.mask_est(data.permute(0, 3, 2, 1).float(), ilens)
# (B, F, C, T) -> (B, T, C, F)
masks = [m.transpose(-1, -3) for m in masks]
if self.nmask == 1:
masks = masks[0]
else:
masks = None
return masks, ilens