PoTaTo721's picture
update to 1.2
69e8a46
import bisect
import random
from typing import Iterable
from torch.utils.data import Dataset, IterableDataset
class ConcatRepeatDataset(Dataset):
datasets: list[Dataset]
cumulative_sizes: list[int]
repeats: list[int]
@staticmethod
def cumsum(sequence, repeats):
r, s = [], 0
for dataset, repeat in zip(sequence, repeats):
l = len(dataset) * repeat
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
super().__init__()
self.datasets = list(datasets)
self.repeats = repeats
assert len(self.datasets) > 0, "datasets should not be an empty iterable"
assert len(self.datasets) == len(
repeats
), "datasets and repeats should have the same length"
for d in self.datasets:
assert not isinstance(
d, IterableDataset
), "ConcatRepeatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
dataset = self.datasets[dataset_idx]
return dataset[sample_idx % len(dataset)]