File size: 1,498 Bytes
882ea5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import bisect
import random
from typing import Iterable
from torch.utils.data import Dataset, IterableDataset
class ConcatRepeatDataset(Dataset):
datasets: list[Dataset]
cumulative_sizes: list[int]
repeats: list[int]
@staticmethod
def cumsum(sequence, repeats):
r, s = [], 0
for dataset, repeat in zip(sequence, repeats):
l = len(dataset) * repeat
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
super().__init__()
self.datasets = list(datasets)
self.repeats = repeats
assert len(self.datasets) > 0, "datasets should not be an empty iterable"
assert len(self.datasets) == len(
repeats
), "datasets and repeats should have the same length"
for d in self.datasets:
assert not isinstance(
d, IterableDataset
), "ConcatRepeatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
dataset = self.datasets[dataset_idx]
return dataset[sample_idx % len(dataset)]
|