from distutils.version import LooseVersion from typing import Sequence from typing import Union import torch from espnet2.asr.specaug.abs_specaug import AbsSpecAug from espnet2.layers.mask_along_axis import MaskAlongAxis from espnet2.layers.time_warp import TimeWarp if LooseVersion(torch.__version__) >= LooseVersion("1.1"): DEFAULT_TIME_WARP_MODE = "bicubic" else: # pytorch1.0 doesn't implement bicubic DEFAULT_TIME_WARP_MODE = "bilinear" class SpecAug(AbsSpecAug): """Implementation of SpecAug. Reference: Daniel S. Park et al. "SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition" .. warning:: When using cuda mode, time_warp doesn't have reproducibility due to `torch.nn.functional.interpolate`. """ def __init__( self, apply_time_warp: bool = True, time_warp_window: int = 5, time_warp_mode: str = DEFAULT_TIME_WARP_MODE, apply_freq_mask: bool = True, freq_mask_width_range: Union[int, Sequence[int]] = (0, 20), num_freq_mask: int = 2, apply_time_mask: bool = True, time_mask_width_range: Union[int, Sequence[int]] = (0, 100), num_time_mask: int = 2, ): if not apply_time_warp and not apply_time_mask and not apply_freq_mask: raise ValueError( "Either one of time_warp, time_mask, or freq_mask should be applied", ) super().__init__() self.apply_time_warp = apply_time_warp self.apply_freq_mask = apply_freq_mask self.apply_time_mask = apply_time_mask if apply_time_warp: self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode) else: self.time_warp = None if apply_freq_mask: self.freq_mask = MaskAlongAxis( dim="freq", mask_width_range=freq_mask_width_range, num_mask=num_freq_mask, ) else: self.freq_mask = None if apply_time_mask: self.time_mask = MaskAlongAxis( dim="time", mask_width_range=time_mask_width_range, num_mask=num_time_mask, ) else: self.time_mask = None def forward(self, x, x_lengths=None): if self.time_warp is not None: x, x_lengths = self.time_warp(x, x_lengths) if self.freq_mask is not None: x, x_lengths = self.freq_mask(x, x_lengths) if self.time_mask is not None: x, x_lengths = self.time_mask(x, x_lengths) return x, x_lengths