BuboGPT / bubogpt /processors /audio_augment.py
ikuinen99's picture
update
e4bd7f9
raw
history blame
No virus
6.95 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Author : Xinhao Mei @CVSSP, University of Surrey
# @E-mail : x.mei@surrey.ac.uk
"""
Implemenation of SpecAugment++,
Adapated from Qiuqiang Kong's trochlibrosa:
https://github.com/qiuqiangkong/torchlibrosa/blob/master/torchlibrosa/augmentation.py
"""
import torch
import torch.nn as nn
class DropStripes:
def __init__(self, dim, drop_width, stripes_num):
""" Drop stripes.
args:
dim: int, dimension along which to drop
drop_width: int, maximum width of stripes to drop
stripes_num: int, how many stripes to drop
"""
super(DropStripes, self).__init__()
assert dim in [2, 3] # dim 2: time; dim 3: frequency
self.dim = dim
self.drop_width = drop_width
self.stripes_num = stripes_num
def __call__(self, input):
"""input: (batch_size, channels, time_steps, freq_bins)"""
assert input.ndimension() == 4
batch_size = input.shape[0]
total_width = input.shape[self.dim]
for n in range(batch_size):
self.transform_slice(input[n], total_width)
return input
def transform_slice(self, e, total_width):
""" e: (channels, time_steps, freq_bins)"""
for _ in range(self.stripes_num):
distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0]
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
if self.dim == 2:
e[:, bgn: bgn + distance, :] = 0
elif self.dim == 3:
e[:, :, bgn: bgn + distance] = 0
class MixStripes:
def __init__(self, dim, mix_width, stripes_num):
""" Mix stripes
args:
dim: int, dimension along which to mix
mix_width: int, maximum width of stripes to mix
stripes_num: int, how many stripes to mix
"""
super(MixStripes, self).__init__()
assert dim in [2, 3]
self.dim = dim
self.mix_width = mix_width
self.stripes_num = stripes_num
def __call__(self, input):
"""input: (batch_size, channel, time_steps, freq_bins)"""
assert input.ndimension() == 4
batch_size = input.shape[0]
total_width = input.shape[self.dim]
rand_sample = input[torch.randperm(batch_size)]
for i in range(batch_size):
self.transform_slice(input[i], rand_sample[i], total_width)
return input
def transform_slice(self, input, random_sample, total_width):
for _ in range(self.stripes_num):
distance = torch.randint(low=0, high=self.mix_width, size=(1,))[0]
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
if self.dim == 2:
input[:, bgn: bgn + distance, :] = 0.5 * input[:, bgn: bgn + distance, :] + \
0.5 * random_sample[:, bgn: bgn + distance, :]
elif self.dim == 3:
input[:, :, bgn: bgn + distance] = 0.5 * input[:, :, bgn: bgn + distance] + \
0.5 * random_sample[:, :, bgn: bgn + distance]
class CutStripes:
def __init__(self, dim, cut_width, stripes_num):
""" Cutting stripes with another randomly selected sample in mini-batch.
args:
dim: int, dimension along which to cut
cut_width: int, maximum width of stripes to cut
stripes_num: int, how many stripes to cut
"""
super(CutStripes, self).__init__()
assert dim in [2, 3]
self.dim = dim
self.cut_width = cut_width
self.stripes_num = stripes_num
def __call__(self, input):
"""input: (batch_size, channel, time_steps, freq_bins)"""
assert input.ndimension() == 4
batch_size = input.shape[0]
total_width = input.shape[self.dim]
rand_sample = input[torch.randperm(batch_size)]
for i in range(batch_size):
self.transform_slice(input[i], rand_sample[i], total_width)
return input
def transform_slice(self, input, random_sample, total_width):
for _ in range(self.stripes_num):
distance = torch.randint(low=0, high=self.cut_width, size=(1,))[0]
bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]
if self.dim == 2:
input[:, bgn: bgn + distance, :] = random_sample[:, bgn: bgn + distance, :]
elif self.dim == 3:
input[:, :, bgn: bgn + distance] = random_sample[:, :, bgn: bgn + distance]
class SpecAugmentation:
def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, freq_stripes_num,
mask_type='mixture'):
"""Spec augmetation and SpecAugment++.
[ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D.
and Le, Q.V., 2019. Specaugment: A simple data augmentation method
for automatic speech recognition. arXiv preprint arXiv:1904.08779.
[ref] Wang H, Zou Y, Wang W., 2021. SpecAugment++: A Hidden Space
Data Augmentation Method for Acoustic Scene Classification. arXiv
preprint arXiv:2103.16858.
Args:
time_drop_width: int
time_stripes_num: int
freq_drop_width: int
freq_stripes_num: int
mask_type: str, mask type in SpecAugment++ (zero_value, mixture, cutting)
"""
super(SpecAugmentation, self).__init__()
if mask_type == 'zero_value':
self.time_augmentator = DropStripes(dim=2, drop_width=time_drop_width,
stripes_num=time_stripes_num)
self.freq_augmentator = DropStripes(dim=3, drop_width=freq_drop_width,
stripes_num=freq_stripes_num)
elif mask_type == 'mixture':
self.time_augmentator = MixStripes(dim=2, mix_width=time_drop_width,
stripes_num=time_stripes_num)
self.freq_augmentator = MixStripes(dim=3, mix_width=freq_drop_width,
stripes_num=freq_stripes_num)
elif mask_type == 'cutting':
self.time_augmentator = CutStripes(dim=2, cut_width=time_drop_width,
stripes_num=time_stripes_num)
self.freq_augmentator = CutStripes(dim=3, cut_width=freq_drop_width,
stripes_num=freq_stripes_num)
else:
raise NameError('No such mask type in SpecAugment++')
def __call__(self, inputs):
# x should be in size [batch_size, channel, time_steps, freq_bins]
x = self.time_augmentator(inputs)
x = self.freq_augmentator(x)
return x