File size: 5,857 Bytes
dcb5590
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de7d21e
b0bdbcf
dcb5590
 
 
 
 
 
 
 
de7d21e
b0bdbcf
dcb5590
 
 
 
 
 
 
 
 
 
 
 
 
 
de7d21e
b0bdbcf
dcb5590
 
de7d21e
dcb5590
 
 
 
 
0f27535
b0bdbcf
 
0f27535
dcb5590
 
de7d21e
 
 
dcb5590
 
 
 
 
 
 
0f27535
 
 
 
 
 
 
 
 
 
b0bdbcf
 
dcb5590
 
0f27535
b0bdbcf
dcb5590
 
 
 
 
 
 
 
 
 
0f27535
dcb5590
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0bdbcf
 
 
dcb5590
 
b0bdbcf
dcb5590
b0bdbcf
dcb5590
 
b0bdbcf
dcb5590
 
b0bdbcf
dcb5590
b0bdbcf
 
dcb5590
de7d21e
dcb5590
 
 
 
 
 
 
b0bdbcf
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from pathlib import Path
from typing import Union, Tuple, Optional, List
import os
import lightning as L
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_and_extract_archive
from loguru import logger


class CatDogImageDataModule(L.LightningDataModule):
    """DataModule for Cat and Dog Image Classification using ImageFolder."""

    def __init__(
        self,
        root_dir: Union[str, Path] = "data",
        data_dir: Union[str, Path] = "cats_and_dogs_filtered",
        batch_size: int = 32,
        num_workers: int = 4,
        train_val_split: List[float] = [0.8, 0.2],
        pin_memory: bool = False,
        image_size: int = 224,
        url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
    ):
        super().__init__()
        self.root_dir = Path(root_dir)
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_val_split = train_val_split
        self.pin_memory = pin_memory
        self.image_size = image_size
        self.url = url

        # Initialize variables for datasets
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def prepare_data(self):
        """Download the dataset if it doesn't exist."""
        self.dataset_path = self.root_dir / self.data_dir
        if not self.dataset_path.exists():
            logger.info("Downloading and extracting dataset.")
            download_and_extract_archive(
                url=self.url, download_root=self.root_dir, remove_finished=True
            )
            logger.info("Download completed.")

    def setup(self, stage: Optional[str] = None):
        """Set up the train, validation, and test datasets."""

        self.prepare_data()

        train_transform = transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.RandomHorizontalFlip(0.5),  # Flip probability increased
                transforms.RandomRotation(5),  # Reduced rotation for stability
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        test_transform = transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        train_path = self.dataset_path / "train"
        test_path = self.dataset_path / "test"

        if stage == "fit" or stage is None:
            full_train_dataset = ImageFolder(root=train_path, transform=train_transform)
            self.class_names = full_train_dataset.classes
            train_size = int(self.train_val_split[0] * len(full_train_dataset))
            val_size = len(full_train_dataset) - train_size
            self.train_dataset, self.val_dataset = random_split(
                full_train_dataset, [train_size, val_size]
            )
            logger.info(
                f"Train/Validation split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation images."
            )

        if stage == "test" or stage is None:
            self.test_dataset = ImageFolder(root=test_path, transform=test_transform)
            logger.info(f"Test dataset size: {len(self.test_dataset)} images.")

    def _create_dataloader(self, dataset, shuffle: bool = False) -> DataLoader:
        """Helper function to create a DataLoader."""
        return DataLoader(
            dataset=dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            shuffle=shuffle,
        )

    def train_dataloader(self) -> DataLoader:
        return self._create_dataloader(self.train_dataset, shuffle=True)

    def val_dataloader(self) -> DataLoader:
        return self._create_dataloader(self.val_dataset)

    def test_dataloader(self) -> DataLoader:
        return self._create_dataloader(self.test_dataset)

    def get_class_names(self) -> List[str]:
        return self.class_names


if __name__ == "__main__":
    # Test the CatDogImageDataModule
    import hydra
    from omegaconf import DictConfig, OmegaConf
    import rootutils

    root = rootutils.setup_root(__file__, indicator=".project-root")

    @hydra.main(
        config_path=str(root / "configs"), version_base="1.3", config_name="train"
    )
    def test_datamodule(cfg: DictConfig):
        logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
        datamodule = CatDogImageDataModule(
            root_dir=cfg.data.root_dir,
            data_dir=cfg.data.data_dir,
            batch_size=cfg.data.batch_size,
            num_workers=cfg.data.num_workers,
            train_val_split=cfg.data.train_val_split,
            pin_memory=cfg.data.pin_memory,
            image_size=cfg.data.image_size,
        )
        datamodule.setup(stage="fit")
        train_loader = datamodule.train_dataloader()
        val_loader = datamodule.val_dataloader()
        datamodule.setup(stage="test")
        test_loader = datamodule.test_dataloader()
        class_names = datamodule.get_class_names()

        logger.info(f"Train loader: {len(train_loader)} batches")
        logger.info(f"Validation loader: {len(val_loader)} batches")
        logger.info(f"Test loader: {len(test_loader)} batches")
        logger.info(f"Class names: {class_names}")

    test_datamodule()