Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import numpy as np | |
import torch | |
from fairseq.data import FairseqDataset, plasma_utils | |
from fairseq.data.indexed_dataset import best_fitting_int_dtype | |
from typing import Tuple | |
class TokenBlockDataset(FairseqDataset): | |
"""Break a Dataset of tokens into blocks. | |
Args: | |
dataset (~torch.utils.data.Dataset): dataset to break into blocks | |
sizes (List[int]): sentence lengths (required for 'complete' and 'eos') | |
block_size (int): maximum block size (ignored in 'eos' break mode) | |
break_mode (str, optional): Mode used for breaking tokens. Values can | |
be one of: | |
- 'none': break tokens into equally sized blocks (up to block_size) | |
- 'complete': break tokens into blocks (up to block_size) such that | |
blocks contains complete sentences, although block_size may be | |
exceeded if some sentences exceed block_size | |
- 'complete_doc': similar to 'complete' mode, but do not | |
cross document boundaries | |
- 'eos': each block contains one sentence (block_size is ignored) | |
include_targets (bool, optional): return next tokens as targets | |
(default: False). | |
document_sep_len (int, optional): document separator size (required for | |
'complete_doc' break mode). Typically 1 if the sentences have eos | |
and 0 otherwise. | |
""" | |
def __init__( | |
self, | |
dataset, | |
sizes, | |
block_size, | |
pad, | |
eos, | |
break_mode=None, | |
include_targets=False, | |
document_sep_len=1, | |
use_plasma_view=False, | |
split_path=None, | |
plasma_path=None, | |
): | |
super().__init__() | |
self.dataset = dataset | |
self.pad = pad | |
self.eos = eos | |
self.include_targets = include_targets | |
assert len(dataset) > 0 | |
assert len(dataset) == len(sizes) | |
_sizes, block_to_dataset_index, slice_indices = self._build_slice_indices( | |
sizes, break_mode, document_sep_len, block_size | |
) | |
if use_plasma_view: | |
plasma_id = (block_size, document_sep_len, str(break_mode), len(dataset)) | |
self._slice_indices = plasma_utils.PlasmaView( | |
slice_indices, split_path, (plasma_id, 0), plasma_path=plasma_path | |
) | |
self._sizes = plasma_utils.PlasmaView( | |
_sizes, split_path, (plasma_id, 1), plasma_path=plasma_path | |
) | |
self._block_to_dataset_index = plasma_utils.PlasmaView( | |
block_to_dataset_index, split_path, (plasma_id, 2), plasma_path=plasma_path, | |
) | |
else: | |
self._slice_indices = plasma_utils.PlasmaArray(slice_indices) | |
self._sizes = plasma_utils.PlasmaArray(_sizes) | |
self._block_to_dataset_index = plasma_utils.PlasmaArray( | |
block_to_dataset_index | |
) | |
def _build_slice_indices( | |
sizes, break_mode, document_sep_len, block_size | |
) -> Tuple[np.ndarray]: | |
"""Use token_block_utils_fast to build arrays for indexing into self.dataset""" | |
try: | |
from fairseq.data.token_block_utils_fast import ( | |
_get_slice_indices_fast, | |
_get_block_to_dataset_index_fast, | |
) | |
except ImportError: | |
raise ImportError( | |
"Please build Cython components with: `pip install --editable .` " | |
"or `python setup.py build_ext --inplace`" | |
) | |
if isinstance(sizes, list): | |
sizes = np.array(sizes, dtype=np.int64) | |
else: | |
if torch.is_tensor(sizes): | |
sizes = sizes.numpy() | |
sizes = sizes.astype(np.int64) | |
break_mode = break_mode if break_mode is not None else "none" | |
# For "eos" break-mode, block_size is not required parameters. | |
if break_mode == "eos" and block_size is None: | |
block_size = 0 | |
slice_indices = _get_slice_indices_fast( | |
sizes, str(break_mode), block_size, document_sep_len | |
) | |
_sizes = slice_indices[:, 1] - slice_indices[:, 0] | |
# build index mapping block indices to the underlying dataset indices | |
if break_mode == "eos": | |
# much faster version for eos break mode | |
block_to_dataset_index = np.stack( | |
[ | |
np.arange(len(sizes)), # starting index in dataset | |
np.zeros( | |
len(sizes), dtype=np.compat.long | |
), # starting offset within starting index | |
np.arange(len(sizes)), # ending index in dataset | |
], | |
1, | |
) | |
else: | |
block_to_dataset_index = _get_block_to_dataset_index_fast( | |
sizes, slice_indices, | |
) | |
size_dtype = np.uint16 if block_size < 65535 else np.uint32 | |
num_tokens = slice_indices[-1].max() | |
slice_indices_dtype = best_fitting_int_dtype(num_tokens) | |
slice_indices = slice_indices.astype(slice_indices_dtype) | |
_sizes = _sizes.astype(size_dtype) | |
block_to_dataset_index = block_to_dataset_index.astype(slice_indices_dtype) | |
return _sizes, block_to_dataset_index, slice_indices | |
def slice_indices(self): | |
return self._slice_indices.array | |
def sizes(self): | |
return self._sizes.array | |
def block_to_dataset_index(self): | |
return self._block_to_dataset_index.array | |
def attr(self, attr: str, index: int): | |
start_ds_idx, _, _ = self.block_to_dataset_index[index] | |
return self.dataset.attr(attr, start_ds_idx) | |
def __getitem__(self, index): | |
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] | |
buffer = torch.cat( | |
[self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] | |
) | |
slice_s, slice_e = self.slice_indices[index] | |
length = slice_e - slice_s | |
s, e = start_offset, start_offset + length | |
item = buffer[s:e] | |
if self.include_targets: | |
# *target* is the original sentence (=item) | |
# *source* is shifted right by 1 (maybe left-padded with eos) | |
# *past_target* is shifted right by 2 (left-padded as needed) | |
if s == 0: | |
source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]]) | |
past_target = torch.cat( | |
[item.new([self.pad, self.eos]), buffer[0 : e - 2]] | |
) | |
else: | |
source = buffer[s - 1 : e - 1] | |
if s == 1: | |
past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]]) | |
else: | |
past_target = buffer[s - 2 : e - 2] | |
return source, item, past_target | |
return item | |
def __len__(self): | |
return len(self.slice_indices) | |
def supports_prefetch(self): | |
return getattr(self.dataset, "supports_prefetch", False) | |
def prefetch(self, indices): | |
self.dataset.prefetch( | |
{ | |
ds_idx | |
for index in indices | |
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]] | |
for ds_idx in range(start_ds_idx, end_ds_idx + 1) | |
} | |
) | |