thomaspaniagua
QuadAttack release
71f183c
raw
history blame
No virus
3.5 kB
import torch
import torchvision
import torchvision.transforms as T
from mmpretrain import datasets as mmdatasets
from mmpretrain.registry import TRANSFORMS
from mmengine.dataset import Compose
from torch import nn
from torch.utils.data import Dataset as TorchDataset
# This holds dataset instantiation functions by (dataset_name) tuple keys
DATASET_REGISTRY = {}
DATASET_PATH = "./datasets"
class MMPretrainWrapper(TorchDataset):
def __init__(self, mmdataset) -> None:
super().__init__()
self.mmdataset = mmdataset
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=256, edge='short'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
self.pipeline = self.init_pipeline(test_pipeline)
def init_pipeline(self, pipeline_cfg):
pipeline = Compose(
[TRANSFORMS.build(t) for t in pipeline_cfg])
return pipeline
@property
def classes(self):
return self.mmdataset.CLASSES
def __len__(self):
return len(self.mmdataset)
def __getitem__(self, index):
sample = self.mmdataset[index]
sample = self.pipeline(sample)
# Our interface expects images in [0-1]
img = sample["inputs"].float() / 255
return img, sample["data_samples"].gt_label.item()
def register_torchvision_dataset(dataset_name, dataset_cls, dataset_kwargs_train={}, dataset_kwargs_val={}):
def instantiate_dataset():
train_data = dataset_cls(
root=DATASET_PATH,
train=True,
download=True,
transform=T.ToTensor()
)
val_data = dataset_cls(
root=DATASET_PATH,
train=False,
download=True,
transform=T.ToTensor()
)
return train_data, val_data
DATASET_REGISTRY[dataset_name] = instantiate_dataset
def register_mmpretrain_dataset(dataset_name, dataset_cls, dataset_kwargs_train={}, dataset_kwargs_val={}):
def instantiate_dataset():
train_data = dataset_cls(**dataset_kwargs_train)
val_data = dataset_cls(**dataset_kwargs_val)
train_data = MMPretrainWrapper(train_data)
val_data = MMPretrainWrapper(val_data)
return train_data, val_data
DATASET_REGISTRY[dataset_name] = instantiate_dataset
def register_default_datasets():
register_torchvision_dataset("cifar10", torchvision.datasets.CIFAR10)
register_torchvision_dataset("cifar100", torchvision.datasets.CIFAR100)
register_mmpretrain_dataset("imagenet", mmdatasets.ImageNet,
dataset_kwargs_train=dict(
data_root = "data/imagenet",
data_prefix = "val",
ann_file = "meta/val.txt"
),
dataset_kwargs_val=dict(
data_root = "data/imagenet",
data_prefix = "val",
ann_file = "meta/val.txt"
))
def get_dataset(dataset_name):
"""
Returns an instance of a dataset
dataset_name: Name of desired dataset
"""
if dataset_name not in DATASET_REGISTRY:
raise Exception("Requested dataset not in registry")
return DATASET_REGISTRY[dataset_name]()