File size: 1,130 Bytes
753e275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from torch.utils.data import Dataset, ConcatDataset
from diffab.utils.transforms import get_transform


_DATASET_DICT = {}


def register_dataset(name):
    def decorator(cls):
        _DATASET_DICT[name] = cls
        return cls
    return decorator


def get_dataset(cfg):
    transform = get_transform(cfg.transform) if 'transform' in cfg else None
    return _DATASET_DICT[cfg.type](cfg, transform=transform)


@register_dataset('concat')
def get_concat_dataset(cfg):
    datasets = [get_dataset(d) for d in cfg.datasets]
    return ConcatDataset(datasets)


@register_dataset('balanced_concat')
class BalancedConcatDataset(Dataset):

    def __init__(self, cfg, transform=None):
        super().__init__()
        assert transform is None, 'transform is not supported.'
        self.datasets = [get_dataset(d) for d in cfg.datasets]
        self.max_size = max([len(d) for d in self.datasets])

    def __len__(self):
        return self.max_size * len(self.datasets)

    def __getitem__(self, idx):
        dataset_idx = idx // self.max_size
        return self.datasets[dataset_idx][idx % len(self.datasets[dataset_idx])]