Soutrik commited on
Commit
969321c
β€’
1 Parent(s): 8931c9e

configs back at target level

Browse files
configs/callbacks/early_stopping.yaml CHANGED
@@ -1,6 +1,7 @@
1
  # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
2
 
3
  early_stopping:
 
4
  monitor: val_loss # quantity to be monitored, must be specified !!!
5
  min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
6
  patience: 3 # number of checks with no improvement after which training will be stopped
 
1
  # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
2
 
3
  early_stopping:
4
+ _target_: pytorch_lightning.callbacks.EarlyStopping
5
  monitor: val_loss # quantity to be monitored, must be specified !!!
6
  min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
7
  patience: 3 # number of checks with no improvement after which training will be stopped
configs/callbacks/model_checkpoint.yaml CHANGED
@@ -1,6 +1,7 @@
1
  # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
2
 
3
  model_checkpoint:
 
4
  dirpath: null # directory to save the model file
5
  filename: best-checkpoint # checkpoint filename
6
  monitor: val_loss # name of the logged metric which determines when model is improving
 
1
  # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
2
 
3
  model_checkpoint:
4
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
5
  dirpath: null # directory to save the model file
6
  filename: best-checkpoint # checkpoint filename
7
  monitor: val_loss # name of the logged metric which determines when model is improving
configs/callbacks/rich_model_summary.yaml CHANGED
@@ -1,2 +1,3 @@
1
  rich_model_summary:
 
2
  max_depth: 1
 
1
  rich_model_summary:
2
+ _target_: pytorch_lightning.callbacks.RichModelSummary
3
  max_depth: 1
configs/callbacks/rich_progress_bar.yaml CHANGED
@@ -1,2 +1,3 @@
1
  rich_progress_bar:
 
2
  refresh_rate: 1
 
1
  rich_progress_bar:
2
+ _target_: pytorch_lightning.callbacks.RichProgressBar
3
  refresh_rate: 1
configs/data/catdog.yaml CHANGED
@@ -1,5 +1,4 @@
1
  _target_: src.datamodules.catdog_datamodule.CatDogImageDataModule
2
-
3
  data_dir: ${paths.data_dir}
4
  url: ${paths.data_url}
5
  num_workers: 4
 
1
  _target_: src.datamodules.catdog_datamodule.CatDogImageDataModule
 
2
  data_dir: ${paths.data_dir}
3
  url: ${paths.data_url}
4
  num_workers: 4
configs/logger/aim.yaml CHANGED
@@ -1,4 +1,5 @@
1
  aim:
 
2
  experiment: ${name}
3
  train_metric_prefix: train_
4
  test_metric_prefix: test_
 
1
  aim:
2
+ _target_: aim.pytorch_lightning.AimLogger
3
  experiment: ${name}
4
  train_metric_prefix: train_
5
  test_metric_prefix: test_
configs/logger/csv.yaml CHANGED
@@ -1,6 +1,7 @@
1
  # csv logger built in lightning
2
 
3
  csv:
 
4
  save_dir: "${paths.output_dir}"
5
  name: "csv/"
6
  prefix: ""
 
1
  # csv logger built in lightning
2
 
3
  csv:
4
+ _target_: lightning.pytorch.loggers.csv_logs.CSVLogger
5
  save_dir: "${paths.output_dir}"
6
  name: "csv/"
7
  prefix: ""
configs/logger/mlflow.yaml CHANGED
@@ -1,6 +1,7 @@
1
  # MLflow logger configuration
2
 
3
  mlflow:
 
4
  experiment_name: ${name}
5
  tracking_uri: file:${paths.log_dir}/mlruns
6
  save_dir: ${paths.log_dir}/mlruns
 
1
  # MLflow logger configuration
2
 
3
  mlflow:
4
+ _target_: lightning.pytorch.loggers.MLFlowLogger
5
  experiment_name: ${name}
6
  tracking_uri: file:${paths.log_dir}/mlruns
7
  save_dir: ${paths.log_dir}/mlruns
