artst-demo-asr / artst /data /multitask_dataset.py
amupd's picture
initial commit
8b33290
raw
history blame
10.7 kB
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import bisect
import logging
import numpy as np
from torch.utils.data.dataloader import default_collate
from fairseq.data import data_utils
from fairseq.data.fairseq_dataset import FairseqDataset
logger = logging.getLogger(__name__)
class MultitaskDataset(FairseqDataset):
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
curr_len = len(e)
r.append(curr_len + s)
s += curr_len
return r
def __init__(self, datasets, sample_ratios=1, batch_ratio=None):
super(MultitaskDataset, self).__init__()
assert len(datasets) > 0, "datasets should not be an empty iterable"
self.datasets = list(datasets)
if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets)
if batch_ratio is not None:
logger.info('batch ratio is ' + str(batch_ratio))
self.batch_ratio = batch_ratio
else:
self.batch_ratio = None
else:
logger.info('set sample ratio to ' + str(sample_ratios))
if batch_ratio is not None:
logger.info('batch ratio is ' + str(batch_ratio))
self.batch_ratio = batch_ratio
else:
self.batch_ratio = None
self.sample_ratios = sample_ratios
self._ordered_indices = None
self._update_size()
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
sample = self.datasets[dataset_idx][sample_idx]
if isinstance(sample, dict):
sample["dataset_idx"] = dataset_idx
else:
sample = sample + (dataset_idx,)
return sample
def _update_size(self):
self.cumulative_sizes = self.cumsum(self.datasets)
self.real_sizes = [len(d) for d in self.datasets]
def _get_dataset_and_sample_index(self, idx: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx]
return dataset_idx, sample_idx
def collater(self, samples, **extra_args):
# For now only supports datasets with same underlying collater implementations
if samples is not None and len(samples) > 0:
if isinstance(samples[0], dict):
dataset_idx = samples[0]["dataset_idx"]
else:
dataset_idx = samples[0][-1]
samples = [sample[:-1] for sample in samples]
else:
dataset_idx = 0
if hasattr(self.datasets[dataset_idx], "collater"):
return self.datasets[dataset_idx].collater(samples, **extra_args)
else:
return default_collate(samples, **extra_args)
def size(self, idx: int):
"""
Return an example's size as a float or tuple.
"""
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
return self.datasets[dataset_idx].size(sample_idx)
def num_tokens(self, index: int):
return np.max(self.size(index))
def attr(self, attr: str, index: int):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
return getattr(self.datasets[dataset_idx], attr, None)
@property
def sizes(self):
_dataset_sizes = []
for ds in self.datasets:
if isinstance(ds.sizes, np.ndarray):
_dataset_sizes.append(ds.sizes)
else:
# Only support underlying dataset with single size array.
assert isinstance(ds.sizes, list)
_dataset_sizes.append(ds.sizes[0])
return np.concatenate(_dataset_sizes)
@property
def supports_prefetch(self):
return all(d.supports_prefetch for d in self.datasets)
def ordered_indices(self):
# ordered_indices = []
# for i, dataset in enumerate(self.datasets):
# indice = dataset.ordered_indices()
# ordered_indices.append(indice)
if self._ordered_indices is None:
# Call the underlying dataset's ordered_indices() here, so that we
# get the same random ordering as we would have from using the
# underlying sub-datasets directly.
self._ordered_indices = [
dataset.ordered_indices()
for dataset in self.datasets
]
return np.arange(len(self))
def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds)
if getattr(ds, "supports_prefetch", False):
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to
def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
if not hasattr(self, "max_tokens"):
self.max_tokens = max_tokens
if not hasattr(self, "max_sentences"):
self.max_sentences = max_sentences
if not hasattr(self, "required_batch_size_multiple"):
self.required_batch_size_multiple = required_batch_size_multiple
batch_samplers = []
for i, dataset in enumerate(self.datasets):
batch_sampler = dataset.batch_by_size(
self._ordered_indices[i],
max_tokens=max_tokens if self.batch_ratio is None else max_tokens * self.batch_ratio[i],
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
if i > 0:
for batch in batch_sampler:
batch += self.cumulative_sizes[i - 1]
if self.sample_ratios[i] != 1.0:
batch_sampler = np.array(batch_sampler)
batch_sampler = np.random.choice(batch_sampler, int(len(batch_sampler) * self.sample_ratios[i]))
batch_sampler = list(batch_sampler)
logger.info('Adjust batch by ratio ' + str(self.sample_ratios[i]) + ' and the number of batch is ' + str(int(len(batch_sampler))) + ' for dataset ' + str(i))
batch_samplers.extend(batch_sampler)
return batch_samplers
def filter_indices_by_size(self, indices, max_positions):
"""
Filter each sub-dataset independently, then update the round robin to work
on the filtered sub-datasets.
"""
if not hasattr(self, "max_positions"):
self.max_positions = max_positions
ignored_some = False
for i in range(len(self.datasets)):
# ignored = []
self._ordered_indices[i], ignored = self.datasets[i].filter_indices_by_size(
self._ordered_indices[i], self.max_positions[i]
)
if len(ignored) > 0:
ignored_some = True
logger.warning(
f"{len(ignored)} samples from {i} have invalid sizes and will be skipped, "
f"max_positions={self.max_positions[i]}, first few sample ids={ignored[:10]}"
)
logger.info('update dataset size')
self._update_size()
# Since we are modifying in place the _ordered_indices,
# it's not possible anymore to return valid ignored indices.
# Hopefully the extra debug information print above should be enough to debug.
# Ideally we would receive ignore_invalid_inputs so that we could have
# a proper error message.
return (np.arange(len(self)), [0] if ignored_some else [])
@property
def can_reuse_epoch_itr_across_epochs(self):
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)
def shuffle_batches(self, batches, seed):
logger.info("shuffle batches")
new_batches_fromlist = []
new_batches_notlist = []
new_batches = []
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
for batch in batches:
if isinstance(batch, list):
# np.random.shuffle(batch)
new_batches_fromlist.append(batch)
else:
new_batches_notlist.append(batch)
logger.info("Get " + str(len(new_batches_fromlist)) + " chunk from speech sides")
logger.info("Get " + str(sum([len(batch_list) for batch_list in new_batches_fromlist])) + " batches from speech sides")
logger.info("Get " + str(len(new_batches_notlist)) + " batches from text sides")
if len(new_batches_fromlist) == 0:
return new_batches_notlist
st_ratio = int(len(new_batches_notlist) / len(new_batches_fromlist))
logger.info("Get st_ratio " + str(st_ratio))
last_idx = 0
for i in range(len(new_batches_fromlist)):
if i == len(new_batches_fromlist) - 1:
new_batches_fromlist[i].extend(new_batches_notlist[last_idx:])
else:
new_batches_fromlist[i].extend(new_batches_notlist[last_idx : last_idx + st_ratio])
np.random.shuffle(new_batches_fromlist[i])
new_batches.extend(new_batches_fromlist[i])
last_idx = last_idx + st_ratio
logger.info("Finish shuffle")
return new_batches
def reset_batch_sampler(self):
logger.info("reset batch sampler")
self._ordered_indices = [
self.datasets[i].ordered_indices()
for i in range(len(self.datasets))
]
self.filter_indices_by_size(None, None)
batch_samplers = self.batch_by_size(
None,
self.max_tokens,
self.max_sentences,
self.required_batch_size_multiple
)
return batch_samplers