Spaces:
Runtime error
Runtime error
Soutrik
commited on
Commit
·
3fa4d71
1
Parent(s):
8d4131e
local training check
Browse files- artifacts/image_prediction.png +0 -0
- docker-compose.yaml +2 -2
- src/train_optuna_callbacks.py +27 -6
artifacts/image_prediction.png
CHANGED
docker-compose.yaml
CHANGED
@@ -5,7 +5,7 @@ services:
|
|
5 |
build:
|
6 |
context: .
|
7 |
command: |
|
8 |
-
python -m src.
|
9 |
touch /app/checkpoints/train_done.flag
|
10 |
volumes:
|
11 |
- ./data:/app/data
|
@@ -25,7 +25,7 @@ services:
|
|
25 |
build:
|
26 |
context: .
|
27 |
command: |
|
28 |
-
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.
|
29 |
volumes:
|
30 |
- ./data:/app/data
|
31 |
- ./checkpoints:/app/checkpoints
|
|
|
5 |
build:
|
6 |
context: .
|
7 |
command: |
|
8 |
+
python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=train ++train=True ++test=False && \
|
9 |
touch /app/checkpoints/train_done.flag
|
10 |
volumes:
|
11 |
- ./data:/app/data
|
|
|
25 |
build:
|
26 |
context: .
|
27 |
command: |
|
28 |
+
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train_optuna_callbacks experiment=catdog_experiment ++task_name=eval ++train=False ++test=True'
|
29 |
volumes:
|
30 |
- ./data:/app/data
|
31 |
- ./checkpoints:/app/checkpoints
|
src/train_optuna_callbacks.py
CHANGED
@@ -15,8 +15,10 @@ from src.utils.logging_utils import setup_logger, task_wrapper
|
|
15 |
from loguru import logger
|
16 |
import rootutils
|
17 |
from lightning.pytorch.loggers import Logger
|
|
|
18 |
import optuna
|
19 |
from lightning.pytorch import Trainer
|
|
|
20 |
|
21 |
# Load environment variables
|
22 |
load_dotenv(find_dotenv(".env"))
|
@@ -25,7 +27,7 @@ load_dotenv(find_dotenv(".env"))
|
|
25 |
root = rootutils.setup_root(__file__, indicator=".project-root")
|
26 |
|
27 |
|
28 |
-
def instantiate_callbacks(callback_cfg: DictConfig) -> List[
|
29 |
"""Instantiate and return a list of callbacks from the configuration."""
|
30 |
callbacks: List[L.Callback] = []
|
31 |
|
@@ -125,7 +127,7 @@ def run_test_module(
|
|
125 |
return test_metrics[0] if test_metrics else {}
|
126 |
|
127 |
|
128 |
-
def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[
|
129 |
"""Objective function for Optuna hyperparameter tuning."""
|
130 |
|
131 |
# Sample hyperparameters for the model
|
@@ -144,9 +146,6 @@ def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[L.Call
|
|
144 |
# Trainer configuration with passed callbacks
|
145 |
trainer = Trainer(**cfg.trainer, logger=loggers, callbacks=callbacks)
|
146 |
|
147 |
-
# Clear checkpoint directory
|
148 |
-
clear_checkpoint_directory(cfg.paths.ckpt_dir)
|
149 |
-
|
150 |
# Train and get val_acc for each epoch
|
151 |
val_accuracies = train_module(data_module, model, trainer)
|
152 |
|
@@ -177,13 +176,16 @@ def setup_trainer(cfg: DictConfig):
|
|
177 |
logger.info(f"Callbacks: {callbacks}")
|
178 |
|
179 |
if cfg.get("train", False):
|
|
|
|
|
|
|
180 |
pruner = optuna.pruners.MedianPruner()
|
181 |
study = optuna.create_study(
|
182 |
direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
|
183 |
)
|
184 |
study.optimize(
|
185 |
lambda trial: objective(trial, cfg, callbacks),
|
186 |
-
n_trials=
|
187 |
show_progress_bar=True,
|
188 |
)
|
189 |
|
@@ -194,7 +196,26 @@ def setup_trainer(cfg: DictConfig):
|
|
194 |
for key, value in best_trial.params.items():
|
195 |
logger.info(f" {key}: {value}")
|
196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
if cfg.get("test", False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
199 |
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
|
200 |
trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))
|
|
|
15 |
from loguru import logger
|
16 |
import rootutils
|
17 |
from lightning.pytorch.loggers import Logger
|
18 |
+
from lightning.pytorch.callbacks import Callback
|
19 |
import optuna
|
20 |
from lightning.pytorch import Trainer
|
21 |
+
import json
|
22 |
|
23 |
# Load environment variables
|
24 |
load_dotenv(find_dotenv(".env"))
|
|
|
27 |
root = rootutils.setup_root(__file__, indicator=".project-root")
|
28 |
|
29 |
|
30 |
+
def instantiate_callbacks(callback_cfg: DictConfig) -> List[Callback]:
|
31 |
"""Instantiate and return a list of callbacks from the configuration."""
|
32 |
callbacks: List[L.Callback] = []
|
33 |
|
|
|
127 |
return test_metrics[0] if test_metrics else {}
|
128 |
|
129 |
|
130 |
+
def objective(trial: optuna.trial.Trial, cfg: DictConfig, callbacks: List[Callback]):
|
131 |
"""Objective function for Optuna hyperparameter tuning."""
|
132 |
|
133 |
# Sample hyperparameters for the model
|
|
|
146 |
# Trainer configuration with passed callbacks
|
147 |
trainer = Trainer(**cfg.trainer, logger=loggers, callbacks=callbacks)
|
148 |
|
|
|
|
|
|
|
149 |
# Train and get val_acc for each epoch
|
150 |
val_accuracies = train_module(data_module, model, trainer)
|
151 |
|
|
|
176 |
logger.info(f"Callbacks: {callbacks}")
|
177 |
|
178 |
if cfg.get("train", False):
|
179 |
+
# Clear checkpoint directory
|
180 |
+
clear_checkpoint_directory(cfg.paths.ckpt_dir)
|
181 |
+
# find the best hyperparameters using Optuna and train the model
|
182 |
pruner = optuna.pruners.MedianPruner()
|
183 |
study = optuna.create_study(
|
184 |
direction="maximize", pruner=pruner, study_name="pytorch_lightning_optuna"
|
185 |
)
|
186 |
study.optimize(
|
187 |
lambda trial: objective(trial, cfg, callbacks),
|
188 |
+
n_trials=3,
|
189 |
show_progress_bar=True,
|
190 |
)
|
191 |
|
|
|
196 |
for key, value in best_trial.params.items():
|
197 |
logger.info(f" {key}: {value}")
|
198 |
|
199 |
+
# write the best hyperparameters to the config
|
200 |
+
best_hyperparams = {key: value for key, value in best_trial.params.items()}
|
201 |
+
best_hyperparams_path = Path(cfg.paths.ckpt_dir) / "best_hyperparams.json"
|
202 |
+
with open(best_hyperparams_path, "w") as f:
|
203 |
+
json.dump(best_hyperparams, f)
|
204 |
+
logger.info(f"Best hyperparameters saved to {best_hyperparams_path}")
|
205 |
+
|
206 |
if cfg.get("test", False):
|
207 |
+
best_hyperparams_path = Path(cfg.paths.ckpt_dir) / "best_hyperparams.json"
|
208 |
+
if best_hyperparams_path.exists():
|
209 |
+
with open(best_hyperparams_path, "r") as f:
|
210 |
+
best_hyperparams = json.load(f)
|
211 |
+
cfg.model.update(best_hyperparams)
|
212 |
+
logger.info(f"Loaded best hyperparameters for testing: {best_hyperparams}")
|
213 |
+
else:
|
214 |
+
logger.error(
|
215 |
+
"Best hyperparameters not found! Using default hyperparameters."
|
216 |
+
)
|
217 |
+
raise FileNotFoundError("Best hyperparameters not found!")
|
218 |
+
|
219 |
data_module: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
220 |
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
|
221 |
trainer = Trainer(**cfg.trainer, logger=instantiate_loggers(cfg.logger))
|