conex / espnet2 /layers /label_aggregation.py
tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
2.33 kB
import torch
from typeguard import check_argument_types
from typing import Optional
from typing import Tuple
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
class LabelAggregate(torch.nn.Module):
def __init__(
self,
win_length: int = 512,
hop_length: int = 128,
center: bool = True,
):
assert check_argument_types()
super().__init__()
self.win_length = win_length
self.hop_length = hop_length
self.center = center
def extra_repr(self):
return (
f"win_length={self.win_length}, "
f"hop_length={self.hop_length}, "
f"center={self.center}, "
)
def forward(
self, input: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""LabelAggregate forward function.
Args:
input: (Batch, Nsamples, Label_dim)
ilens: (Batch)
Returns:
output: (Batch, Frames, Label_dim)
"""
bs = input.size(0)
max_length = input.size(1)
label_dim = input.size(2)
# NOTE(jiatong):
# The default behaviour of label aggregation is compatible with
# torch.stft about framing and padding.
# Step1: center padding
if self.center:
pad = self.win_length // 2
max_length = max_length + 2 * pad
input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0)
nframe = (max_length - self.win_length) // self.hop_length + 1
# Step2: framing
output = input.as_strided(
(bs, nframe, self.win_length, label_dim),
(max_length * label_dim, self.hop_length * label_dim, label_dim, 1),
)
# Step3: aggregate label
output = torch.gt(output.sum(dim=2, keepdim=False), self.win_length // 2)
output = output.float()
# Step4: process lengths
if ilens is not None:
if self.center:
pad = self.win_length // 2
ilens = ilens + 2 * pad
olens = (ilens - self.win_length) // self.hop_length + 1
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
else:
olens = None
return output, olens