conex / espnet2 /layers /mask_along_axis.py
tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
3.73 kB
import torch
from typeguard import check_argument_types
from typing import Sequence
from typing import Union
def mask_along_axis(
spec: torch.Tensor,
spec_lengths: torch.Tensor,
mask_width_range: Sequence[int] = (0, 30),
dim: int = 1,
num_mask: int = 2,
replace_with_zero: bool = True,
):
"""Apply mask along the specified direction.
Args:
spec: (Batch, Length, Freq)
spec_lengths: (Length): Not using lenghts in this implementation
mask_width_range: Select the width randomly between this range
"""
org_size = spec.size()
if spec.dim() == 4:
# spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
spec = spec.view(-1, spec.size(2), spec.size(3))
B = spec.shape[0]
# D = Length or Freq
D = spec.shape[dim]
# mask_length: (B, num_mask, 1)
mask_length = torch.randint(
mask_width_range[0],
mask_width_range[1],
(B, num_mask),
device=spec.device,
).unsqueeze(2)
# mask_pos: (B, num_mask, 1)
mask_pos = torch.randint(
0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
).unsqueeze(2)
# aran: (1, 1, D)
aran = torch.arange(D, device=spec.device)[None, None, :]
# mask: (Batch, num_mask, D)
mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
# Multiply masks: (Batch, num_mask, D) -> (Batch, D)
mask = mask.any(dim=1)
if dim == 1:
# mask: (Batch, Length, 1)
mask = mask.unsqueeze(2)
elif dim == 2:
# mask: (Batch, 1, Freq)
mask = mask.unsqueeze(1)
if replace_with_zero:
value = 0.0
else:
value = spec.mean()
if spec.requires_grad:
spec = spec.masked_fill(mask, value)
else:
spec = spec.masked_fill_(mask, value)
spec = spec.view(*org_size)
return spec, spec_lengths
class MaskAlongAxis(torch.nn.Module):
def __init__(
self,
mask_width_range: Union[int, Sequence[int]] = (0, 30),
num_mask: int = 2,
dim: Union[int, str] = "time",
replace_with_zero: bool = True,
):
assert check_argument_types()
if isinstance(mask_width_range, int):
mask_width_range = (0, mask_width_range)
if len(mask_width_range) != 2:
raise TypeError(
f"mask_width_range must be a tuple of int and int values: "
f"{mask_width_range}",
)
assert mask_width_range[1] > mask_width_range[0]
if isinstance(dim, str):
if dim == "time":
dim = 1
elif dim == "freq":
dim = 2
else:
raise ValueError("dim must be int, 'time' or 'freq'")
if dim == 1:
self.mask_axis = "time"
elif dim == 2:
self.mask_axis = "freq"
else:
self.mask_axis = "unknown"
super().__init__()
self.mask_width_range = mask_width_range
self.num_mask = num_mask
self.dim = dim
self.replace_with_zero = replace_with_zero
def extra_repr(self):
return (
f"mask_width_range={self.mask_width_range}, "
f"num_mask={self.num_mask}, axis={self.mask_axis}"
)
def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
"""Forward function.
Args:
spec: (Batch, Length, Freq)
"""
return mask_along_axis(
spec,
spec_lengths,
mask_width_range=self.mask_width_range,
dim=self.dim,
num_mask=self.num_mask,
replace_with_zero=self.replace_with_zero,
)