Spaces:
Sleeping
Sleeping
import torch | |
import torchvision | |
import torchvision.transforms as T | |
from mmpretrain import datasets as mmdatasets | |
from mmpretrain.registry import TRANSFORMS | |
from mmengine.dataset import Compose | |
from torch import nn | |
from torch.utils.data import Dataset as TorchDataset | |
# This holds dataset instantiation functions by (dataset_name) tuple keys | |
DATASET_REGISTRY = {} | |
DATASET_PATH = "./datasets" | |
class MMPretrainWrapper(TorchDataset): | |
def __init__(self, mmdataset) -> None: | |
super().__init__() | |
self.mmdataset = mmdataset | |
test_pipeline = [ | |
dict(type='LoadImageFromFile'), | |
dict(type='ResizeEdge', scale=256, edge='short'), | |
dict(type='CenterCrop', crop_size=224), | |
dict(type='PackInputs'), | |
] | |
self.pipeline = self.init_pipeline(test_pipeline) | |
def init_pipeline(self, pipeline_cfg): | |
pipeline = Compose( | |
[TRANSFORMS.build(t) for t in pipeline_cfg]) | |
return pipeline | |
def classes(self): | |
return self.mmdataset.CLASSES | |
def __len__(self): | |
return len(self.mmdataset) | |
def __getitem__(self, index): | |
sample = self.mmdataset[index] | |
sample = self.pipeline(sample) | |
# Our interface expects images in [0-1] | |
img = sample["inputs"].float() / 255 | |
return img, sample["data_samples"].gt_label.item() | |
def register_torchvision_dataset(dataset_name, dataset_cls, dataset_kwargs_train={}, dataset_kwargs_val={}): | |
def instantiate_dataset(): | |
train_data = dataset_cls( | |
root=DATASET_PATH, | |
train=True, | |
download=True, | |
transform=T.ToTensor() | |
) | |
val_data = dataset_cls( | |
root=DATASET_PATH, | |
train=False, | |
download=True, | |
transform=T.ToTensor() | |
) | |
return train_data, val_data | |
DATASET_REGISTRY[dataset_name] = instantiate_dataset | |
def register_mmpretrain_dataset(dataset_name, dataset_cls, dataset_kwargs_train={}, dataset_kwargs_val={}): | |
def instantiate_dataset(): | |
train_data = dataset_cls(**dataset_kwargs_train) | |
val_data = dataset_cls(**dataset_kwargs_val) | |
train_data = MMPretrainWrapper(train_data) | |
val_data = MMPretrainWrapper(val_data) | |
return train_data, val_data | |
DATASET_REGISTRY[dataset_name] = instantiate_dataset | |
def register_default_datasets(): | |
register_torchvision_dataset("cifar10", torchvision.datasets.CIFAR10) | |
register_torchvision_dataset("cifar100", torchvision.datasets.CIFAR100) | |
register_mmpretrain_dataset("imagenet", mmdatasets.ImageNet, | |
dataset_kwargs_train=dict( | |
data_root = "data/imagenet", | |
data_prefix = "val", | |
ann_file = "meta/val.txt" | |
), | |
dataset_kwargs_val=dict( | |
data_root = "data/imagenet", | |
data_prefix = "val", | |
ann_file = "meta/val.txt" | |
)) | |
def get_dataset(dataset_name): | |
""" | |
Returns an instance of a dataset | |
dataset_name: Name of desired dataset | |
""" | |
if dataset_name not in DATASET_REGISTRY: | |
raise Exception("Requested dataset not in registry") | |
return DATASET_REGISTRY[dataset_name]() | |