configs/logger/tensorboard.yaml CHANGED
@@ -1,9 +1,10 @@
1
  # https://www.tensorflow.org/tensorboard/
2
 
3
  tensorboard:
 
4
  save_dir: "${paths.output_dir}/tensorboard/"
5
  name: null
6
  log_graph: False
7
  default_hp_metric: True
8
  prefix: ""
9
- # version: ""
 
1
  # https://www.tensorflow.org/tensorboard/
2
 
3
  tensorboard:
4
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
5
  save_dir: "${paths.output_dir}/tensorboard/"
6
  name: null
7
  log_graph: False
8
  default_hp_metric: True
9
  prefix: ""
10
+ # version: ""
notebooks/datamodule_lightning.ipynb ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "vscode": {
7
+ "languageId": "plaintext"
8
+ }
9
+ },
10
+ "source": [
11
+ "In this notebook, we will be discussing about the pytorch lightning datamodule library with images in a folder strutcture with folders as class labels. We will be using the cats and dogs dataset from kaggle. The dataset can be downloaded from [here](https://www.kaggle.com/c/dogs-vs-cats/data). The dataset contains 25000 images of cats and dogs. We will be using 20000 images for training and 5000 images for validation. The images are in a folder structure with folders as class labels."
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 1,
17
+ "metadata": {},
18
+ "outputs": [
19
+ {
20
+ "data": {
21
+ "application/javascript": "IPython.notebook.set_autosave_interval(300000)"
22
+ },
23
+ "metadata": {},
24
+ "output_type": "display_data"
25
+ },
26
+ {
27
+ "name": "stdout",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "Autosaving every 300 seconds\n"
31
+ ]
32
+ }
33
+ ],
34
+ "source": [
35
+ "%autosave 300\n",
36
+ "%load_ext autoreload\n",
37
+ "%autoreload 2\n",
38
+ "%reload_ext autoreload\n",
39
+ "%config Completer.use_jedi = False"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 2,
45
+ "metadata": {},
46
+ "outputs": [
47
+ {
48
+ "name": "stdout",
49
+ "output_type": "stream",
50
+ "text": [
51
+ "/mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws\n"
52
+ ]
53
+ }
54
+ ],
55
+ "source": [
56
+ "\n",
57
+ "import os\n",
58
+ "\n",
59
+ "os.chdir(\"..\")\n",
60
+ "print(os.getcwd())"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": []
69
+ }
70
+ ],
71
+ "metadata": {
72
+ "kernelspec": {
73
+ "display_name": "emlo_env",
74
+ "language": "python",
75
+ "name": "python3"
76
+ },
77
+ "language_info": {
78
+ "codemirror_mode": {
79
+ "name": "ipython",
80
+ "version": 3
81
+ },
82
+ "file_extension": ".py",
83
+ "mimetype": "text/x-python",
84
+ "name": "python",
85
+ "nbconvert_exporter": "python",
86
+ "pygments_lexer": "ipython3",
87
+ "version": "3.10.15"
88
+ }
89
+ },
90
+ "nbformat": 4,
91
+ "nbformat_minor": 2
92
+ }
notebooks/{training_lightning.ipynb β†’ training_lightning_tests.ipynb} RENAMED
File without changes
src/{hydra_test.py β†’ checks/hydra_test.py} RENAMED
File without changes
src/{hydra_test2.py β†’ checks/hydra_test2.py} RENAMED
File without changes
src/train.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import os
2
  import shutil
3
  from pathlib import Path
 
1
+ """
2
+ Train and evaluate a model using PyTorch Lightning.
3
+ Initializes the DataModule, Model, Trainer, and runs training and testing.
4
+ Initializes loggers and callbacks from the configuration using Hydra configuration but with a more modular approach without direct instantiation.
5
+ """
6
+
7
  import os
8
  import shutil
9
  from pathlib import Path
