File size: 1,286 Bytes
71c714a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#!/usr/bin/env python3
"""
Module containing wrapper classes for PyTorch Datasets
Author: Shilpaj Bhalerao
Date: Jun 25, 2023
"""
# Standard Library Imports
from typing import Tuple

# Third-Party Imports
from torchvision import datasets, transforms


class AlbumDataset(datasets.CIFAR10):
    """
    Wrapper class to use albumentations library with PyTorch Dataset
    """
    def __init__(self, root: str = "./data", train: bool = True, download: bool = True, transform: list = None):
        """
        Constructor
        :param root: Directory at which data is stored
        :param train: Param to distinguish if data is training or test
        :param download: Param to download the dataset from source
        :param transform: List of transformation to be performed on the dataset
        """
        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index: int) -> Tuple:
        """
        Method to return image and its label
        :param index: Index of image and label in the dataset
        """
        image, label = self.data[index], self.targets[index]

        if self.transform:
            transformed = self.transform(image=image)
            image = transformed["image"]
        return image, label