|
import sys |
|
import torch.utils.data as data |
|
from os import listdir |
|
from utils.tools import default_loader, is_image_file, normalize |
|
import os |
|
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
class Dataset(data.Dataset): |
|
def __init__(self, data_path, image_shape, with_subfolder=False, random_crop=True, return_name=False): |
|
super(Dataset, self).__init__() |
|
if with_subfolder: |
|
self.samples = self._find_samples_in_subfolders(data_path) |
|
else: |
|
self.samples = [x for x in listdir(data_path) if is_image_file(x)] |
|
self.data_path = data_path |
|
self.image_shape = image_shape[:-1] |
|
self.random_crop = random_crop |
|
self.return_name = return_name |
|
|
|
def __getitem__(self, index): |
|
path = os.path.join(self.data_path, self.samples[index]) |
|
img = default_loader(path) |
|
|
|
if self.random_crop: |
|
imgw, imgh = img.size |
|
if imgh < self.image_shape[0] or imgw < self.image_shape[1]: |
|
img = transforms.Resize(min(self.image_shape))(img) |
|
img = transforms.RandomCrop(self.image_shape)(img) |
|
else: |
|
img = transforms.Resize(self.image_shape)(img) |
|
img = transforms.RandomCrop(self.image_shape)(img) |
|
|
|
img = transforms.ToTensor()(img) |
|
img = normalize(img) |
|
|
|
if self.return_name: |
|
return self.samples[index], img |
|
else: |
|
return img |
|
|
|
def _find_samples_in_subfolders(self, dir): |
|
""" |
|
Finds the class folders in a dataset. |
|
Args: |
|
dir (string): Root directory path. |
|
Returns: |
|
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. |
|
Ensures: |
|
No class is a subdirectory of another. |
|
""" |
|
if sys.version_info >= (3, 5): |
|
|
|
classes = [d.name for d in os.scandir(dir) if d.is_dir()] |
|
else: |
|
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] |
|
classes.sort() |
|
class_to_idx = {classes[i]: i for i in range(len(classes))} |
|
samples = [] |
|
for target in sorted(class_to_idx.keys()): |
|
d = os.path.join(dir, target) |
|
if not os.path.isdir(d): |
|
continue |
|
for root, _, fnames in sorted(os.walk(d)): |
|
for fname in sorted(fnames): |
|
if is_image_file(fname): |
|
path = os.path.join(root, fname) |
|
|
|
|
|
samples.append(path) |
|
return samples |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|