Spaces:
Runtime error
Runtime error
Soutrik
commited on
Commit
•
b0bdbcf
1
Parent(s):
aeaa968
datamodule new tested
Browse files
configs/data/catdog.yaml
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
_target_: src.datamodules.catdog_datamodule.CatDogImageDataModule
|
2 |
-
|
|
|
3 |
url: ${paths.data_url}
|
4 |
num_workers: 4
|
5 |
batch_size: 32
|
6 |
train_val_split: [0.8, 0.2]
|
7 |
pin_memory: False
|
8 |
-
image_size:
|
|
|
1 |
_target_: src.datamodules.catdog_datamodule.CatDogImageDataModule
|
2 |
+
root_dir: ${paths.data_dir}
|
3 |
+
data_dir: "cats_and_dogs_filtered"
|
4 |
url: ${paths.data_url}
|
5 |
num_workers: 4
|
6 |
batch_size: 32
|
7 |
train_val_split: [0.8, 0.2]
|
8 |
pin_memory: False
|
9 |
+
image_size: 224
|
configs/experiment/catdog_experiment.yaml
CHANGED
@@ -18,10 +18,11 @@ seed: 42
|
|
18 |
name: "catdog_experiment"
|
19 |
|
20 |
data:
|
21 |
-
|
|
|
22 |
num_workers: 8
|
23 |
pin_memory: True
|
24 |
-
image_size:
|
25 |
|
26 |
model:
|
27 |
lr: 1e-3
|
|
|
18 |
name: "catdog_experiment"
|
19 |
|
20 |
data:
|
21 |
+
dataset: "cats_and_dogs_filtered"
|
22 |
+
batch_size: 32
|
23 |
num_workers: 8
|
24 |
pin_memory: True
|
25 |
+
image_size: 224
|
26 |
|
27 |
model:
|
28 |
lr: 1e-3
|
notebooks/datamodule_lightning.ipynb
CHANGED
@@ -53,13 +53,222 @@
|
|
53 |
}
|
54 |
],
|
55 |
"source": [
|
56 |
-
"\n",
|
57 |
"import os\n",
|
58 |
"\n",
|
59 |
"os.chdir(\"..\")\n",
|
60 |
"print(os.getcwd())"
|
61 |
]
|
62 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
{
|
64 |
"cell_type": "code",
|
65 |
"execution_count": null,
|
|
|
53 |
}
|
54 |
],
|
55 |
"source": [
|
|
|
56 |
"import os\n",
|
57 |
"\n",
|
58 |
"os.chdir(\"..\")\n",
|
59 |
"print(os.getcwd())"
|
60 |
]
|
61 |
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": 3,
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [
|
67 |
+
{
|
68 |
+
"name": "stderr",
|
69 |
+
"output_type": "stream",
|
70 |
+
"text": [
|
71 |
+
"/anaconda/envs/emlo_env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
72 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
73 |
+
]
|
74 |
+
}
|
75 |
+
],
|
76 |
+
"source": [
|
77 |
+
"from pathlib import Path\n",
|
78 |
+
"from typing import Union, Tuple, Optional, List\n",
|
79 |
+
"import os\n",
|
80 |
+
"import lightning as L\n",
|
81 |
+
"from torch.utils.data import DataLoader, random_split\n",
|
82 |
+
"from torchvision import transforms\n",
|
83 |
+
"from torchvision.datasets import ImageFolder\n",
|
84 |
+
"from torchvision.datasets.utils import download_and_extract_archive\n",
|
85 |
+
"from loguru import logger"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": 32,
|
91 |
+
"metadata": {},
|
92 |
+
"outputs": [],
|
93 |
+
"source": [
|
94 |
+
"class CatDogImageDataModule(L.LightningDataModule):\n",
|
95 |
+
" \"\"\"DataModule for Cat and Dog Image Classification using ImageFolder.\"\"\"\n",
|
96 |
+
"\n",
|
97 |
+
" def __init__(\n",
|
98 |
+
" self,\n",
|
99 |
+
" data_root: Union[str, Path] = \"data\",\n",
|
100 |
+
" data_dir: Union[str, Path] = \"cats_and_dogs_filtered\",\n",
|
101 |
+
" batch_size: int = 32,\n",
|
102 |
+
" num_workers: int = 4,\n",
|
103 |
+
" train_val_split: List[float] = [0.8, 0.2],\n",
|
104 |
+
" pin_memory: bool = False,\n",
|
105 |
+
" image_size: int = 224,\n",
|
106 |
+
" url: str = \"https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\",\n",
|
107 |
+
" ):\n",
|
108 |
+
" super().__init__()\n",
|
109 |
+
" self.data_root = Path(data_root)\n",
|
110 |
+
" self.data_dir = data_dir\n",
|
111 |
+
" self.batch_size = batch_size\n",
|
112 |
+
" self.num_workers = num_workers\n",
|
113 |
+
" self.train_val_split = train_val_split\n",
|
114 |
+
" self.pin_memory = pin_memory\n",
|
115 |
+
" self.image_size = image_size\n",
|
116 |
+
" self.url = url\n",
|
117 |
+
"\n",
|
118 |
+
" # Initialize variables for datasets\n",
|
119 |
+
" self.train_dataset = None\n",
|
120 |
+
" self.val_dataset = None\n",
|
121 |
+
" self.test_dataset = None\n",
|
122 |
+
"\n",
|
123 |
+
" def prepare_data(self):\n",
|
124 |
+
" \"\"\"Download the dataset if it doesn't exist.\"\"\"\n",
|
125 |
+
" self.dataset_path = self.data_root / self.data_dir\n",
|
126 |
+
" if not self.dataset_path.exists():\n",
|
127 |
+
" logger.info(\"Downloading and extracting dataset.\")\n",
|
128 |
+
" download_and_extract_archive(\n",
|
129 |
+
" url=self.url, download_root=self.data_root, remove_finished=True\n",
|
130 |
+
" )\n",
|
131 |
+
" logger.info(\"Download completed.\")\n",
|
132 |
+
"\n",
|
133 |
+
" def setup(self, stage: Optional[str] = None):\n",
|
134 |
+
" \"\"\"Set up the train, validation, and test datasets.\"\"\"\n",
|
135 |
+
"\n",
|
136 |
+
" train_transform = transforms.Compose(\n",
|
137 |
+
" [\n",
|
138 |
+
" transforms.Resize((self.image_size, self.image_size)),\n",
|
139 |
+
" transforms.RandomHorizontalFlip(0.1),\n",
|
140 |
+
" transforms.RandomRotation(10),\n",
|
141 |
+
" transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),\n",
|
142 |
+
" transforms.RandomAutocontrast(0.1),\n",
|
143 |
+
" transforms.RandomAdjustSharpness(2, 0.1),\n",
|
144 |
+
" transforms.ToTensor(),\n",
|
145 |
+
" transforms.Normalize(\n",
|
146 |
+
" mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
|
147 |
+
" ),\n",
|
148 |
+
" ]\n",
|
149 |
+
" )\n",
|
150 |
+
"\n",
|
151 |
+
" test_transform = transforms.Compose(\n",
|
152 |
+
" [\n",
|
153 |
+
" transforms.Resize((self.image_size, self.image_size)),\n",
|
154 |
+
" transforms.ToTensor(),\n",
|
155 |
+
" transforms.Normalize(\n",
|
156 |
+
" mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
|
157 |
+
" ),\n",
|
158 |
+
" ]\n",
|
159 |
+
" )\n",
|
160 |
+
"\n",
|
161 |
+
" train_path = self.dataset_path / \"train\"\n",
|
162 |
+
" test_path = self.dataset_path / \"test\"\n",
|
163 |
+
"\n",
|
164 |
+
" self.prepare_data()\n",
|
165 |
+
"\n",
|
166 |
+
" if stage == \"fit\" or stage is None:\n",
|
167 |
+
" full_train_dataset = ImageFolder(root=train_path, transform=train_transform)\n",
|
168 |
+
" self.class_names = full_train_dataset.classes\n",
|
169 |
+
" train_size = int(self.train_val_split[0] * len(full_train_dataset))\n",
|
170 |
+
" val_size = len(full_train_dataset) - train_size\n",
|
171 |
+
" self.train_dataset, self.val_dataset = random_split(\n",
|
172 |
+
" full_train_dataset, [train_size, val_size]\n",
|
173 |
+
" )\n",
|
174 |
+
" logger.info(\n",
|
175 |
+
" f\"Train/Validation split: {len(self.train_dataset)} train, {len(self.val_dataset)} validation images.\"\n",
|
176 |
+
" )\n",
|
177 |
+
"\n",
|
178 |
+
" if stage == \"test\" or stage is None:\n",
|
179 |
+
" self.test_dataset = ImageFolder(root=test_path, transform=test_transform)\n",
|
180 |
+
" logger.info(f\"Test dataset size: {len(self.test_dataset)} images.\")\n",
|
181 |
+
"\n",
|
182 |
+
" def _create_dataloader(self, dataset, shuffle: bool = False) -> DataLoader:\n",
|
183 |
+
" \"\"\"Helper function to create a DataLoader.\"\"\"\n",
|
184 |
+
" return DataLoader(\n",
|
185 |
+
" dataset=dataset,\n",
|
186 |
+
" batch_size=self.batch_size,\n",
|
187 |
+
" num_workers=self.num_workers,\n",
|
188 |
+
" pin_memory=self.pin_memory,\n",
|
189 |
+
" shuffle=shuffle,\n",
|
190 |
+
" )\n",
|
191 |
+
"\n",
|
192 |
+
" def train_dataloader(self) -> DataLoader:\n",
|
193 |
+
" return self._create_dataloader(self.train_dataset, shuffle=True)\n",
|
194 |
+
"\n",
|
195 |
+
" def val_dataloader(self) -> DataLoader:\n",
|
196 |
+
" return self._create_dataloader(self.val_dataset)\n",
|
197 |
+
"\n",
|
198 |
+
" def test_dataloader(self) -> DataLoader:\n",
|
199 |
+
" return self._create_dataloader(self.test_dataset)\n",
|
200 |
+
"\n",
|
201 |
+
" def get_class_names(self) -> List[str]:\n",
|
202 |
+
" return self.class_names"
|
203 |
+
]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"cell_type": "code",
|
207 |
+
"execution_count": 33,
|
208 |
+
"metadata": {},
|
209 |
+
"outputs": [],
|
210 |
+
"source": [
|
211 |
+
"datamodule = CatDogImageDataModule(\n",
|
212 |
+
" data_root=\"data\",\n",
|
213 |
+
" data_dir=\"cats_and_dogs_filtered\",\n",
|
214 |
+
" batch_size=32,\n",
|
215 |
+
" num_workers=4,\n",
|
216 |
+
" train_val_split=[0.8, 0.2],\n",
|
217 |
+
" pin_memory=True,\n",
|
218 |
+
" image_size=224,\n",
|
219 |
+
" url=\"https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip\",\n",
|
220 |
+
")"
|
221 |
+
]
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"cell_type": "code",
|
225 |
+
"execution_count": 35,
|
226 |
+
"metadata": {},
|
227 |
+
"outputs": [
|
228 |
+
{
|
229 |
+
"name": "stderr",
|
230 |
+
"output_type": "stream",
|
231 |
+
"text": [
|
232 |
+
"\u001b[32m2024-11-10 05:37:17.840\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m81\u001b[0m - \u001b[1mTrain/Validation split: 2241 train, 561 validation images.\u001b[0m\n"
|
233 |
+
]
|
234 |
+
},
|
235 |
+
{
|
236 |
+
"name": "stderr",
|
237 |
+
"output_type": "stream",
|
238 |
+
"text": [
|
239 |
+
"\u001b[32m2024-11-10 05:37:17.910\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36msetup\u001b[0m:\u001b[36m87\u001b[0m - \u001b[1mTest dataset size: 198 images.\u001b[0m\n"
|
240 |
+
]
|
241 |
+
}
|
242 |
+
],
|
243 |
+
"source": [
|
244 |
+
"datamodule.prepare_data()\n",
|
245 |
+
"datamodule.setup()\n",
|
246 |
+
"class_names = datamodule.get_class_names()\n",
|
247 |
+
"train_dataloader = datamodule.train_dataloader()\n",
|
248 |
+
"val_dataloader= datamodule.val_dataloader()\n",
|
249 |
+
"test_dataloader= datamodule.test_dataloader()"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "code",
|
254 |
+
"execution_count": 36,
|
255 |
+
"metadata": {},
|
256 |
+
"outputs": [
|
257 |
+
{
|
258 |
+
"data": {
|
259 |
+
"text/plain": [
|
260 |
+
"['cats', 'dogs']"
|
261 |
+
]
|
262 |
+
},
|
263 |
+
"execution_count": 36,
|
264 |
+
"metadata": {},
|
265 |
+
"output_type": "execute_result"
|
266 |
+
}
|
267 |
+
],
|
268 |
+
"source": [
|
269 |
+
"class_names"
|
270 |
+
]
|
271 |
+
},
|
272 |
{
|
273 |
"cell_type": "code",
|
274 |
"execution_count": null,
|
src/datamodules/catdog_datamodule.py
CHANGED
@@ -14,7 +14,8 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
14 |
|
15 |
def __init__(
|
16 |
self,
|
17 |
-
|
|
|
18 |
batch_size: int = 32,
|
19 |
num_workers: int = 4,
|
20 |
train_val_split: List[float] = [0.8, 0.2],
|
@@ -23,7 +24,8 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
23 |
url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
|
24 |
):
|
25 |
super().__init__()
|
26 |
-
self.
|
|
|
27 |
self.batch_size = batch_size
|
28 |
self.num_workers = num_workers
|
29 |
self.train_val_split = train_val_split
|
@@ -38,21 +40,27 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
38 |
|
39 |
def prepare_data(self):
|
40 |
"""Download the dataset if it doesn't exist."""
|
41 |
-
dataset_path = self.
|
42 |
-
if not dataset_path.exists():
|
43 |
logger.info("Downloading and extracting dataset.")
|
44 |
download_and_extract_archive(
|
45 |
-
url=self.url, download_root=self.
|
46 |
)
|
47 |
logger.info("Download completed.")
|
48 |
|
49 |
def setup(self, stage: Optional[str] = None):
|
50 |
"""Set up the train, validation, and test datasets."""
|
51 |
|
|
|
|
|
52 |
train_transform = transforms.Compose(
|
53 |
[
|
54 |
transforms.Resize((self.image_size, self.image_size)),
|
55 |
-
transforms.RandomHorizontalFlip(),
|
|
|
|
|
|
|
|
|
56 |
transforms.ToTensor(),
|
57 |
transforms.Normalize(
|
58 |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
@@ -70,11 +78,12 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
70 |
]
|
71 |
)
|
72 |
|
73 |
-
train_path = self.
|
74 |
-
test_path = self.
|
75 |
|
76 |
if stage == "fit" or stage is None:
|
77 |
full_train_dataset = ImageFolder(root=train_path, transform=train_transform)
|
|
|
78 |
train_size = int(self.train_val_split[0] * len(full_train_dataset))
|
79 |
val_size = len(full_train_dataset) - train_size
|
80 |
self.train_dataset, self.val_dataset = random_split(
|
@@ -107,43 +116,42 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
107 |
def test_dataloader(self) -> DataLoader:
|
108 |
return self._create_dataloader(self.test_dataset)
|
109 |
|
|
|
|
|
|
|
110 |
|
111 |
if __name__ == "__main__":
|
112 |
-
|
113 |
import hydra
|
|
|
114 |
import rootutils
|
115 |
|
116 |
-
|
117 |
-
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
118 |
-
logger.info(f"Root directory: {root}")
|
119 |
|
120 |
@hydra.main(
|
121 |
-
version_base="1.3",
|
122 |
-
config_path=str(root / "configs"),
|
123 |
-
config_name="train",
|
124 |
)
|
125 |
-
def
|
126 |
-
|
127 |
-
logger.info("Config:\n" + OmegaConf.to_yaml(cfg))
|
128 |
-
|
129 |
-
# Initialize DataModule
|
130 |
datamodule = CatDogImageDataModule(
|
|
|
131 |
data_dir=cfg.data.data_dir,
|
132 |
batch_size=cfg.data.batch_size,
|
133 |
num_workers=cfg.data.num_workers,
|
134 |
train_val_split=cfg.data.train_val_split,
|
135 |
pin_memory=cfg.data.pin_memory,
|
136 |
image_size=cfg.data.image_size,
|
137 |
-
url=cfg.data.url,
|
138 |
-
)
|
139 |
-
datamodule.prepare_data()
|
140 |
-
datamodule.setup()
|
141 |
-
|
142 |
-
# Log DataLoader sizes
|
143 |
-
logger.info(f"Train DataLoader: {len(datamodule.train_dataloader())} batches")
|
144 |
-
logger.info(
|
145 |
-
f"Validation DataLoader: {len(datamodule.val_dataloader())} batches"
|
146 |
)
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def __init__(
|
16 |
self,
|
17 |
+
data_root: Union[str, Path] = "data",
|
18 |
+
data_dir: Union[str, Path] = "cats_and_dogs_filtered",
|
19 |
batch_size: int = 32,
|
20 |
num_workers: int = 4,
|
21 |
train_val_split: List[float] = [0.8, 0.2],
|
|
|
24 |
url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
|
25 |
):
|
26 |
super().__init__()
|
27 |
+
self.data_root = Path(data_root)
|
28 |
+
self.data_dir = data_dir
|
29 |
self.batch_size = batch_size
|
30 |
self.num_workers = num_workers
|
31 |
self.train_val_split = train_val_split
|
|
|
40 |
|
41 |
def prepare_data(self):
|
42 |
"""Download the dataset if it doesn't exist."""
|
43 |
+
self.dataset_path = self.data_root / self.data_dir
|
44 |
+
if not self.dataset_path.exists():
|
45 |
logger.info("Downloading and extracting dataset.")
|
46 |
download_and_extract_archive(
|
47 |
+
url=self.url, download_root=self.data_root, remove_finished=True
|
48 |
)
|
49 |
logger.info("Download completed.")
|
50 |
|
51 |
def setup(self, stage: Optional[str] = None):
|
52 |
"""Set up the train, validation, and test datasets."""
|
53 |
|
54 |
+
self.prepare_data()
|
55 |
+
|
56 |
train_transform = transforms.Compose(
|
57 |
[
|
58 |
transforms.Resize((self.image_size, self.image_size)),
|
59 |
+
transforms.RandomHorizontalFlip(0.1),
|
60 |
+
transforms.RandomRotation(10),
|
61 |
+
transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
|
62 |
+
transforms.RandomAutocontrast(0.1),
|
63 |
+
transforms.RandomAdjustSharpness(2, 0.1),
|
64 |
transforms.ToTensor(),
|
65 |
transforms.Normalize(
|
66 |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
|
78 |
]
|
79 |
)
|
80 |
|
81 |
+
train_path = self.dataset_path / "train"
|
82 |
+
test_path = self.dataset_path / "test"
|
83 |
|
84 |
if stage == "fit" or stage is None:
|
85 |
full_train_dataset = ImageFolder(root=train_path, transform=train_transform)
|
86 |
+
self.class_names = full_train_dataset.classes
|
87 |
train_size = int(self.train_val_split[0] * len(full_train_dataset))
|
88 |
val_size = len(full_train_dataset) - train_size
|
89 |
self.train_dataset, self.val_dataset = random_split(
|
|
|
116 |
def test_dataloader(self) -> DataLoader:
|
117 |
return self._create_dataloader(self.test_dataset)
|
118 |
|
119 |
+
def get_class_names(self) -> List[str]:
|
120 |
+
return self.class_names
|
121 |
+
|
122 |
|
123 |
if __name__ == "__main__":
|
124 |
+
# Test the CatDogImageDataModule
|
125 |
import hydra
|
126 |
+
from omegaconf import DictConfig, OmegaConf
|
127 |
import rootutils
|
128 |
|
129 |
+
root = rootutils.setup_root(__file__, indicator=".project-root")
|
|
|
|
|
130 |
|
131 |
@hydra.main(
|
132 |
+
config_path=str(root / "configs"), version_base="1.3", config_name="train"
|
|
|
|
|
133 |
)
|
134 |
+
def test_datamodule(cfg: DictConfig):
|
135 |
+
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
|
|
|
|
|
|
|
136 |
datamodule = CatDogImageDataModule(
|
137 |
+
data_root=cfg.paths.data_dir,
|
138 |
data_dir=cfg.data.data_dir,
|
139 |
batch_size=cfg.data.batch_size,
|
140 |
num_workers=cfg.data.num_workers,
|
141 |
train_val_split=cfg.data.train_val_split,
|
142 |
pin_memory=cfg.data.pin_memory,
|
143 |
image_size=cfg.data.image_size,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
)
|
145 |
+
datamodule.setup(stage="fit")
|
146 |
+
train_loader = datamodule.train_dataloader()
|
147 |
+
val_loader = datamodule.val_dataloader()
|
148 |
+
datamodule.setup(stage="test")
|
149 |
+
test_loader = datamodule.test_dataloader()
|
150 |
+
class_names = datamodule.get_class_names()
|
151 |
+
|
152 |
+
logger.info(f"Train loader: {len(train_loader)} batches")
|
153 |
+
logger.info(f"Validation loader: {len(val_loader)} batches")
|
154 |
+
logger.info(f"Test loader: {len(test_loader)} batches")
|
155 |
+
logger.info(f"Class names: {class_names}")
|
156 |
+
|
157 |
+
test_datamodule()
|
src/train_new.py
CHANGED
@@ -122,7 +122,7 @@ def run_test_module(
|
|
122 |
return test_metrics[0] if test_metrics else {}
|
123 |
|
124 |
|
125 |
-
@hydra.main(config_path="../configs", config_name="train", version_base="1.
|
126 |
def setup_run_trainer(cfg: DictConfig):
|
127 |
"""Set up and run the Trainer for training and testing."""
|
128 |
# Display configuration
|
|
|
122 |
return test_metrics[0] if test_metrics else {}
|
123 |
|
124 |
|
125 |
+
@hydra.main(config_path="../configs", config_name="train", version_base="1.3")
|
126 |
def setup_run_trainer(cfg: DictConfig):
|
127 |
"""Set up and run the Trainer for training and testing."""
|
128 |
# Display configuration
|