Soutrik commited on
Commit
b0bdbcf
1 Parent(s): aeaa968

datamodule new tested

Browse files
configs/data/catdog.yaml CHANGED
@@ -1,8 +1,9 @@
1
  _target_: src.datamodules.catdog_datamodule.CatDogImageDataModule
2
- data_dir: ${paths.data_dir}
 
3
  url: ${paths.data_url}
4
  num_workers: 4
5
  batch_size: 32
6
  train_val_split: [0.8, 0.2]
7
  pin_memory: False
8
- image_size: 160
 
1
  _target_: src.datamodules.catdog_datamodule.CatDogImageDataModule
2
+ root_dir: ${paths.data_dir}
3
+ data_dir: "cats_and_dogs_filtered"
4
  url: ${paths.data_url}
5
  num_workers: 4
6
  batch_size: 32
7
  train_val_split: [0.8, 0.2]
8
  pin_memory: False
9
+ image_size: 224
configs/experiment/catdog_experiment.yaml CHANGED
@@ -18,10 +18,11 @@ seed: 42
18
  name: "catdog_experiment"
19
 
20
  data:
21
- batch_size: 64
 
22
  num_workers: 8
23
  pin_memory: True
24
- image_size: 160
25
 
26
  model:
27
  lr: 1e-3
 
18
  name: "catdog_experiment"
19
 
20
  data:
21
+ dataset: "cats_and_dogs_filtered"
22
+ batch_size: 32
23
  num_workers: 8
24
  pin_memory: True
25
+ image_size: 224
26
 
27
  model:
28
  lr: 1e-3
notebooks/datamodule_lightning.ipynb CHANGED
@@ -53,13 +53,222 @@
53
  }
54
  ],
55
  "source": [
56
- "\n",
57
  "import os\n",
58
  "\n",
59
  "os.chdir(\"..\")\n",
60
  "print(os.getcwd())"
61
  ]
62
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  {
64
  "cell_type": "code",
65
  "execution_count": null,
 
53
  }
54
  ],
55
  "source": [
 
56
  "import os\n",
57
  "\n",
58
  "os.chdir(\"..\")\n",
59
  "print(os.getcwd())"
60
  ]
61
  },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 3,
