Spaces:
Running
on
Zero
Running
on
Zero
# Very loosely inspired by indexed_dataset in Fairseq, Megatron | |
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py | |
import os | |
import random | |
import struct | |
import numpy as np | |
import torch | |
from torch.utils.data import IterableDataset, get_worker_info | |
dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16} | |
def code(dtype): | |
for k in dtypes: | |
if dtypes[k] == dtype: | |
return k | |
raise ValueError(dtype) | |
HDR_MAGIC = b"LITPKDS" | |
HDR_SIZE = 24 # bytes | |
class PackedDataset(IterableDataset): | |
def __init__( | |
self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0 | |
): | |
self._filenames = filenames | |
self._n_chunks = n_chunks | |
self._block_size = block_size | |
self._seed = seed | |
self._shuffle = shuffle | |
self._wrap = wrap | |
self._num_processes = num_processes | |
self._process_rank = process_rank | |
def __iter__(self): | |
worker_info = get_worker_info() | |
num_workers = worker_info.num_workers if worker_info is not None else 1 | |
worker_id = worker_info.id if worker_info is not None else 0 | |
num_shards = num_workers * self._num_processes | |
shard_id = self._process_rank * num_workers + worker_id | |
max_num_files = len(self._filenames) // num_shards * num_shards | |
filenames = self._filenames[shard_id:max_num_files:num_shards] | |
return PackedDatasetIterator( | |
filenames=filenames, | |
n_chunks=self._n_chunks, | |
block_size=self._block_size, | |
seed=self._seed, | |
shuffle=self._shuffle, | |
wrap=self._wrap, | |
) | |
class PackedDatasetBuilder(object): | |
def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None): | |
if dtype == "auto": | |
if vocab_size is None: | |
raise ValueError("vocab_size cannot be None when dtype='auto'") | |
if vocab_size is not None and vocab_size < 65500: | |
self._dtype = np.uint16 | |
else: | |
self._dtype = np.int32 | |
else: | |
self._dtype = dtype | |
self._counter = 0 | |
self._chunk_size = chunk_size | |
self._outdir = outdir | |
self._prefix = prefix | |
self._sep_token = sep_token | |
self._arr = np.zeros(self._chunk_size, dtype=self._dtype) | |
self._arr.fill(self._sep_token) | |
self._idx = 0 | |
self._version = 1 | |
self._filenames = [] | |
def _write_chunk(self): | |
filename = f"{self._prefix}_{self._counter:010d}.bin" | |
filename = os.path.join(self._outdir, filename) | |
with open(filename, "wb") as f: | |
f.write(HDR_MAGIC) | |
f.write(struct.pack("<Q", self._version)) | |
f.write(struct.pack("<B", code(self._dtype))) | |
f.write(struct.pack("<Q", self._chunk_size)) | |
f.write(self._arr.tobytes(order="C")) | |
self._filenames.append(filename) | |
self._counter += 1 | |
self._arr.fill(self._sep_token) | |
self._idx = 0 | |
def dtype(self): | |
return self._dtype | |
def filenames(self): | |
return self._filenames.copy() | |
def add_array(self, arr): | |
while self._idx + arr.shape[0] > self._chunk_size: | |
part_len = self._chunk_size - self._idx | |
self._arr[self._idx : self._idx + part_len] = arr[:part_len] | |
self._write_chunk() | |
arr = arr[part_len:] | |
arr_len = arr.shape[0] | |
self._arr[self._idx : self._idx + arr_len] = arr | |
self._idx += arr_len | |
def write_reminder(self): | |
self._write_chunk() | |
class PackedDatasetIterator: | |
def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): | |
self._seed = seed | |
self._shuffle = shuffle | |
self._rng = np.random.default_rng(seed) if shuffle else None | |
self._block_idxs = None | |
self._wrap = wrap | |
# TODO: instead of filenames, we could have a single text stream | |
# (or text file) with the sequence of all files to be | |
# fetched/loaded. | |
self._filenames = filenames | |
self._file_idx = 0 | |
self._n_chunks = n_chunks | |
self._dtype = None | |
self._block_size = block_size | |
self._n_blocks = None | |
self._mmaps = [] | |
self._buffers = [] | |
self._block_idxs = [] | |
self._curr_idx = 0 | |
self._load_n_chunks() | |
def _read_header(self, path): | |
with open(path, "rb") as f: | |
magic = f.read(len(HDR_MAGIC)) | |
assert magic == HDR_MAGIC, "File doesn't match expected format." | |
version = struct.unpack("<Q", f.read(8)) | |
assert version == (1,) | |
(dtype_code,) = struct.unpack("<B", f.read(1)) | |
dtype = dtypes[dtype_code] | |
(chunk_size,) = struct.unpack("<Q", f.read(8)) | |
return dtype, chunk_size | |
def _close_mmaps(self): | |
for mmap in self._mmaps: | |
mmap._mmap.close() | |
def _load_n_chunks(self): | |
self._close_mmaps() | |
self._mmaps = [] | |
self._buffers = [] | |
if self._n_chunks > len(self._filenames[self._file_idx :]): | |
if not self._wrap: | |
raise StopIteration | |
self._file_idx = 0 | |
for i in range(self._n_chunks): | |
filename = self._filenames[self._file_idx + i] | |
if self._dtype is None: | |
self._dtype, self._chunk_size = self._read_header(filename) | |
self._n_blocks = self._chunk_size // self._block_size | |
# TODO: check header matches with previous files | |
mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) | |
self._mmaps.append(mmap) | |
self._buffers.append(memoryview(mmap)) | |
self._file_idx += self._n_chunks | |
n_all_blocks = self._n_chunks * self._n_blocks | |
self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) | |
self._curr_idx = 0 | |
def __del__(self): | |
self._close_mmaps() | |
del self._mmaps | |
del self._buffers | |
def __iter__(self): | |
return self | |
def __next__(self): | |
if self._curr_idx >= len(self._block_idxs): | |
self._load_n_chunks() | |
# TODO: trigger fetching next next n_chunks if remote | |
block_idx = self._block_idxs[self._curr_idx] | |
chunk_id = block_idx // self._n_blocks | |
buffer = self._buffers[chunk_id] | |
elem_id = (block_idx % self._n_blocks) * self._block_size | |
offset = np.dtype(self._dtype).itemsize * elem_id | |
arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) | |
self._curr_idx += 1 | |
return torch.from_numpy(arr.astype(np.int64)) | |
class CombinedDataset(IterableDataset): | |
def __init__(self, datasets, seed, weights=None): | |
self._seed = seed | |
self._datasets = datasets | |
self._weights = weights | |
n_datasets = len(datasets) | |
if weights is None: | |
self._weights = [1 / n_datasets] * n_datasets | |
def __iter__(self): | |
return CombinedDatasetIterator(self._datasets, self._seed, self._weights) | |
class CombinedDatasetIterator: | |
def __init__(self, datasets, seed, weights): | |
self._datasets = [iter(el) for el in datasets] | |
self._weights = weights | |
self._rng = random.Random(seed) | |
def __next__(self): | |
(dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) | |
return next(dataset) | |