File size: 3,495 Bytes
71f183c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torchvision
import torchvision.transforms as T
from mmpretrain import datasets as mmdatasets
from mmpretrain.registry import TRANSFORMS
from mmengine.dataset import Compose

from torch import nn
from torch.utils.data import Dataset as TorchDataset

# This holds dataset instantiation functions by (dataset_name) tuple keys
DATASET_REGISTRY = {}
DATASET_PATH = "./datasets"

class MMPretrainWrapper(TorchDataset):
    def __init__(self, mmdataset) -> None:
        super().__init__()
        self.mmdataset = mmdataset

        test_pipeline = [
            dict(type='LoadImageFromFile'),
            dict(type='ResizeEdge', scale=256, edge='short'),
            dict(type='CenterCrop', crop_size=224),
            dict(type='PackInputs'),
        ]

        self.pipeline = self.init_pipeline(test_pipeline)

    def init_pipeline(self, pipeline_cfg):
        pipeline = Compose(
            [TRANSFORMS.build(t) for t in pipeline_cfg])
        return pipeline

    @property
    def classes(self):
        return self.mmdataset.CLASSES
    
    def __len__(self):
        return len(self.mmdataset)

    def __getitem__(self, index):
        sample = self.mmdataset[index]
        sample = self.pipeline(sample)

        # Our interface expects images in [0-1]
        img = sample["inputs"].float() / 255

        return img, sample["data_samples"].gt_label.item()


def register_torchvision_dataset(dataset_name, dataset_cls, dataset_kwargs_train={}, dataset_kwargs_val={}):
    def instantiate_dataset():
        train_data = dataset_cls(
            root=DATASET_PATH,
            train=True,
            download=True,
            transform=T.ToTensor()
        )

        val_data = dataset_cls(
            root=DATASET_PATH,
            train=False,
            download=True,
            transform=T.ToTensor()
        )

        return train_data, val_data

    DATASET_REGISTRY[dataset_name] = instantiate_dataset

def register_mmpretrain_dataset(dataset_name, dataset_cls, dataset_kwargs_train={}, dataset_kwargs_val={}):
    def instantiate_dataset():
        train_data = dataset_cls(**dataset_kwargs_train)
        val_data = dataset_cls(**dataset_kwargs_val)

        train_data = MMPretrainWrapper(train_data)
        val_data = MMPretrainWrapper(val_data)

        return train_data, val_data
    
    DATASET_REGISTRY[dataset_name] = instantiate_dataset

def register_default_datasets():
    register_torchvision_dataset("cifar10", torchvision.datasets.CIFAR10)
    register_torchvision_dataset("cifar100", torchvision.datasets.CIFAR100)
    register_mmpretrain_dataset("imagenet", mmdatasets.ImageNet, 
                                dataset_kwargs_train=dict(
                                    data_root = "data/imagenet", 
                                    data_prefix = "val", 
                                    ann_file = "meta/val.txt"
                                ),
                                dataset_kwargs_val=dict(
                                    data_root = "data/imagenet", 
                                    data_prefix = "val", 
                                    ann_file = "meta/val.txt"
                                ))

def get_dataset(dataset_name):
    """
    Returns an instance of a dataset

    dataset_name: Name of desired dataset
    """

    if dataset_name not in DATASET_REGISTRY:
        raise Exception("Requested dataset not in registry")        
    
    return DATASET_REGISTRY[dataset_name]()