65
+ "metadata": {},
66
+ "outputs": [
67
+ {
68
+ "name": "stderr",
69
+ "output_type": "stream",
70
+ "text": [
71
+ "/anaconda/envs/emlo_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
72
+ " from .autonotebook import tqdm as notebook_tqdm\n"
73
+ ]
74
+ }
75
+ ],
76
+ "source": [
77
+ "from pathlib import Path\n",
78
+ "from typing import Union, Tuple, Optional, List\n",
79
+ "import os\n",
80
+ "import lightning as L\n",
81
+ "from torch.utils.data import DataLoader, random_split\n",
82
+ "from torchvision import transforms\n",
83
+ "from torchvision.datasets import ImageFolder\n",
84
+ "from torchvision.datasets.utils import download_and_extract_archive\n",
85
+ "from loguru import logger"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": 32,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "class CatDogImageDataModule(L.LightningDataModule):\n",
95
+ " \"\"\"DataModule for Cat and Dog Image Classification using ImageFolder.\"\"\"\n",
96
+ "\n",
97
+ " def __init__(\n",
98
+ " self,\n",
99
+ " data_root: Union[str, Path] = \"data\",\n",
100
+ " data_dir: Union[str, Path] = \"cats_and_dogs_filtered\",\n",
101
+ " batch_size: int = 32,\n",
102
+ " num_workers: int = 4,\n",
103
+ " train_val_split: List[float] = [0.8, 0.2],\n",
104
+ " pin_memory: bool = False,\n",
105
+ " image_size: int = 224,\n",
106
+ " url: str = \"https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\",\n",
107
+ " ):\n",
108
+ " super().__init__()\n",
109
+ " self.data_root = Path(data_root)\n",
110
+ " self.data_dir = data_dir\n",
111
+ " self.batch_size = batch_size\n",
112
+ " self.num_workers = num_workers\n",
113
+ " self.train_val_split = train_val_split\n",
114
+ " self.pin_memory = pin_memory\n",
115
+ " self.image_size = image_size\n",
116
+ " self.url = url\n",
117
+ "\n",
118
+ " # Initialize variables for datasets\n",
119
+ " self.train_dataset = None\n",
120
+ " self.val_dataset = None\n",
121
+ " self.test_dataset = None\n",
122
+ "\n",
123
+ " def prepare_data(self):\n",
124
+ " \"\"\"Download the dataset if it doesn't exist.\"\"\"\n",
125
+ " self.dataset_path = self.data_root / self.data_dir\n",
126
+ " if not self.dataset_path.exists():\n",
127
+ " logger.info(\"Downloading and extracting dataset.\")\n",
128
+ " download_and_extract_archive(\n",
129
+ " url=self.url, download_root=self.data_root, remove_finished=True\n",
130
+ " )\n",
131
+ " logger.info(\"Download completed.\")\n",
132
+ "\n",
133
+ " def setup(self, stage: Optional[str] = None):\n",
134
+ " \"\"\"Set up the train, validation, and test datasets.\"\"\"\n",
135
+ "\n",
136
+ " train_transform = transforms.Compose(\n",
137
+ " [\n",
138
+ " transforms.Resize((self.image_size, self.image_size)),\n",
139
+ " transforms.RandomHorizontalFlip(0.1),\n",
140
+ " transforms.RandomRotation(10),\n",
141
+ " transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),\n",
142
+ " transforms.RandomAutocontrast(0.1),\n",
143
+ " transforms.RandomAdjustSharpness(2, 0.1),\n",
144
+ " transforms.ToTensor(),\n",
145
+ " transforms.Normalize(\n",
146
+ " mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
147
+ " ),\n",
148
+ " ]\n",
149
+ " )\n",
150
+ "\n",
151
+ " test_transform = transforms.Compose(\n",
152
+ " [\n",
153
+ " transforms.Resize((self.image_size, self.image_size)),\n",
154
+ " transforms.ToTensor(),\n",
155
+ " transforms.Normalize(\n",
156
+ " mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
157
+ " ),\n",
158
+ " ]\n",
159
+ " )\n",
160
+ "\n",
161
+ " train_path = self.dataset_path / \"train\"\n",
162
+ " test_path = self.dataset_path / \"test\"\n",
163
+ "\n",
164
+ " self.prepare_data()\n",
165
+ "\n",
166
+ " if stage == \"fit\" or stage is None:\n",
167
+ " full_train_dataset = ImageFolder(root=train_path, transform=train_transform)\n",
168
+ " self.class_names = full_train_dataset.classes\n",
169
+ " train_size = int(self.train_val_split[0] * len(full_train_dataset))\n",
170
+ " val_size = len(full_train_dataset) - train_size\n",
171
+ " self.train_dataset, self.val_dataset = random_split(\n",
172
+ " full_train_dataset, [train_size, val_size]\n",
173
+ " )\n",
174
+ " logger.info(\n",
175
+ " f\"Train/Validation split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation images.\"\n",
176
+ " )\n",
177
+ "\n",
178
+ " if stage == \"test\" or stage is None:\n",
179
+ " self.test_dataset = ImageFolder(root=test_path, transform=test_transform)\n",
180
+ " logger.info(f\"Test dataset size: {len(self.test_dataset)} images.\")\n",
181
+ "\n",
182
+ " def _create_dataloader(self, dataset, shuffle: bool = False) -> DataLoader:\n",
183
+ " \"\"\"Helper function to create a DataLoader.\"\"\"\n",
184
+ " return DataLoader(\n",
185
+ " dataset=dataset,\n",
186
+ " batch_size=self.batch_size,\n",
187
+ " num_workers=self.num_workers,\n",
188
+ " pin_memory=self.pin_memory,\n",
189
+ " shuffle=shuffle,\n",
190
+ " )\n",
191
+ "\n",
192
+ " def train_dataloader(self) -> DataLoader:\n",
193
+ " return self._create_dataloader(self.train_dataset, shuffle=True)\n",
194
+ "\n",
195
+ " def val_dataloader(self) -> DataLoader:\n",
196
+ " return self._create_dataloader(self.val_dataset)\n",
197
+ "\n",
198
+ " def test_dataloader(self) -> DataLoader:\n",
199
+ " return self._create_dataloader(self.test_dataset)\n",
200
+ "\n",
201
+ " def get_class_names(self) -> List[str]:\n",
202
+ " return self.class_names"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": 33,
208
+ "metadata": {},
209
+ "outputs": [],
210
+ "source": [
211
+ "datamodule = CatDogImageDataModule(\n",
212
+ " data_root=\"data\",\n",
213
+ " data_dir=\"cats_and_dogs_filtered\",\n",
214
+ " batch_size=32,\n",
215
+ " num_workers=4,\n",
216
+ " train_val_split=[0.8, 0.2],\n",
217
+ " pin_memory=True,\n",
218
+ " image_size=224,\n",
219
+ " url=\"https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\",\n",
220
+ ")"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": 35,
226
+ "metadata": {},
227
+ "outputs": [
228
+ {
229
+ "name": "stderr",
230
+ "output_type": "stream",
231
+ "text": [
232
+ "\u001b[32m2024-11-10 05:37:17.840\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m81\u001b[0m - \u001b[1mTrain/Validation split: 2241 train, 561 validation images.\u001b[0m\n"
233
+ ]
234
+ },
235
+ {
236
+ "name": "stderr",
237
+ "output_type": "stream",
238
+ "text": [
239
+ "\u001b[32m2024-11-10 05:37:17.910\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m87\u001b[0m - \u001b[1mTest dataset size: 198 images.\u001b[0m\n"
240
+ ]
241
+ }
242
+ ],
243
+ "source": [
244
+ "datamodule.prepare_data()\n",
245
+ "datamodule.setup()\n",
246
+ "class_names = datamodule.get_class_names()\n",
247
+ "train_dataloader = datamodule.train_dataloader()\n",
248
+ "val_dataloader= datamodule.val_dataloader()\n",
249
+ "test_dataloader= datamodule.test_dataloader()"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": 36,
255
+ "metadata": {},
256
+ "outputs": [
257
+ {
258
+ "data": {
259
+ "text/plain": [
260
+ "['cats', 'dogs']"
261
+ ]
262
+ },
263
+ "execution_count": 36,
264
+ "metadata": {},
265
+ "output_type": "execute_result"
266
+ }
267
+ ],
268
+ "source": [
269
+ "class_names"
270
+ ]
271
+ },
272
  {
273
  "cell_type": "code",
274
  "execution_count": null,
src/datamodules/catdog_datamodule.py CHANGED
@@ -14,7 +14,8 @@ class CatDogImageDataModule(L.LightningDataModule):
14
 
15
  def __init__(
16
  self,
17
- data_dir: Union[str, Path] = "data",
 
18
  batch_size: int = 32,
19
  num_workers: int = 4,
20
  train_val_split: List[float] = [0.8, 0.2],
@@ -23,7 +24,8 @@ class CatDogImageDataModule(L.LightningDataModule):
23
  url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
24
  ):
25
  super().__init__()
26
- self.data_dir = Path(data_dir)
 
27
  self.batch_size = batch_size
28
  self.num_workers = num_workers
29
  self.train_val_split = train_val_split
@@ -38,21 +40,27 @@ class CatDogImageDataModule(L.LightningDataModule):
38
 
39
  def prepare_data(self):
40
  """Download the dataset if it doesn't exist."""
41
- dataset_path = self.data_dir / "cats_and_dogs_filtered"
42
- if not dataset_path.exists():
43
  logger.info("Downloading and extracting dataset.")
44
  download_and_extract_archive(
45
- url=self.url, download_root=self.data_dir, remove_finished=True
46
  )
47
  logger.info("Download completed.")
48
 
49
  def setup(self, stage: Optional[str] = None):
50
  """Set up the train, validation, and test datasets."""
51
 
 
 
52
  train_transform = transforms.Compose(
53
  [
54
  transforms.Resize((self.image_size, self.image_size)),
55
- transforms.RandomHorizontalFlip(),
 
 
 
 
56
  transforms.ToTensor(),
57
  transforms.Normalize(
58
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
@@ -70,11 +78,12 @@ class CatDogImageDataModule(L.LightningDataModule):
70
  ]
71
  )
72
 
73
- train_path = self.data_dir / "cats_and_dogs_filtered" / "train"
74
- test_path = self.data_dir / "cats_and_dogs_filtered" / "validation"
75
 
76
  if stage == "fit" or stage is None:
77
  full_train_dataset = ImageFolder(root=train_path, transform=train_transform)
 
78
  train_size = int(self.train_val_split[0] * len(full_train_dataset))
79
  val_size = len(full_train_dataset) - train_size
80
  self.train_dataset, self.val_dataset = random_split(
@@ -107,43 +116,42 @@ class CatDogImageDataModule(L.LightningDataModule):
107
  def test_dataloader(self) -> DataLoader:
108
  return self._create_dataloader(self.test_dataset)
109
 
 
 
 
110
 
111
  if __name__ == "__main__":
112
- from omegaconf import DictConfig, OmegaConf
113
  import hydra
 
114
  import rootutils
115
 
116
- # Setup root directory
117
- root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
118
- logger.info(f"Root directory: {root}")
119
 
120
  @hydra.main(
121
- version_base="1.3",
122
- config_path=str(root / "configs"),
123
- config_name="train",
124
  )
125
- def main(cfg: DictConfig):
126
- # Log configuration
127
- logger.info("Config:\n" + OmegaConf.to_yaml(cfg))
128
-
129
- # Initialize DataModule
130
  datamodule = CatDogImageDataModule(
 
131
  data_dir=cfg.data.data_dir,
132
  batch_size=cfg.data.batch_size,
133
  num_workers=cfg.data.num_workers,
134
  train_val_split=cfg.data.train_val_split,
135
  pin_memory=cfg.data.pin_memory,
136
  image_size=cfg.data.image_size,
137
- url=cfg.data.url,
138
- )
139
- datamodule.prepare_data()
140
- datamodule.setup()
141
-
142
- # Log DataLoader sizes
143
- logger.info(f"Train DataLoader: {len(datamodule.train_dataloader())} batches")
144
- logger.info(
145
- f"Validation DataLoader: {len(datamodule.val_dataloader())} batches"
146
  )
147
- logger.info(f"Test DataLoader: {len(datamodule.test_dataloader())} batches")
148
-
149
- main()
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def __init__(
16
  self,
17
+ data_root: Union[str, Path] = "data",
18
+ data_dir: Union[str, Path] = "cats_and_dogs_filtered",
19
  batch_size: int = 32,
20
  num_workers: int = 4,
21
  train_val_split: List[float] = [0.8, 0.2],
 
24
  url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
25
  ):
26
  super().__init__()
27
+ self.data_root = Path(data_root)
28
+ self.data_dir = data_dir
29
  self.batch_size = batch_size
30
  self.num_workers = num_workers
31
  self.train_val_split = train_val_split
 
40
 
41
  def prepare_data(self):
42
  """Download the dataset if it doesn't exist."""
43
+ self.dataset_path = self.data_root / self.data_dir
44
+ if not self.dataset_path.exists():
45
  logger.info("Downloading and extracting dataset.")
46
  download_and_extract_archive(
47
+ url=self.url, download_root=self.data_root, remove_finished=True
48
  )
49
  logger.info("Download completed.")
50
 
51
  def setup(self, stage: Optional[str] = None):
52
  """Set up the train, validation, and test datasets."""
53
 
54
+ self.prepare_data()
55
+
56
  train_transform = transforms.Compose(
57
  [
58
  transforms.Resize((self.image_size, self.image_size)),
59
+ transforms.RandomHorizontalFlip(0.1),
60
+ transforms.RandomRotation(10),
61
+ transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
62
+ transforms.RandomAutocontrast(0.1),
63
+ transforms.RandomAdjustSharpness(2, 0.1),
64
  transforms.ToTensor(),
65
  transforms.Normalize(
66
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
 
78
  ]
79
  )
80
 
81
+ train_path = self.dataset_path / "train"
82
+ test_path = self.dataset_path / "test"
83
 
84
  if stage == "fit" or stage is None:
85
  full_train_dataset = ImageFolder(root=train_path, transform=train_transform)
86
+ self.class_names = full_train_dataset.classes
87
  train_size = int(self.train_val_split[0] * len(full_train_dataset))
88
  val_size = len(full_train_dataset) - train_size
89
  self.train_dataset, self.val_dataset = random_split(
 
116
  def test_dataloader(self) -> DataLoader:
117
  return self._create_dataloader(self.test_dataset)
118
 
119
+ def get_class_names(self) -> List[str]:
120
+ return self.class_names
121
+
122
 
123
  if __name__ == "__main__":
124
+ # Test the CatDogImageDataModule
125
  import hydra
126
+ from omegaconf import DictConfig, OmegaConf
127
  import rootutils
128
 
129
+ root = rootutils.setup_root(__file__, indicator=".project-root")
 
 
130
 
131
  @hydra.main(
132
+ config_path=str(root / "configs"), version_base="1.3", config_name="train"
 
 
133
  )
134
+ def test_datamodule(cfg: DictConfig):
135
+ logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
 
 
 
136
  datamodule = CatDogImageDataModule(
137
+ data_root=cfg.paths.data_dir,
138
  data_dir=cfg.data.data_dir,
139
  batch_size=cfg.data.batch_size,
140
  num_workers=cfg.data.num_workers,
141
  train_val_split=cfg.data.train_val_split,
142
  pin_memory=cfg.data.pin_memory,
143
  image_size=cfg.data.image_size,
 
 
 
 
 
 
 
 
 
144
  )
145
+ datamodule.setup(stage="fit")
146
+ train_loader = datamodule.train_dataloader()
147
+ val_loader = datamodule.val_dataloader()
148
+ datamodule.setup(stage="test")
149
+ test_loader = datamodule.test_dataloader()
150
+ class_names = datamodule.get_class_names()
151
+
152
+ logger.info(f"Train loader: {len(train_loader)} batches")
153
+ logger.info(f"Validation loader: {len(val_loader)} batches")
154
+ logger.info(f"Test loader: {len(test_loader)} batches")
155
+ logger.info(f"Class names: {class_names}")
156
+
157
+ test_datamodule()
src/train_new.py CHANGED
@@ -122,7 +122,7 @@ def run_test_module(
122
  return test_metrics[0] if test_metrics else {}
123
 
124
 
125
- @hydra.main(config_path="../configs", config_name="train", version_base="1.1")
126
  def setup_run_trainer(cfg: DictConfig):
127
  """Set up and run the Trainer for training and testing."""
128
  # Display configuration
 
122
  return test_metrics[0] if test_metrics else {}
123
 
124
 
125
+ @hydra.main(config_path="../configs", config_name="train", version_base="1.3")
126
  def setup_run_trainer(cfg: DictConfig):
127
  """Set up and run the Trainer for training and testing."""
128
  # Display configuration