Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
import bisect | |
import logging | |
import numpy as np | |
from torch.utils.data.dataloader import default_collate | |
from fairseq.data import data_utils | |
from fairseq.data.fairseq_dataset import FairseqDataset | |
logger = logging.getLogger(__name__) | |
class MultitaskDataset(FairseqDataset): | |
def cumsum(sequence): | |
r, s = [], 0 | |
for e in sequence: | |
curr_len = len(e) | |
r.append(curr_len + s) | |
s += curr_len | |
return r | |
def __init__(self, datasets, sample_ratios=1, batch_ratio=None): | |
super(MultitaskDataset, self).__init__() | |
assert len(datasets) > 0, "datasets should not be an empty iterable" | |
self.datasets = list(datasets) | |
if isinstance(sample_ratios, int): | |
sample_ratios = [sample_ratios] * len(self.datasets) | |
if batch_ratio is not None: | |
logger.info('batch ratio is ' + str(batch_ratio)) | |
self.batch_ratio = batch_ratio | |
else: | |
self.batch_ratio = None | |
else: | |
logger.info('set sample ratio to ' + str(sample_ratios)) | |
if batch_ratio is not None: | |
logger.info('batch ratio is ' + str(batch_ratio)) | |
self.batch_ratio = batch_ratio | |
else: | |
self.batch_ratio = None | |
self.sample_ratios = sample_ratios | |
self._ordered_indices = None | |
self._update_size() | |
def __len__(self): | |
return self.cumulative_sizes[-1] | |
def __getitem__(self, idx): | |
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) | |
sample = self.datasets[dataset_idx][sample_idx] | |
if isinstance(sample, dict): | |
sample["dataset_idx"] = dataset_idx | |
else: | |
sample = sample + (dataset_idx,) | |
return sample | |
def _update_size(self): | |
self.cumulative_sizes = self.cumsum(self.datasets) | |
self.real_sizes = [len(d) for d in self.datasets] | |
def _get_dataset_and_sample_index(self, idx: int): | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
if dataset_idx == 0: | |
sample_idx = idx | |
else: | |
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
sample_idx = sample_idx % self.real_sizes[dataset_idx] | |
return dataset_idx, sample_idx | |
def collater(self, samples, **extra_args): | |
# For now only supports datasets with same underlying collater implementations | |
if samples is not None and len(samples) > 0: | |
if isinstance(samples[0], dict): | |
dataset_idx = samples[0]["dataset_idx"] | |
else: | |
dataset_idx = samples[0][-1] | |
samples = [sample[:-1] for sample in samples] | |
else: | |
dataset_idx = 0 | |
if hasattr(self.datasets[dataset_idx], "collater"): | |
return self.datasets[dataset_idx].collater(samples, **extra_args) | |
else: | |
return default_collate(samples, **extra_args) | |
def size(self, idx: int): | |
""" | |
Return an example's size as a float or tuple. | |
""" | |
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) | |
return self.datasets[dataset_idx].size(sample_idx) | |
def num_tokens(self, index: int): | |
return np.max(self.size(index)) | |
def attr(self, attr: str, index: int): | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index) | |
return getattr(self.datasets[dataset_idx], attr, None) | |
def sizes(self): | |
_dataset_sizes = [] | |
for ds in self.datasets: | |
if isinstance(ds.sizes, np.ndarray): | |
_dataset_sizes.append(ds.sizes) | |
else: | |
# Only support underlying dataset with single size array. | |
assert isinstance(ds.sizes, list) | |
_dataset_sizes.append(ds.sizes[0]) | |
return np.concatenate(_dataset_sizes) | |
def supports_prefetch(self): | |
return all(d.supports_prefetch for d in self.datasets) | |
def ordered_indices(self): | |
# ordered_indices = [] | |
# for i, dataset in enumerate(self.datasets): | |
# indice = dataset.ordered_indices() | |
# ordered_indices.append(indice) | |
if self._ordered_indices is None: | |
# Call the underlying dataset's ordered_indices() here, so that we | |
# get the same random ordering as we would have from using the | |
# underlying sub-datasets directly. | |
self._ordered_indices = [ | |
dataset.ordered_indices() | |
for dataset in self.datasets | |
] | |
return np.arange(len(self)) | |
def prefetch(self, indices): | |
frm = 0 | |
for to, ds in zip(self.cumulative_sizes, self.datasets): | |
real_size = len(ds) | |
if getattr(ds, "supports_prefetch", False): | |
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) | |
frm = to | |
def batch_by_size( | |
self, | |
indices, | |
max_tokens=None, | |
max_sentences=None, | |
required_batch_size_multiple=1, | |
): | |
if not hasattr(self, "max_tokens"): | |
self.max_tokens = max_tokens | |
if not hasattr(self, "max_sentences"): | |
self.max_sentences = max_sentences | |
if not hasattr(self, "required_batch_size_multiple"): | |
self.required_batch_size_multiple = required_batch_size_multiple | |
batch_samplers = [] | |
for i, dataset in enumerate(self.datasets): | |
batch_sampler = dataset.batch_by_size( | |
self._ordered_indices[i], | |
max_tokens=max_tokens if self.batch_ratio is None else max_tokens * self.batch_ratio[i], | |
max_sentences=max_sentences, | |
required_batch_size_multiple=required_batch_size_multiple, | |
) | |
if i > 0: | |
for batch in batch_sampler: | |
batch += self.cumulative_sizes[i - 1] | |
if self.sample_ratios[i] != 1.0: | |
batch_sampler = np.array(batch_sampler) | |
batch_sampler = np.random.choice(batch_sampler, int(len(batch_sampler) * self.sample_ratios[i])) | |
batch_sampler = list(batch_sampler) | |
logger.info('Adjust batch by ratio ' + str(self.sample_ratios[i]) + ' and the number of batch is ' + str(int(len(batch_sampler))) + ' for dataset ' + str(i)) | |
batch_samplers.extend(batch_sampler) | |
return batch_samplers | |
def filter_indices_by_size(self, indices, max_positions): | |
""" | |
Filter each sub-dataset independently, then update the round robin to work | |
on the filtered sub-datasets. | |
""" | |
if not hasattr(self, "max_positions"): | |
self.max_positions = max_positions | |
ignored_some = False | |
for i in range(len(self.datasets)): | |
# ignored = [] | |
self._ordered_indices[i], ignored = self.datasets[i].filter_indices_by_size( | |
self._ordered_indices[i], self.max_positions[i] | |
) | |
if len(ignored) > 0: | |
ignored_some = True | |
logger.warning( | |
f"{len(ignored)} samples from {i} have invalid sizes and will be skipped, " | |
f"max_positions={self.max_positions[i]}, first few sample ids={ignored[:10]}" | |
) | |
logger.info('update dataset size') | |
self._update_size() | |
# Since we are modifying in place the _ordered_indices, | |
# it's not possible anymore to return valid ignored indices. | |
# Hopefully the extra debug information print above should be enough to debug. | |
# Ideally we would receive ignore_invalid_inputs so that we could have | |
# a proper error message. | |
return (np.arange(len(self)), [0] if ignored_some else []) | |
def can_reuse_epoch_itr_across_epochs(self): | |
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets) | |
def set_epoch(self, epoch): | |
super().set_epoch(epoch) | |
for ds in self.datasets: | |
if hasattr(ds, "set_epoch"): | |
ds.set_epoch(epoch) | |
def shuffle_batches(self, batches, seed): | |
logger.info("shuffle batches") | |
new_batches_fromlist = [] | |
new_batches_notlist = [] | |
new_batches = [] | |
with data_utils.numpy_seed(seed): | |
np.random.shuffle(batches) | |
for batch in batches: | |
if isinstance(batch, list): | |
# np.random.shuffle(batch) | |
new_batches_fromlist.append(batch) | |
else: | |
new_batches_notlist.append(batch) | |
logger.info("Get " + str(len(new_batches_fromlist)) + " chunk from speech sides") | |
logger.info("Get " + str(sum([len(batch_list) for batch_list in new_batches_fromlist])) + " batches from speech sides") | |
logger.info("Get " + str(len(new_batches_notlist)) + " batches from text sides") | |
if len(new_batches_fromlist) == 0: | |
return new_batches_notlist | |
st_ratio = int(len(new_batches_notlist) / len(new_batches_fromlist)) | |
logger.info("Get st_ratio " + str(st_ratio)) | |
last_idx = 0 | |
for i in range(len(new_batches_fromlist)): | |
if i == len(new_batches_fromlist) - 1: | |
new_batches_fromlist[i].extend(new_batches_notlist[last_idx:]) | |
else: | |
new_batches_fromlist[i].extend(new_batches_notlist[last_idx : last_idx + st_ratio]) | |
np.random.shuffle(new_batches_fromlist[i]) | |
new_batches.extend(new_batches_fromlist[i]) | |
last_idx = last_idx + st_ratio | |
logger.info("Finish shuffle") | |
return new_batches | |
def reset_batch_sampler(self): | |
logger.info("reset batch sampler") | |
self._ordered_indices = [ | |
self.datasets[i].ordered_indices() | |
for i in range(len(self.datasets)) | |
] | |
self.filter_indices_by_size(None, None) | |
batch_samplers = self.batch_by_size( | |
None, | |
self.max_tokens, | |
self.max_sentences, | |
self.required_batch_size_multiple | |
) | |
return batch_samplers | |