|
import os |
|
import sys |
|
import traceback |
|
import types |
|
from functools import wraps |
|
from itertools import chain |
|
import numpy as np |
|
import torch.utils.data |
|
from torch.utils.data import ConcatDataset |
|
from utils.commons.hparams import hparams |
|
|
|
|
|
|
|
def collate_xd(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): |
|
if len(values[0].shape) == 1: |
|
return collate_1d(values, pad_idx, left_pad, shift_right, max_len, shift_id) |
|
elif len(values[0].shape) == 2: |
|
return collate_2d(values, pad_idx, left_pad, shift_right, max_len) |
|
elif len(values[0].shape) == 3: |
|
return collate_3d(values, pad_idx, left_pad, shift_right, max_len) |
|
|
|
def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): |
|
if len(values[0].shape) == 1: |
|
return collate_1d(values, pad_idx, left_pad, shift_right, max_len, shift_id) |
|
else: |
|
return collate_2d(values, pad_idx, left_pad, shift_right, max_len) |
|
|
|
|
|
def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): |
|
"""Convert a list of 1d tensors into a padded 2d tensor.""" |
|
size = max(v.size(0) for v in values) if max_len is None else max_len |
|
res = values[0].new(len(values), size).fill_(pad_idx) |
|
|
|
def copy_tensor(src, dst): |
|
assert dst.numel() == src.numel() |
|
if shift_right: |
|
dst[1:] = src[:-1] |
|
dst[0] = shift_id |
|
else: |
|
dst.copy_(src) |
|
|
|
for i, v in enumerate(values): |
|
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) |
|
return res |
|
|
|
|
|
def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None): |
|
"""Convert a list of 2d tensors into a padded 3d tensor.""" |
|
size = max(v.size(0) for v in values) if max_len is None else max_len |
|
res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx) |
|
|
|
def copy_tensor(src, dst): |
|
assert dst.numel() == src.numel() |
|
if shift_right: |
|
dst[1:] = src[:-1] |
|
else: |
|
dst.copy_(src) |
|
|
|
for i, v in enumerate(values): |
|
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) |
|
return res |
|
|
|
def collate_3d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None): |
|
"""Convert a list of 2d tensors into a padded 3d tensor.""" |
|
size = max(v.size(0) for v in values) if max_len is None else max_len |
|
res = values[0].new(len(values), size, values[0].shape[1], values[0].shape[2]).fill_(pad_idx) |
|
|
|
def copy_tensor(src, dst): |
|
assert dst.numel() == src.numel() |
|
if shift_right: |
|
dst[1:] = src[:-1] |
|
else: |
|
dst.copy_(src) |
|
|
|
for i, v in enumerate(values): |
|
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) |
|
return res |
|
|
|
|
|
def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): |
|
if len(batch) == 0: |
|
return 0 |
|
if len(batch) == max_sentences: |
|
return 1 |
|
if num_tokens > max_tokens: |
|
return 1 |
|
return 0 |
|
|
|
|
|
def batch_by_size( |
|
indices, num_tokens_fn, max_tokens=None, max_sentences=None, |
|
required_batch_size_multiple=1, distributed=False |
|
): |
|
""" |
|
Yield mini-batches of indices bucketed by size. Batches may contain |
|
sequences of different lengths. |
|
|
|
Args: |
|
indices (List[int]): ordered list of dataset indices |
|
num_tokens_fn (callable): function that returns the number of tokens at |
|
a given index |
|
max_tokens (int, optional): max number of tokens in each batch |
|
(default: None). |
|
max_sentences (int, optional): max number of sentences in each |
|
batch (default: None). |
|
required_batch_size_multiple (int, optional): require batch size to |
|
be a multiple of N (default: 1). |
|
""" |
|
max_tokens = max_tokens if max_tokens is not None else sys.maxsize |
|
max_sentences = max_sentences if max_sentences is not None else sys.maxsize |
|
bsz_mult = required_batch_size_multiple |
|
|
|
if isinstance(indices, types.GeneratorType): |
|
indices = np.fromiter(indices, dtype=np.int64, count=-1) |
|
|
|
sample_len = 0 |
|
sample_lens = [] |
|
batch = [] |
|
batches = [] |
|
for i in range(len(indices)): |
|
idx = indices[i] |
|
num_tokens = num_tokens_fn(idx) |
|
sample_lens.append(num_tokens) |
|
sample_len = max(sample_len, num_tokens) |
|
|
|
assert sample_len <= max_tokens, ( |
|
"sentence at index {} of size {} exceeds max_tokens " |
|
"limit of {}!".format(idx, sample_len, max_tokens) |
|
) |
|
num_tokens = (len(batch) + 1) * sample_len |
|
|
|
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): |
|
mod_len = max( |
|
bsz_mult * (len(batch) // bsz_mult), |
|
len(batch) % bsz_mult, |
|
) |
|
batches.append(batch[:mod_len]) |
|
batch = batch[mod_len:] |
|
sample_lens = sample_lens[mod_len:] |
|
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 |
|
batch.append(idx) |
|
if len(batch) > 0: |
|
batches.append(batch) |
|
return batches |
|
|
|
|
|
def unpack_dict_to_list(samples): |
|
samples_ = [] |
|
bsz = samples.get('outputs').size(0) |
|
for i in range(bsz): |
|
res = {} |
|
for k, v in samples.items(): |
|
try: |
|
res[k] = v[i] |
|
except: |
|
pass |
|
samples_.append(res) |
|
return samples_ |
|
|
|
|
|
def remove_padding(x, padding_idx=0): |
|
if x is None: |
|
return None |
|
assert len(x.shape) in [1, 2] |
|
if len(x.shape) == 2: |
|
return x[np.abs(x).sum(-1) != padding_idx] |
|
elif len(x.shape) == 1: |
|
return x[x != padding_idx] |
|
|
|
|
|
def data_loader(fn): |
|
""" |
|
Decorator to make any fx with this use the lazy property |
|
:param fn: |
|
:return: |
|
""" |
|
|
|
wraps(fn) |
|
attr_name = '_lazy_' + fn.__name__ |
|
|
|
def _get_data_loader(self): |
|
try: |
|
value = getattr(self, attr_name) |
|
except AttributeError: |
|
try: |
|
value = fn(self) |
|
except AttributeError as e: |
|
|
|
traceback.print_exc() |
|
error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e) |
|
raise RuntimeError(error) from e |
|
setattr(self, attr_name, value) |
|
return value |
|
|
|
return _get_data_loader |
|
|
|
|
|
class BaseDataset(torch.utils.data.Dataset): |
|
def __init__(self, shuffle): |
|
super().__init__() |
|
self.hparams = hparams |
|
self.shuffle = shuffle |
|
self.sort_by_len = hparams['sort_by_len'] |
|
self.sizes = None |
|
|
|
@property |
|
def _sizes(self): |
|
return self.sizes |
|
|
|
def __getitem__(self, index): |
|
raise NotImplementedError |
|
|
|
def collater(self, samples): |
|
raise NotImplementedError |
|
|
|
def __len__(self): |
|
return len(self._sizes) |
|
|
|
def num_tokens(self, index): |
|
return self.size(index) |
|
|
|
def size(self, index): |
|
"""Return an example's size as a float or tuple. This value is used when |
|
filtering a dataset with ``--max-positions``.""" |
|
return min(self._sizes[index], hparams['max_frames']) |
|
|
|
def ordered_indices(self): |
|
"""Return an ordered list of indices. Batches will be constructed based |
|
on this order.""" |
|
if self.shuffle: |
|
indices = np.random.permutation(len(self)) |
|
if self.sort_by_len: |
|
indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')] |
|
else: |
|
indices = np.arange(len(self)) |
|
return indices |
|
|
|
@property |
|
def num_workers(self): |
|
return int(os.getenv('NUM_WORKERS', hparams['num_workers'])) |
|
|
|
|
|
class BaseConcatDataset(ConcatDataset): |
|
def collater(self, samples): |
|
return self.datasets[0].collater(samples) |
|
|
|
@property |
|
def _sizes(self): |
|
if not hasattr(self, 'sizes'): |
|
self.sizes = list(chain.from_iterable([d._sizes for d in self.datasets])) |
|
return self.sizes |
|
|
|
def size(self, index): |
|
return min(self._sizes[index], hparams['max_frames']) |
|
|
|
def num_tokens(self, index): |
|
return self.size(index) |
|
|
|
def ordered_indices(self): |
|
"""Return an ordered list of indices. Batches will be constructed based |
|
on this order.""" |
|
if self.datasets[0].shuffle: |
|
indices = np.random.permutation(len(self)) |
|
if self.datasets[0].sort_by_len: |
|
indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')] |
|
else: |
|
indices = np.arange(len(self)) |
|
return indices |
|
|
|
@property |
|
def num_workers(self): |
|
return self.datasets[0].num_workers |
|
|