|
|
|
""" |
|
Module containing wrapper classes for PyTorch Datasets |
|
Author: Shilpaj Bhalerao |
|
Date: Jun 25, 2023 |
|
""" |
|
|
|
from typing import Tuple |
|
|
|
|
|
from torchvision import datasets, transforms |
|
|
|
|
|
class AlbumDataset(datasets.CIFAR10): |
|
""" |
|
Wrapper class to use albumentations library with PyTorch Dataset |
|
""" |
|
def __init__(self, root: str = "./data", train: bool = True, download: bool = True, transform: list = None): |
|
""" |
|
Constructor |
|
:param root: Directory at which data is stored |
|
:param train: Param to distinguish if data is training or test |
|
:param download: Param to download the dataset from source |
|
:param transform: List of transformation to be performed on the dataset |
|
""" |
|
super().__init__(root=root, train=train, download=download, transform=transform) |
|
|
|
def __getitem__(self, index: int) -> Tuple: |
|
""" |
|
Method to return image and its label |
|
:param index: Index of image and label in the dataset |
|
""" |
|
image, label = self.data[index], self.targets[index] |
|
|
|
if self.transform: |
|
transformed = self.transform(image=image) |
|
image = transformed["image"] |
|
return image, label |
|
|