Spaces:
Runtime error
Runtime error
from torch.utils import data | |
from PIL import Image | |
import os | |
class Dataset(data.Dataset): | |
'Characterizes a dataset for PyTorch' | |
def __init__(self, path, transform=None): | |
'Initialization' | |
self.file_names = self.get_filenames(path) | |
self.transform = transform | |
def __len__(self): | |
'Denotes the total number of samples' | |
return len(self.file_names) | |
def __getitem__(self, index): | |
'Generates one sample of data' | |
img = Image.open(self.file_names[index]).convert('RGB') | |
# Convert image and label to torch tensors | |
if self.transform is not None: | |
img = self.transform(img) | |
return img | |
def get_filenames(self, data_path): | |
images = [] | |
for path, subdirs, files in os.walk(data_path): | |
for name in files: | |
if name.rfind('jpg') != -1 or name.rfind('png') != -1: | |
filename = os.path.join(path, name) | |
if os.path.isfile(filename): | |
images.append(filename) | |
return images | |