Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# @Author : Xinhao Mei @CVSSP, University of Surrey | |
# @E-mail : x.mei@surrey.ac.uk | |
""" | |
Implemenation of SpecAugment++, | |
Adapated from Qiuqiang Kong's trochlibrosa: | |
https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/augmentation.py | |
""" | |
import torch | |
import torch.nn as nn | |
class DropStripes: | |
def __init__(self, dim, drop_width, stripes_num): | |
""" Drop stripes. | |
args: | |
dim: int, dimension along which to drop | |
drop_width: int, maximum width of stripes to drop | |
stripes_num: int, how many stripes to drop | |
""" | |
super(DropStripes, self).__init__() | |
assert dim in [2, 3] # dim 2: time; dim 3: frequency | |
self.dim = dim | |
self.drop_width = drop_width | |
self.stripes_num = stripes_num | |
def __call__(self, input): | |
"""input: (batch_size, channels, time_steps, freq_bins)""" | |
assert input.ndimension() == 4 | |
batch_size = input.shape[0] | |
total_width = input.shape[self.dim] | |
for n in range(batch_size): | |
self.transform_slice(input[n], total_width) | |
return input | |
def transform_slice(self, e, total_width): | |
""" e: (channels, time_steps, freq_bins)""" | |
for _ in range(self.stripes_num): | |
distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0] | |
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] | |
if self.dim == 2: | |
e[:, bgn: bgn + distance, :] = 0 | |
elif self.dim == 3: | |
e[:, :, bgn: bgn + distance] = 0 | |
class MixStripes: | |
def __init__(self, dim, mix_width, stripes_num): | |
""" Mix stripes | |
args: | |
dim: int, dimension along which to mix | |
mix_width: int, maximum width of stripes to mix | |
stripes_num: int, how many stripes to mix | |
""" | |
super(MixStripes, self).__init__() | |
assert dim in [2, 3] | |
self.dim = dim | |
self.mix_width = mix_width | |
self.stripes_num = stripes_num | |
def __call__(self, input): | |
"""input: (batch_size, channel, time_steps, freq_bins)""" | |
assert input.ndimension() == 4 | |
batch_size = input.shape[0] | |
total_width = input.shape[self.dim] | |
rand_sample = input[torch.randperm(batch_size)] | |
for i in range(batch_size): | |
self.transform_slice(input[i], rand_sample[i], total_width) | |
return input | |
def transform_slice(self, input, random_sample, total_width): | |
for _ in range(self.stripes_num): | |
distance = torch.randint(low=0, high=self.mix_width, size=(1,))[0] | |
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] | |
if self.dim == 2: | |
input[:, bgn: bgn + distance, :] = 0.5 * input[:, bgn: bgn + distance, :] + \ | |
0.5 * random_sample[:, bgn: bgn + distance, :] | |
elif self.dim == 3: | |
input[:, :, bgn: bgn + distance] = 0.5 * input[:, :, bgn: bgn + distance] + \ | |
0.5 * random_sample[:, :, bgn: bgn + distance] | |
class CutStripes: | |
def __init__(self, dim, cut_width, stripes_num): | |
""" Cutting stripes with another randomly selected sample in mini-batch. | |
args: | |
dim: int, dimension along which to cut | |
cut_width: int, maximum width of stripes to cut | |
stripes_num: int, how many stripes to cut | |
""" | |
super(CutStripes, self).__init__() | |
assert dim in [2, 3] | |
self.dim = dim | |
self.cut_width = cut_width | |
self.stripes_num = stripes_num | |
def __call__(self, input): | |
"""input: (batch_size, channel, time_steps, freq_bins)""" | |
assert input.ndimension() == 4 | |
batch_size = input.shape[0] | |
total_width = input.shape[self.dim] | |
rand_sample = input[torch.randperm(batch_size)] | |
for i in range(batch_size): | |
self.transform_slice(input[i], rand_sample[i], total_width) | |
return input | |
def transform_slice(self, input, random_sample, total_width): | |
for _ in range(self.stripes_num): | |
distance = torch.randint(low=0, high=self.cut_width, size=(1,))[0] | |
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0] | |
if self.dim == 2: | |
input[:, bgn: bgn + distance, :] = random_sample[:, bgn: bgn + distance, :] | |
elif self.dim == 3: | |
input[:, :, bgn: bgn + distance] = random_sample[:, :, bgn: bgn + distance] | |
class SpecAugmentation: | |
def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, freq_stripes_num, | |
mask_type='mixture'): | |
"""Spec augmetation and SpecAugment++. | |
[ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D. | |
and Le, Q.V., 2019. Specaugment: A simple data augmentation method | |
for automatic speech recognition. arXiv preprint arXiv:1904.08779. | |
[ref] Wang H, Zou Y, Wang W., 2021. SpecAugment++: A Hidden Space | |
Data Augmentation Method for Acoustic Scene Classification. arXiv | |
preprint arXiv:2103.16858. | |
Args: | |
time_drop_width: int | |
time_stripes_num: int | |
freq_drop_width: int | |
freq_stripes_num: int | |
mask_type: str, mask type in SpecAugment++ (zero_value, mixture, cutting) | |
""" | |
super(SpecAugmentation, self).__init__() | |
if mask_type == 'zero_value': | |
self.time_augmentator = DropStripes(dim=2, drop_width=time_drop_width, | |
stripes_num=time_stripes_num) | |
self.freq_augmentator = DropStripes(dim=3, drop_width=freq_drop_width, | |
stripes_num=freq_stripes_num) | |
elif mask_type == 'mixture': | |
self.time_augmentator = MixStripes(dim=2, mix_width=time_drop_width, | |
stripes_num=time_stripes_num) | |
self.freq_augmentator = MixStripes(dim=3, mix_width=freq_drop_width, | |
stripes_num=freq_stripes_num) | |
elif mask_type == 'cutting': | |
self.time_augmentator = CutStripes(dim=2, cut_width=time_drop_width, | |
stripes_num=time_stripes_num) | |
self.freq_augmentator = CutStripes(dim=3, cut_width=freq_drop_width, | |
stripes_num=freq_stripes_num) | |
else: | |
raise NameError('No such mask type in SpecAugment++') | |
def __call__(self, inputs): | |
# x should be in size [batch_size, channel, time_steps, freq_bins] | |
x = self.time_augmentator(inputs) | |
x = self.freq_augmentator(x) | |
return x | |