src/train_new.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train and evaluate a model using PyTorch Lightning.
3
+ Initializes the DataModule, Model, Trainer, and runs training and testing.
4
+ Initializes loggers and callbacks from the configuration using Hydra and target paths from the configuration.
5
+ """
6
+
7
+ import os
8
+ import shutil
9
+ from pathlib import Path
10
+ from typing import List
11
+ import torch
12
+ import lightning as L
13
+ from dotenv import load_dotenv, find_dotenv
14
+ import hydra
15
+ from omegaconf import DictConfig, OmegaConf
16
+ from src.utils.logging_utils import setup_logger, task_wrapper
17
+ from loguru import logger
18
+ import rootutils
19
+ from lightning.pytorch.loggers import Logger
20
+ from lightning.pytorch.callbacks import Callback
21
+
22
+ # Load environment variables
23
+ load_dotenv(find_dotenv(".env"))
24
+
25
+ # Setup root directory
26
+
27
+ root = rootutils.setup_root(__file__, indicator=".project-root")
28
+
29
+
30
+ def instantiate_callbacks(callback_cfg: DictConfig) -> List[Callback]:
31
+ """Instantiate and return a list of callbacks from the configuration."""
32
+ callbacks_ls: List[L.Callback] = []
33
+
34
+ if not callback_cfg:
35
+ logger.warning("No callback configs found! Skipping..")
36
+ return None
37
+
38
+ if not isinstance(callback_cfg, DictConfig):
39
+ raise TypeError("Callbacks config must be a DictConfig!")
40
+
41
+ for _, cb_conf in callback_cfg.items():
42
+ if "_target_" in cb_conf:
43
+ logger.info(f"Instantiating callback <{cb_conf._target_}>")
44
+ callbacks_ls.append(hydra.utils.instantiate(cb_conf))
45
+
46
+ return callbacks_ls
47
+
48
+
49
+ def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
50
+ """Instantiate and return a list of loggers from the configuration."""
51
+ loggers_ls: List[Logger] = []
52
+
53
+ if not logger_cfg:
54
+ logger.warning("No logger configs found! Skipping..")
55
+ return loggers_ls
56
+
57
+ if not isinstance(logger_cfg, DictConfig):
58
+ raise TypeError("Logger config must be a DictConfig!")
59
+
60
+ for _, lg_conf in logger_cfg.items():
61
+ if "_target_" in lg_conf:
62
+ logger.info(f"Instantiating logger <{lg_conf._target_}>")
63
+ loggers_ls.append(hydra.utils.instantiate(lg_conf))
64
+
65
+ return loggers_ls
66
+
67
+
68
+ def load_checkpoint_if_available(ckpt_path: str) -> str:
69
+ """Return the checkpoint path if available, else None."""
70
+ if ckpt_path and Path(ckpt_path).exists():
71
+ logger.info(f"Using checkpoint: {ckpt_path}")
72
+ return ckpt_path
73
+ logger.warning(f"Checkpoint not found at {ckpt_path}. Using current model weights.")
74
+ return None
75
+
76
+
77
+ def clear_checkpoint_directory(ckpt_dir: str):
78
+ """Clear checkpoint directory contents without removing the directory."""
79
+ ckpt_dir_path = Path(ckpt_dir)
80
+ if not ckpt_dir_path.exists():
81
+ logger.info(f"Creating checkpoint directory: {ckpt_dir}")
82
+ ckpt_dir_path.mkdir(parents=True, exist_ok=True)
83
+ else:
84
+ logger.info(f"Clearing checkpoint directory: {ckpt_dir}")
85
+ for item in ckpt_dir_path.iterdir():
86
+ try:
87
+ item.unlink() if item.is_file() else shutil.rmtree(item)
88
+ except Exception as e:
89
+ logger.error(f"Failed to delete {item}: {e}")
90
+
91
+
92
+ @task_wrapper
93
+ def train_module(
94
+ data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
95
+ ):
96
+ """Train the model and log metrics."""
97
+ logger.info("Starting training")
98
+ trainer.fit(model, data_module)
99
+ train_metrics = trainer.callback_metrics
100
+ train_acc = train_metrics.get("train_acc")
101
+ val_acc = train_metrics.get("val_acc")
102
+ logger.info(
103
+ f"Training completed. Metrics - train_acc: {train_acc}, val_acc: {val_acc}"
104
+ )
105
+ return train_metrics
106
+
107
+
108
+ @task_wrapper
109
+ def run_test_module(
110
+ cfg: DictConfig,
111
+ datamodule: L.LightningDataModule,
112
+ model: L.LightningModule,
113
+ trainer: L.Trainer,
114
+ ):
115
+ """Test the model using the best checkpoint or current model weights."""
116
+ logger.info("Starting testing")
117
+ datamodule.setup(stage="test")
118
+ test_metrics = trainer.test(
119
+ model, datamodule, ckpt_path=load_checkpoint_if_available(cfg.ckpt_path)
120
+ )
121
+ logger.info(f"Test metrics: {test_metrics}")
122
+ return test_metrics[0] if test_metrics else {}
123
+
124
+
125
+ @hydra.main(config_path="../configs", config_name="train", version_base="1.1")
126
+ def setup_run_trainer(cfg: DictConfig):
127
+ """Set up and run the Trainer for training and testing."""
128
+ # Display configuration
129
+ logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
130
+
131
+ # Initialize logger
132
+ log_path = Path(cfg.paths.log_dir) / (
133
+ "train.log" if cfg.task_name == "train" else "eval.log"
134
+ )
135
+ setup_logger(log_path)
136
+
137
+ # Display key paths
138
+ for path_name in [
139
+ "root_dir",
140
+ "data_dir",
141
+ "log_dir",
142
+ "ckpt_dir",
143
+ "artifact_dir",
144
+ "output_dir",
145
+ ]:
146
+ logger.info(
147
+ f"{path_name.replace('_', ' ').capitalize()}: {cfg.paths[path_name]}"
148
+ )
149
+
150
+ # Initialize DataModule and Model
151
+ logger.info(f"Instantiating datamodule <{cfg.data._target_}>")
152
+ datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
153
+ logger.info(f"Instantiating model <{cfg.model._target_}>")
154
+ model: L.LightningModule = hydra.utils.instantiate(cfg.model)
155
+
156
+ # Check GPU availability and set seed for reproducibility
157
+ logger.info("GPU available" if torch.cuda.is_available() else "No GPU available")
158
+ L.seed_everything(cfg.seed, workers=True)
159
+
160
+ # Set up callbacks, loggers, and Trainer
161
+ callbacks = instantiate_callbacks(cfg.callbacks)
162
+ logger.info(f"Callbacks: {callbacks}")
163
+ loggers = instantiate_loggers(cfg.loggers)
164
+ logger.info(f"Loggers: {loggers}")
165
+ trainer: L.Trainer = hydra.utils.instantiate(
166
+ cfg.trainer, callbacks=callbacks, logger=loggers
167
+ )
168
+
169
+ # Training phase
170
+ train_metrics = {}
171
+ if cfg.get("train"):
172
+ clear_checkpoint_directory(cfg.paths.ckpt_dir)
173
+ train_metrics = train_module(datamodule, model, trainer)
174
+ (Path(cfg.paths.ckpt_dir) / "train_done.flag").write_text(
175
+ "Training completed.\n"
176
+ )
177
+
178
+ # Testing phase
179
+ test_metrics = {}
180
+ if cfg.get("test"):
181
+ test_metrics = run_test_module(cfg, datamodule, model, trainer)
182
+
183
+ # Combine metrics and extract optimization metric
184
+ all_metrics = {**train_metrics, **test_metrics}
185
+ optimization_metric = all_metrics.get(cfg.get("optimization_metric"), 0.0)
186
+ (
187
+ logger.warning(
188
+ f"Optimization metric '{cfg.get('optimization_metric')}' not found. Defaulting to 0."
189
+ )
190
+ if optimization_metric == 0.0
191
+ else logger.info(f"Optimization metric: {optimization_metric}")
192
+ )
193
+
194
+ return optimization_metric
195
+
196
+
197
+ if __name__ == "__main__":
198
+ setup_run_trainer()