|
from .utils import check_integrity, download_and_extract_archive |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
import os |
|
import os.path |
|
import numpy as np |
|
import pickle |
|
from typing import Any, Callable, Optional, Tuple |
|
|
|
import torchvision.transforms.functional as TF |
|
|
|
|
|
class UpsideDownDataset(Dataset): |
|
|
|
""" |
|
Adapted from torchvision source code. |
|
|
|
Horizontally flips every other image and makes its label '1', |
|
otherwise makes its label '0' |
|
""" |
|
|
|
base_folder = 'cifar-10-batches-py' |
|
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" |
|
filename = "cifar-10-python.tar.gz" |
|
tgz_md5 = 'c58f30108f718f92721af3b95e74349a' |
|
train_list = [ |
|
['data_batch_1', 'c99cafc152244af753f735de768cd75f'], |
|
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], |
|
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], |
|
['data_batch_4', '634d18415352ddfa80567beed471001a'], |
|
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], |
|
] |
|
|
|
test_list = [ |
|
['test_batch', '40351d587109b95175f43aff81a1287e'], |
|
] |
|
meta = { |
|
'filename': 'batches.meta', |
|
'key': 'label_names', |
|
'md5': '5ff9c542aee3614f3951f8cda6e48888', |
|
} |
|
|
|
def __init__( |
|
self, |
|
root: str, |
|
train: bool = True, |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
download: bool = False, |
|
) -> None: |
|
|
|
|
|
|
|
|
|
self.train = train |
|
self.root = root |
|
self.transform = transform |
|
self.target_transform = target_transform |
|
|
|
if download: |
|
self.download() |
|
|
|
|
|
|
|
|
|
|
|
if self.train: |
|
downloaded_list = self.train_list |
|
else: |
|
downloaded_list = self.test_list |
|
|
|
self.data: Any = [] |
|
self.targets = [] |
|
|
|
|
|
for file_name, checksum in downloaded_list: |
|
file_path = os.path.join(self.root, self.base_folder, file_name) |
|
with open(file_path, 'rb') as f: |
|
entry = pickle.load(f, encoding='latin1') |
|
self.data.append(entry['data']) |
|
if 'labels' in entry: |
|
self.targets.extend(entry['labels']) |
|
else: |
|
self.targets.extend(entry['fine_labels']) |
|
|
|
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) |
|
self.data = self.data.transpose((0, 2, 3, 1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index: int) -> Tuple[Any, Any]: |
|
""" |
|
Args: |
|
index (int): Index |
|
|
|
Returns: |
|
tuple: (image, target) where target is index of the target class. |
|
""" |
|
img, target = self.data[index], self.targets[index] |
|
|
|
|
|
|
|
img = Image.fromarray(img) |
|
|
|
if index % 2 == 0: |
|
img = TF.vflip(img) |
|
target = 1 |
|
|
|
if index % 2 != 0: |
|
target = 0 |
|
|
|
if self.transform is not None: |
|
img = self.transform(img) |
|
|
|
if self.target_transform is not None: |
|
target = self.target_transform(target) |
|
|
|
return img, target |
|
|
|
|
|
def __len__(self) -> int: |
|
return len(self.data) |
|
|
|
def download(self) -> None: |
|
|
|
|
|
|
|
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) |
|
|
|
def extra_repr(self) -> str: |
|
return "Split: {}".format("Train" if self.train is True else "Test") |