Soutrik
optuna added as base
8d4131e
raw
history blame
9.04 kB
import os
import shutil
from pathlib import Path
import torch
import lightning as L
from lightning.pytorch.loggers import Logger
from typing import List
from src.datamodules.dogbreed_datamodule import main_dataloader
from src.utils.logging_utils import setup_logger, task_wrapper
from loguru import logger
from dotenv import load_dotenv, find_dotenv
import rootutils
import hydra
from omegaconf import DictConfig, OmegaConf
# Load environment variables
load_dotenv(find_dotenv(".env"))
# Setup root directory
root = rootutils.setup_root(__file__, indicator=".project-root")
def instantiate_callbacks(callback_cfg: DictConfig) -> List[L.Callback]:
"""Instantiate and return a list of callbacks from the configuration."""
callbacks: List[L.Callback] = []
if not callback_cfg:
logger.warning("No callback configs found! Skipping..")
return callbacks
if not isinstance(callback_cfg, DictConfig):
raise TypeError("Callbacks config must be a DictConfig!")
for _, cb_conf in callback_cfg.items():
if "_target_" in cb_conf:
logger.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
return callbacks
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
"""Instantiate and return a list of loggers from the configuration."""
loggers_ls: List[Logger] = []
if not logger_cfg:
logger.warning("No logger configs found! Skipping..")
return loggers_ls
if not isinstance(logger_cfg, DictConfig):
raise TypeError("Logger config must be a DictConfig!")
for _, lg_conf in logger_cfg.items():
if "_target_" in lg_conf:
logger.info(f"Instantiating logger <{lg_conf._target_}>")
loggers_ls.append(hydra.utils.instantiate(lg_conf))
return loggers_ls
def load_checkpoint_if_available(ckpt_path: str) -> str:
"""Check if the specified checkpoint exists and return the valid checkpoint path."""
if ckpt_path and Path(ckpt_path).exists():
logger.info(f"Checkpoint found: {ckpt_path}")
return ckpt_path
else:
logger.warning(
f"No checkpoint found at {ckpt_path}. Using current model weights."
)
return None
def clear_checkpoint_directory(ckpt_dir: str):
"""Clear all contents of the checkpoint directory without deleting the directory itself."""
ckpt_dir_path = Path(ckpt_dir)
if ckpt_dir_path.exists() and ckpt_dir_path.is_dir():
logger.info(f"Clearing checkpoint directory: {ckpt_dir}")
# Iterate over all files and directories in the checkpoint directory and remove them
for item in ckpt_dir_path.iterdir():
try:
if item.is_file() or item.is_symlink():
item.unlink() # Remove file or symlink
elif item.is_dir():
shutil.rmtree(item) # Remove directory
except Exception as e:
logger.error(f"Failed to delete {item}: {e}")
logger.info(f"Checkpoint directory cleared: {ckpt_dir}")
else:
logger.info(
f"Checkpoint directory does not exist. Creating directory: {ckpt_dir}"
)
os.makedirs(ckpt_dir_path, exist_ok=True)
@task_wrapper
def train_module(
cfg: DictConfig,
data_module: L.LightningDataModule,
model: L.LightningModule,
trainer: L.Trainer,
):
"""Train the model using the provided Trainer and DataModule."""
logger.info("Training the model")
trainer.fit(model, data_module)
train_metrics = trainer.callback_metrics
try:
logger.info(
f"Training completed with the following metrics- train_acc: {train_metrics['train_acc'].item()} and val_acc: {train_metrics['val_acc'].item()}"
)
except KeyError:
logger.info(f"Training completed with the following metrics:{train_metrics}")
return train_metrics
@task_wrapper
def run_test_module(
cfg: DictConfig,
datamodule: L.LightningDataModule,
model: L.LightningModule,
trainer: L.Trainer,
):
"""Test the model using the best checkpoint or the current model weights."""
logger.info("Testing the model")
datamodule.setup(stage="test")
ckpt_path = load_checkpoint_if_available(cfg.ckpt_path)
# If no checkpoint is available, Lightning will use current model weights
test_metrics = trainer.test(model, datamodule, ckpt_path=ckpt_path)
logger.info(f"Test metrics:\n{test_metrics}")
return test_metrics[0] if test_metrics else {}
@hydra.main(config_path="../configs", config_name="train", version_base="1.1")
def setup_run_trainer(cfg: DictConfig):
"""Set up and run the Trainer for training and testing the model."""
# show me the entire config
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
# Initialize logger
if cfg.task_name == "train":
log_path = Path(cfg.paths.log_dir) / "train.log"
else:
log_path = Path(cfg.paths.log_dir) / "eval.log"
setup_logger(log_path)
# the path to the checkpoint directory
root_dir = cfg.paths.root_dir
logger.info(f"Root directory: {root_dir}")
logger.info(f"Current working directory: {os.listdir(root_dir)}")
ckpt_dir = cfg.paths.ckpt_dir
logger.info(f"Checkpoint directory: {ckpt_dir}")
# the path to the data directory
data_dir = cfg.paths.data_dir
logger.info(f"Data directory: {data_dir}")
# the path to the log directory
log_dir = cfg.paths.log_dir
logger.info(f"Log directory: {log_dir}")
# the path to the artifact directory
artifact_dir = cfg.paths.artifact_dir
logger.info(f"Artifact directory: {artifact_dir}")
# output directory
output_dir = cfg.paths.output_dir
logger.info(f"Output directory: {output_dir}")
# name of the experiment
experiment_name = cfg.name
logger.info(f"Experiment name: {experiment_name}")
# Initialize DataModule
if experiment_name == "dogbreed_experiment":
logger.info("Setting up the DataModule")
dataset_df, datamodule = main_dataloader(cfg)
labels = dataset_df.label.nunique()
logger.info(f"Number of classes: {labels}")
os.makedirs(cfg.paths.artifact_dir, exist_ok=True)
dataset_df.to_csv(
Path(cfg.paths.artifact_dir) / "dogbreed_dataset.csv", index=False
)
elif (
experiment_name == "catdog_experiment"
or experiment_name == "catdog_experiment_convnext"
):
# Initialize DataModule
logger.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
# Check for GPU availability
logger.info("GPU available" if torch.cuda.is_available() else "No GPU available")
# Set seed for reproducibility
L.seed_everything(cfg.seed, workers=True)
# Initialize model
logger.info(f"Instantiating model <{cfg.model._target_}>")
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
logger.info(f"Model summary:\n{model}")
# Set up callbacks and loggers
logger.info("Setting up callbacks and loggers")
callbacks: List[L.Callback] = instantiate_callbacks(cfg.get("callbacks"))
logger.info(f"Callbacks: {callbacks}")
loggers: List[Logger] = instantiate_loggers(cfg.get("logger"))
logger.info(f"Loggers: {loggers}")
# Initialize Trainer
logger.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: L.Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=loggers
)
# Train and test the model based on config settings
train_metrics = {}
if cfg.get("train"):
# clear the checkpoint directory
clear_checkpoint_directory(cfg.paths.ckpt_dir)
logger.info("Training the model")
train_metrics = train_module(cfg, datamodule, model, trainer)
# Write training done flag using Hydra paths config
done_flag_path = Path(cfg.paths.ckpt_dir) / "train_done.flag"
with done_flag_path.open("w") as f:
f.write("Training completed.\n")
logger.info(f"Training completion flag written to: {done_flag_path}")
logger.info(
f"Training completed. Checkpoint directory: {os.listdir(cfg.paths.ckpt_dir)}"
)
test_metrics = {}
if cfg.get("test"):
logger.info(f"Checkpoint directory: {os.listdir(cfg.paths.ckpt_dir)}")
test_metrics = run_test_module(cfg, datamodule, model, trainer)
# Combine metrics
all_metrics = {**train_metrics, **test_metrics}
# Extract and return the optimization metric
optimization_metric = all_metrics.get(cfg.get("optimization_metric"))
if optimization_metric is None:
logger.warning(
f"Optimization metric '{cfg.get('optimization_metric')}' not found in metrics. Returning 0."
)
return 0.0
return optimization_metric
if __name__ == "__main__":
setup_run_trainer()