Spaces:
Runtime error
Runtime error
File size: 5,379 Bytes
d4ab5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from .transformations import AddGaussianNoise
from abc import abstractmethod, ABCMeta
from argparse import ArgumentParser
from pytorch_lightning import LightningDataModule
from torch.utils.data import (
DataLoader,
Dataset,
default_collate,
RandomSampler,
SequentialSampler,
)
from torchvision import transforms
from typing import Optional
class ImageDataModule(LightningDataModule, metaclass=ABCMeta):
@staticmethod
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = parent_parser.add_argument_group("Data Modules")
parser.add_argument(
"--data_dir",
type=str,
default="data/",
help="The directory where the data is stored.",
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="The batch size to use.",
)
parser.add_argument(
"--add_noise",
action="store_true",
help="Use gaussian noise augmentation.",
)
parser.add_argument(
"--add_rotation",
action="store_true",
help="Use rotation augmentation.",
)
parser.add_argument(
"--add_blur",
action="store_true",
help="Use blur augmentation.",
)
parser.add_argument(
"--num_workers",
type=int,
default=4,
help="Number of workers to use for data loading.",
)
return parent_parser
# Declare variables that will be initialized later
train_data: Dataset
val_data: Dataset
test_data: Dataset
def __init__(
self,
feature_extractor: Optional[callable] = None,
data_dir: str = "data/",
batch_size: int = 32,
add_noise: bool = False,
add_rotation: bool = False,
add_blur: bool = False,
num_workers: int = 4,
):
"""Abstract Pytorch Lightning DataModule for image datasets.
Args:
feature_extractor (callable): feature extractor instance
data_dir (str): directory to store the dataset
batch_size (int): batch size for the train/val/test dataloaders
add_noise (bool): whether to add noise to the images
add_rotation (bool): whether to add random rotation to the images
add_blur (bool): whether to add blur to the images
num_workers (int): number of workers for train/val/test dataloaders
"""
super().__init__()
# Store hyperparameters
self.data_dir = data_dir
self.batch_size = batch_size
self.feature_extractor = feature_extractor
self.num_workers = num_workers
# Set the transforms
# If the feature_extractor is None, then we do not split the images into features
init_transforms = [feature_extractor] if feature_extractor else []
self.transform = transforms.Compose(init_transforms)
self._add_transforms(add_noise, add_rotation, add_blur)
# Set the collate function and the samplers
# These can be adapted in a child datamodule class to have a different behavior
self.collate_fn = default_collate
self.shuffled_sampler = RandomSampler
self.sequential_sampler = SequentialSampler
def _add_transforms(self, noise: bool, rotation: bool, blur: bool):
"""Add transforms to the module's transformations list.
Args:
noise (bool): whether to add noise to the images
rotation (bool): whether to add random rotation to the images
blur (bool): whether to add blur to the images
"""
# TODO:
# - Which order to add the transforms in?
# - Applied in both train and test or just test?
# - Check what transforms are applied by the model
if noise:
self.transform.transforms.append(AddGaussianNoise(0.0, 1.0))
if rotation:
self.transform.transforms.append(transforms.RandomRotation(20))
if blur:
self.transform.transforms.append(transforms.GaussianBlur(3))
@abstractmethod
def prepare_data(self):
raise NotImplementedError()
@abstractmethod
def setup(self, stage: Optional[str] = None):
raise NotImplementedError()
# noinspection PyTypeChecker
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_data,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
sampler=self.shuffled_sampler(self.train_data),
)
# noinspection PyTypeChecker
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.val_data,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
sampler=self.sequential_sampler(self.val_data),
)
# noinspection PyTypeChecker
def test_dataloader(self) -> DataLoader:
return DataLoader(
self.test_data,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
sampler=self.sequential_sampler(self.test_data),
)
|