File size: 6,307 Bytes
81d747c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
""" Quick n Simple Image Folder, Tarfile based DataSet
Hacked together by / Copyright 2019, Ross Wightman
"""
import io
import logging
from typing import Optional
import torch
import torch.utils.data as data
from PIL import Image
from .readers import create_reader
_logger = logging.getLogger(__name__)
_ERROR_RETRY = 50
class ImageDataset(data.Dataset):
def __init__(
self,
root,
reader=None,
split='train',
class_map=None,
load_bytes=False,
input_img_mode='RGB',
transform=None,
target_transform=None,
**kwargs,
):
if reader is None or isinstance(reader, str):
reader = create_reader(
reader or '',
root=root,
split=split,
class_map=class_map,
**kwargs,
)
self.reader = reader
self.load_bytes = load_bytes
self.input_img_mode = input_img_mode
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __getitem__(self, index):
img, target = self.reader[index]
try:
img = img.read() if self.load_bytes else Image.open(img)
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.reader.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.reader))
else:
raise e
self._consecutive_errors = 0
if self.input_img_mode and not self.load_bytes:
img = img.convert(self.input_img_mode)
if self.transform is not None:
img = self.transform(img)
if target is None:
target = -1
elif self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.reader)
def filename(self, index, basename=False, absolute=False):
return self.reader.filename(index, basename, absolute)
def filenames(self, basename=False, absolute=False):
return self.reader.filenames(basename, absolute)
class IterableImageDataset(data.IterableDataset):
def __init__(
self,
root,
reader=None,
split='train',
class_map=None,
is_training=False,
batch_size=1,
num_samples=None,
seed=42,
repeats=0,
download=False,
input_img_mode='RGB',
input_key=None,
target_key=None,
transform=None,
target_transform=None,
max_steps=None,
**kwargs,
):
assert reader is not None
if isinstance(reader, str):
self.reader = create_reader(
reader,
root=root,
split=split,
class_map=class_map,
is_training=is_training,
batch_size=batch_size,
num_samples=num_samples,
seed=seed,
repeats=repeats,
download=download,
input_img_mode=input_img_mode,
input_key=input_key,
target_key=target_key,
max_steps=max_steps,
**kwargs,
)
else:
self.reader = reader
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __iter__(self):
for img, target in self.reader:
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
yield img, target
def __len__(self):
if hasattr(self.reader, '__len__'):
return len(self.reader)
else:
return 0
def set_epoch(self, count):
# TFDS and WDS need external epoch count for deterministic cross process shuffle
if hasattr(self.reader, 'set_epoch'):
self.reader.set_epoch(count)
def set_loader_cfg(
self,
num_workers: Optional[int] = None,
):
# TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
if hasattr(self.reader, 'set_loader_cfg'):
self.reader.set_loader_cfg(num_workers=num_workers)
def filename(self, index, basename=False, absolute=False):
assert False, 'Filename lookup by index not supported, use filenames().'
def filenames(self, basename=False, absolute=False):
return self.reader.filenames(basename, absolute)
class AugMixDataset(torch.utils.data.Dataset):
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
def __init__(self, dataset, num_splits=2):
self.augmentation = None
self.normalize = None
self.dataset = dataset
if self.dataset.transform is not None:
self._set_transforms(self.dataset.transform)
self.num_splits = num_splits
def _set_transforms(self, x):
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
self.dataset.transform = x[0]
self.augmentation = x[1]
self.normalize = x[2]
@property
def transform(self):
return self.dataset.transform
@transform.setter
def transform(self, x):
self._set_transforms(x)
def _normalize(self, x):
return x if self.normalize is None else self.normalize(x)
def __getitem__(self, i):
x, y = self.dataset[i] # all splits share the same dataset base transform
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
# run the full augmentation on the remaining splits
for _ in range(self.num_splits - 1):
x_list.append(self._normalize(self.augmentation(x)))
return tuple(x_list), y
def __len__(self):
return len(self.dataset)
|