Soutrik commited on
Commit
3fa4d71
·
1 Parent(s): 8d4131e

local training check

Browse files
artifacts/image_prediction.png CHANGED
docker-compose.yaml CHANGED
@@ -5,7 +5,7 @@ services:
5
  build:
6
  context: .
7
  command: |
8
- python -m src.train_new experiment=catdog_experiment ++task_name=train ++train=True ++test=False && \
9
  touch /app/checkpoints/train_done.flag
10
  volumes:
11
  - ./data:/app/data
@@ -25,7 +25,7 @@ services:
25
  build:
26
  context: .
27
  command: |
28
- sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train_new experiment=catdog_experiment ++task_name=eval ++train=False ++test=True'
29
  volumes:
30
  - ./data:/app/data
31
  - ./checkpoints:/app/checkpoints
 
5
  build:
6
  context: .
7
  command: |
8
+ python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=train ++train=True ++test=False && \
9
  touch /app/checkpoints/train_done.flag
10
  volumes:
11
  - ./data:/app/data
 
25
  build:
26
  context: .
27
  command: |
28
+ sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=eval ++train=False ++test=True'
29
  volumes:
30
  - ./data:/app/data
31
  - ./checkpoints:/app/checkpoints
src/train_optuna_callbacks.py CHANGED
@@ -15,8 +15,10 @@ from src.utils.logging_utils import setup_logger, task_wrapper
15
  from loguru import logger
16
  import rootutils
17
  from lightning.pytorch.loggers import Logger
 
18
  import optuna
19
  from lightning.pytorch import Trainer
 
20
 
21
  # Load environment variables
22
  load_dotenv(find_dotenv(".env"))
@@ -25,7 +27,7 @@ load_dotenv(find_dotenv(".env"))
25
  root = rootutils.setup_root(__file__, indicator=".project-root")
26
 
27
 
28
- def instantiate_callbacks(callback_cfg: DictConfig) -> List[L.Callback]:
29
  """Instantiate and return a list of callbacks from the configuration."""
30
  callbacks: List[L.Callback] = []
31
 
@@ -125,7 +127,7 @@ def run_test_module(
125
  return test_metrics[0] if test_metrics else {}
126
 
127
 
128
- def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[L.Callback]):
129
  """Objective function for Optuna hyperparameter tuning."""
130
 
131
  # Sample hyperparameters for the model
@@ -144,9 +146,6 @@ def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[L.Call
144
  # Trainer configuration with passed callbacks
145
  trainer = Trainer(**cfg.trainer, logger=loggers, callbacks=callbacks)
146
 
147
- # Clear checkpoint directory
148
- clear_checkpoint_directory(cfg.paths.ckpt_dir)
149
-
150
  # Train and get val_acc for each epoch
151
  val_accuracies = train_module(data_module, model, trainer)
152
 
@@ -177,13 +176,16 @@ def setup_trainer(cfg: DictConfig):
177
  logger.info(f"Callbacks: {callbacks}")
178
 
179
  if cfg.get("train", False):
 
 
 
180
  pruner = optuna.pruners.MedianPruner()
181
  study = optuna.create_study(
182
  direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
183
  )
184
  study.optimize(
185
  lambda trial: objective(trial, cfg, callbacks),
186
- n_trials=5,
187
  show_progress_bar=True,
188
  )
189
 
@@ -194,7 +196,26 @@ def setup_trainer(cfg: DictConfig):
194
  for key, value in best_trial.params.items():
195
  logger.info(f" {key}: {value}")
196
 
 
 
 
 
 
 
 
197
  if cfg.get("test", False):
 
 
 
 
 
 
 
 
 
 
 
 
198
  data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
199
  model: L.LightningModule = hydra.utils.instantiate(cfg.model)
200
  trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))
 
15
  from loguru import logger
16
  import rootutils
17
  from lightning.pytorch.loggers import Logger
18
+ from lightning.pytorch.callbacks import Callback
19
  import optuna
20
  from lightning.pytorch import Trainer
21
+ import json
22
 
23
  # Load environment variables
24
  load_dotenv(find_dotenv(".env"))
 
27
  root = rootutils.setup_root(__file__, indicator=".project-root")
28
 
29
 
30
+ def instantiate_callbacks(callback_cfg: DictConfig) -> List[Callback]:
31
  """Instantiate and return a list of callbacks from the configuration."""
32
  callbacks: List[L.Callback] = []
33
 
 
127
  return test_metrics[0] if test_metrics else {}
128
 
129
 
130
+ def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[Callback]):
131
  """Objective function for Optuna hyperparameter tuning."""
132
 
133
  # Sample hyperparameters for the model
 
146
  # Trainer configuration with passed callbacks
147
  trainer = Trainer(**cfg.trainer, logger=loggers, callbacks=callbacks)
148
 
 
 
 
149
  # Train and get val_acc for each epoch
150
  val_accuracies = train_module(data_module, model, trainer)
151
 
 
176
  logger.info(f"Callbacks: {callbacks}")
177
 
178
  if cfg.get("train", False):
179
+ # Clear checkpoint directory
180
+ clear_checkpoint_directory(cfg.paths.ckpt_dir)
181
+ # find the best hyperparameters using Optuna and train the model
182
  pruner = optuna.pruners.MedianPruner()
183
  study = optuna.create_study(
184
  direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
185
  )
186
  study.optimize(
187
  lambda trial: objective(trial, cfg, callbacks),
188
+ n_trials=3,
189
  show_progress_bar=True,
190
  )
191
 
 
196
  for key, value in best_trial.params.items():
197
  logger.info(f" {key}: {value}")
198
 
199
+ # write the best hyperparameters to the config
200
+ best_hyperparams = {key: value for key, value in best_trial.params.items()}
201
+ best_hyperparams_path = Path(cfg.paths.ckpt_dir) / "best_hyperparams.json"
202
+ with open(best_hyperparams_path, "w") as f:
203
+ json.dump(best_hyperparams, f)
204
+ logger.info(f"Best hyperparameters saved to {best_hyperparams_path}")
205
+
206
  if cfg.get("test", False):
207
+ best_hyperparams_path = Path(cfg.paths.ckpt_dir) / "best_hyperparams.json"
208
+ if best_hyperparams_path.exists():
209
+ with open(best_hyperparams_path, "r") as f:
210
+ best_hyperparams = json.load(f)
211
+ cfg.model.update(best_hyperparams)
212
+ logger.info(f"Loaded best hyperparameters for testing: {best_hyperparams}")
213
+ else:
214
+ logger.error(
215
+ "Best hyperparameters not found! Using default hyperparameters."
216
+ )
217
+ raise FileNotFoundError("Best hyperparameters not found!")
218
+
219
  data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
220
  model: L.LightningModule = hydra.utils.instantiate(cfg.model)
221
  trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))