File size: 5,379 Bytes
d4ab5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from .transformations import AddGaussianNoise
from abc import abstractmethod, ABCMeta
from argparse import ArgumentParser
from pytorch_lightning import LightningDataModule
from torch.utils.data import (
    DataLoader,
    Dataset,
    default_collate,
    RandomSampler,
    SequentialSampler,
)
from torchvision import transforms
from typing import Optional


class ImageDataModule(LightningDataModule, metaclass=ABCMeta):
    @staticmethod
    def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
        parser = parent_parser.add_argument_group("Data Modules")
        parser.add_argument(
            "--data_dir",
            type=str,
            default="data/",
            help="The directory where the data is stored.",
        )
        parser.add_argument(
            "--batch_size",
            type=int,
            default=32,
            help="The batch size to use.",
        )
        parser.add_argument(
            "--add_noise",
            action="store_true",
            help="Use gaussian noise augmentation.",
        )
        parser.add_argument(
            "--add_rotation",
            action="store_true",
            help="Use rotation augmentation.",
        )
        parser.add_argument(
            "--add_blur",
            action="store_true",
            help="Use blur augmentation.",
        )
        parser.add_argument(
            "--num_workers",
            type=int,
            default=4,
            help="Number of workers to use for data loading.",
        )
        return parent_parser

    # Declare variables that will be initialized later
    train_data: Dataset
    val_data: Dataset
    test_data: Dataset

    def __init__(
        self,
        feature_extractor: Optional[callable] = None,
        data_dir: str = "data/",
        batch_size: int = 32,
        add_noise: bool = False,
        add_rotation: bool = False,
        add_blur: bool = False,
        num_workers: int = 4,
    ):
        """Abstract Pytorch Lightning DataModule for image datasets.

        Args:
            feature_extractor (callable): feature extractor instance
            data_dir (str): directory to store the dataset
            batch_size (int): batch size for the train/val/test dataloaders
            add_noise (bool): whether to add noise to the images
            add_rotation (bool): whether to add random rotation to the images
            add_blur (bool): whether to add blur to the images
            num_workers (int): number of workers for train/val/test dataloaders
        """
        super().__init__()

        # Store hyperparameters
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.feature_extractor = feature_extractor
        self.num_workers = num_workers

        # Set the transforms
        # If the feature_extractor is None, then we do not split the images into features
        init_transforms = [feature_extractor] if feature_extractor else []
        self.transform = transforms.Compose(init_transforms)
        self._add_transforms(add_noise, add_rotation, add_blur)

        # Set the collate function and the samplers
        # These can be adapted in a child datamodule class to have a different behavior
        self.collate_fn = default_collate
        self.shuffled_sampler = RandomSampler
        self.sequential_sampler = SequentialSampler

    def _add_transforms(self, noise: bool, rotation: bool, blur: bool):
        """Add transforms to the module's transformations list.

        Args:
            noise (bool): whether to add noise to the images
            rotation (bool): whether to add random rotation to the images
            blur (bool): whether to add blur to the images
        """
        # TODO:
        # - Which order to add the transforms in?
        # - Applied in both train and test or just test?
        # - Check what transforms are applied by the model
        if noise:
            self.transform.transforms.append(AddGaussianNoise(0.0, 1.0))
        if rotation:
            self.transform.transforms.append(transforms.RandomRotation(20))
        if blur:
            self.transform.transforms.append(transforms.GaussianBlur(3))

    @abstractmethod
    def prepare_data(self):
        raise NotImplementedError()

    @abstractmethod
    def setup(self, stage: Optional[str] = None):
        raise NotImplementedError()

    # noinspection PyTypeChecker
    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            sampler=self.shuffled_sampler(self.train_data),
        )

    # noinspection PyTypeChecker
    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            sampler=self.sequential_sampler(self.val_data),
        )

    # noinspection PyTypeChecker
    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_data,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            sampler=self.sequential_sampler(self.test_data),
        )