|
import bisect |
|
import random |
|
from typing import Iterable |
|
|
|
from torch.utils.data import Dataset, IterableDataset |
|
|
|
|
|
class ConcatRepeatDataset(Dataset): |
|
datasets: list[Dataset] |
|
cumulative_sizes: list[int] |
|
repeats: list[int] |
|
|
|
@staticmethod |
|
def cumsum(sequence, repeats): |
|
r, s = [], 0 |
|
for dataset, repeat in zip(sequence, repeats): |
|
l = len(dataset) * repeat |
|
r.append(l + s) |
|
s += l |
|
return r |
|
|
|
def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): |
|
super().__init__() |
|
|
|
self.datasets = list(datasets) |
|
self.repeats = repeats |
|
|
|
assert len(self.datasets) > 0, "datasets should not be an empty iterable" |
|
assert len(self.datasets) == len( |
|
repeats |
|
), "datasets and repeats should have the same length" |
|
|
|
for d in self.datasets: |
|
assert not isinstance( |
|
d, IterableDataset |
|
), "ConcatRepeatDataset does not support IterableDataset" |
|
|
|
self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) |
|
|
|
def __len__(self): |
|
return self.cumulative_sizes[-1] |
|
|
|
def __getitem__(self, idx): |
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
|
|
|
if dataset_idx == 0: |
|
sample_idx = idx |
|
else: |
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
|
|
|
dataset = self.datasets[dataset_idx] |
|
|
|
return dataset[sample_idx % len(dataset)] |
|
|