Spaces:
Runtime error
Runtime error
Soutrik
commited on
Commit
β’
969321c
1
Parent(s):
8931c9e
configs back at target level
Browse files- configs/callbacks/early_stopping.yaml +1 -0
- configs/callbacks/model_checkpoint.yaml +1 -0
- configs/callbacks/rich_model_summary.yaml +1 -0
- configs/callbacks/rich_progress_bar.yaml +1 -0
- configs/data/catdog.yaml +0 -1
- configs/logger/aim.yaml +1 -0
- configs/logger/csv.yaml +1 -0
- configs/logger/mlflow.yaml +1 -0
- configs/logger/tensorboard.yaml +2 -1
- notebooks/datamodule_lightning.ipynb +92 -0
- notebooks/{training_lightning.ipynb β training_lightning_tests.ipynb} +0 -0
- src/{hydra_test.py β checks/hydra_test.py} +0 -0
- src/{hydra_test2.py β checks/hydra_test2.py} +0 -0
- src/train.py +6 -0
- src/train_new.py +198 -0
configs/callbacks/early_stopping.yaml
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
|
2 |
|
3 |
early_stopping:
|
|
|
4 |
monitor: val_loss # quantity to be monitored, must be specified !!!
|
5 |
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
6 |
patience: 3 # number of checks with no improvement after which training will be stopped
|
|
|
1 |
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
|
2 |
|
3 |
early_stopping:
|
4 |
+
_target_: pytorch_lightning.callbacks.EarlyStopping
|
5 |
monitor: val_loss # quantity to be monitored, must be specified !!!
|
6 |
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
7 |
patience: 3 # number of checks with no improvement after which training will be stopped
|
configs/callbacks/model_checkpoint.yaml
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
2 |
|
3 |
model_checkpoint:
|
|
|
4 |
dirpath: null # directory to save the model file
|
5 |
filename: best-checkpoint # checkpoint filename
|
6 |
monitor: val_loss # name of the logged metric which determines when model is improving
|
|
|
1 |
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
2 |
|
3 |
model_checkpoint:
|
4 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
5 |
dirpath: null # directory to save the model file
|
6 |
filename: best-checkpoint # checkpoint filename
|
7 |
monitor: val_loss # name of the logged metric which determines when model is improving
|
configs/callbacks/rich_model_summary.yaml
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
rich_model_summary:
|
|
|
2 |
max_depth: 1
|
|
|
1 |
rich_model_summary:
|
2 |
+
_target_: pytorch_lightning.callbacks.RichModelSummary
|
3 |
max_depth: 1
|
configs/callbacks/rich_progress_bar.yaml
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
rich_progress_bar:
|
|
|
2 |
refresh_rate: 1
|
|
|
1 |
rich_progress_bar:
|
2 |
+
_target_: pytorch_lightning.callbacks.RichProgressBar
|
3 |
refresh_rate: 1
|
configs/data/catdog.yaml
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
_target_: src.datamodules.catdog_datamodule.CatDogImageDataModule
|
2 |
-
|
3 |
data_dir: ${paths.data_dir}
|
4 |
url: ${paths.data_url}
|
5 |
num_workers: 4
|
|
|
1 |
_target_: src.datamodules.catdog_datamodule.CatDogImageDataModule
|
|
|
2 |
data_dir: ${paths.data_dir}
|
3 |
url: ${paths.data_url}
|
4 |
num_workers: 4
|
configs/logger/aim.yaml
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
aim:
|
|
|
2 |
experiment: ${name}
|
3 |
train_metric_prefix: train_
|
4 |
test_metric_prefix: test_
|
|
|
1 |
aim:
|
2 |
+
_target_: aim.pytorch_lightning.AimLogger
|
3 |
experiment: ${name}
|
4 |
train_metric_prefix: train_
|
5 |
test_metric_prefix: test_
|
configs/logger/csv.yaml
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# csv logger built in lightning
|
2 |
|
3 |
csv:
|
|
|
4 |
save_dir: "${paths.output_dir}"
|
5 |
name: "csv/"
|
6 |
prefix: ""
|
|
|
1 |
# csv logger built in lightning
|
2 |
|
3 |
csv:
|
4 |
+
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
|
5 |
save_dir: "${paths.output_dir}"
|
6 |
name: "csv/"
|
7 |
prefix: ""
|
configs/logger/mlflow.yaml
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
# MLflow logger configuration
|
2 |
|
3 |
mlflow:
|
|
|
4 |
experiment_name: ${name}
|
5 |
tracking_uri: file:${paths.log_dir}/mlruns
|
6 |
save_dir: ${paths.log_dir}/mlruns
|
|
|
1 |
# MLflow logger configuration
|
2 |
|
3 |
mlflow:
|
4 |
+
_target_: lightning.pytorch.loggers.MLFlowLogger
|
5 |
experiment_name: ${name}
|
6 |
tracking_uri: file:${paths.log_dir}/mlruns
|
7 |
save_dir: ${paths.log_dir}/mlruns
|
configs/logger/tensorboard.yaml
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
# https://www.tensorflow.org/tensorboard/
|
2 |
|
3 |
tensorboard:
|
|
|
4 |
save_dir: "${paths.output_dir}/tensorboard/"
|
5 |
name: null
|
6 |
log_graph: False
|
7 |
default_hp_metric: True
|
8 |
prefix: ""
|
9 |
-
# version: ""
|
|
|
1 |
# https://www.tensorflow.org/tensorboard/
|
2 |
|
3 |
tensorboard:
|
4 |
+
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
|
5 |
save_dir: "${paths.output_dir}/tensorboard/"
|
6 |
name: null
|
7 |
log_graph: False
|
8 |
default_hp_metric: True
|
9 |
prefix: ""
|
10 |
+
# version: ""
|
notebooks/datamodule_lightning.ipynb
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"vscode": {
|
7 |
+
"languageId": "plaintext"
|
8 |
+
}
|
9 |
+
},
|
10 |
+
"source": [
|
11 |
+
"In this notebook, we will be discussing about the pytorch lightning datamodule library with images in a folder strutcture with folders as class labels. We will be using the cats and dogs dataset from kaggle. The dataset can be downloaded from [here](https://www.kaggle.com/c/dogs-vs-cats/data). The dataset contains 25000 images of cats and dogs. We will be using 20000 images for training and 5000 images for validation. The images are in a folder structure with folders as class labels."
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 1,
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [
|
19 |
+
{
|
20 |
+
"data": {
|
21 |
+
"application/javascript": "IPython.notebook.set_autosave_interval(300000)"
|
22 |
+
},
|
23 |
+
"metadata": {},
|
24 |
+
"output_type": "display_data"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"name": "stdout",
|
28 |
+
"output_type": "stream",
|
29 |
+
"text": [
|
30 |
+
"Autosaving every 300 seconds\n"
|
31 |
+
]
|
32 |
+
}
|
33 |
+
],
|
34 |
+
"source": [
|
35 |
+
"%autosave 300\n",
|
36 |
+
"%load_ext autoreload\n",
|
37 |
+
"%autoreload 2\n",
|
38 |
+
"%reload_ext autoreload\n",
|
39 |
+
"%config Completer.use_jedi = False"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": 2,
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [
|
47 |
+
{
|
48 |
+
"name": "stdout",
|
49 |
+
"output_type": "stream",
|
50 |
+
"text": [
|
51 |
+
"/mnt/batch/tasks/shared/LS_root/mounts/clusters/soutrik-vm-dev/code/Users/Soutrik.Chowdhury/pytorch-template-aws\n"
|
52 |
+
]
|
53 |
+
}
|
54 |
+
],
|
55 |
+
"source": [
|
56 |
+
"\n",
|
57 |
+
"import os\n",
|
58 |
+
"\n",
|
59 |
+
"os.chdir(\"..\")\n",
|
60 |
+
"print(os.getcwd())"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": null,
|
66 |
+
"metadata": {},
|
67 |
+
"outputs": [],
|
68 |
+
"source": []
|
69 |
+
}
|
70 |
+
],
|
71 |
+
"metadata": {
|
72 |
+
"kernelspec": {
|
73 |
+
"display_name": "emlo_env",
|
74 |
+
"language": "python",
|
75 |
+
"name": "python3"
|
76 |
+
},
|
77 |
+
"language_info": {
|
78 |
+
"codemirror_mode": {
|
79 |
+
"name": "ipython",
|
80 |
+
"version": 3
|
81 |
+
},
|
82 |
+
"file_extension": ".py",
|
83 |
+
"mimetype": "text/x-python",
|
84 |
+
"name": "python",
|
85 |
+
"nbconvert_exporter": "python",
|
86 |
+
"pygments_lexer": "ipython3",
|
87 |
+
"version": "3.10.15"
|
88 |
+
}
|
89 |
+
},
|
90 |
+
"nbformat": 4,
|
91 |
+
"nbformat_minor": 2
|
92 |
+
}
|
notebooks/{training_lightning.ipynb β training_lightning_tests.ipynb}
RENAMED
File without changes
|
src/{hydra_test.py β checks/hydra_test.py}
RENAMED
File without changes
|
src/{hydra_test2.py β checks/hydra_test2.py}
RENAMED
File without changes
|
src/train.py
CHANGED
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import shutil
|
3 |
from pathlib import Path
|
|
|
1 |
+
"""
|
2 |
+
Train and evaluate a model using PyTorch Lightning.
|
3 |
+
Initializes the DataModule, Model, Trainer, and runs training and testing.
|
4 |
+
Initializes loggers and callbacks from the configuration using Hydra configuration but with a more modular approach without direct instantiation.
|
5 |
+
"""
|
6 |
+
|
7 |
import os
|
8 |
import shutil
|
9 |
from pathlib import Path
|
src/train_new.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Train and evaluate a model using PyTorch Lightning.
|
3 |
+
Initializes the DataModule, Model, Trainer, and runs training and testing.
|
4 |
+
Initializes loggers and callbacks from the configuration using Hydra and target paths from the configuration.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import List
|
11 |
+
import torch
|
12 |
+
import lightning as L
|
13 |
+
from dotenv import load_dotenv, find_dotenv
|
14 |
+
import hydra
|
15 |
+
from omegaconf import DictConfig, OmegaConf
|
16 |
+
from src.utils.logging_utils import setup_logger, task_wrapper
|
17 |
+
from loguru import logger
|
18 |
+
import rootutils
|
19 |
+
from lightning.pytorch.loggers import Logger
|
20 |
+
from lightning.pytorch.callbacks import Callback
|
21 |
+
|
22 |
+
# Load environment variables
|
23 |
+
load_dotenv(find_dotenv(".env"))
|
24 |
+
|
25 |
+
# Setup root directory
|
26 |
+
|
27 |
+
root = rootutils.setup_root(__file__, indicator=".project-root")
|
28 |
+
|
29 |
+
|
30 |
+
def instantiate_callbacks(callback_cfg: DictConfig) -> List[Callback]:
|
31 |
+
"""Instantiate and return a list of callbacks from the configuration."""
|
32 |
+
callbacks_ls: List[L.Callback] = []
|
33 |
+
|
34 |
+
if not callback_cfg:
|
35 |
+
logger.warning("No callback configs found! Skipping..")
|
36 |
+
return None
|
37 |
+
|
38 |
+
if not isinstance(callback_cfg, DictConfig):
|
39 |
+
raise TypeError("Callbacks config must be a DictConfig!")
|
40 |
+
|
41 |
+
for _, cb_conf in callback_cfg.items():
|
42 |
+
if "_target_" in cb_conf:
|
43 |
+
logger.info(f"Instantiating callback <{cb_conf._target_}>")
|
44 |
+
callbacks_ls.append(hydra.utils.instantiate(cb_conf))
|
45 |
+
|
46 |
+
return callbacks_ls
|
47 |
+
|
48 |
+
|
49 |
+
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
50 |
+
"""Instantiate and return a list of loggers from the configuration."""
|
51 |
+
loggers_ls: List[Logger] = []
|
52 |
+
|
53 |
+
if not logger_cfg:
|
54 |
+
logger.warning("No logger configs found! Skipping..")
|
55 |
+
return loggers_ls
|
56 |
+
|
57 |
+
if not isinstance(logger_cfg, DictConfig):
|
58 |
+
raise TypeError("Logger config must be a DictConfig!")
|
59 |
+
|
60 |
+
for _, lg_conf in logger_cfg.items():
|
61 |
+
if "_target_" in lg_conf:
|
62 |
+
logger.info(f"Instantiating logger <{lg_conf._target_}>")
|
63 |
+
loggers_ls.append(hydra.utils.instantiate(lg_conf))
|
64 |
+
|
65 |
+
return loggers_ls
|
66 |
+
|
67 |
+
|
68 |
+
def load_checkpoint_if_available(ckpt_path: str) -> str:
|
69 |
+
"""Return the checkpoint path if available, else None."""
|
70 |
+
if ckpt_path and Path(ckpt_path).exists():
|
71 |
+
logger.info(f"Using checkpoint: {ckpt_path}")
|
72 |
+
return ckpt_path
|
73 |
+
logger.warning(f"Checkpoint not found at {ckpt_path}. Using current model weights.")
|
74 |
+
return None
|
75 |
+
|
76 |
+
|
77 |
+
def clear_checkpoint_directory(ckpt_dir: str):
|
78 |
+
"""Clear checkpoint directory contents without removing the directory."""
|
79 |
+
ckpt_dir_path = Path(ckpt_dir)
|
80 |
+
if not ckpt_dir_path.exists():
|
81 |
+
logger.info(f"Creating checkpoint directory: {ckpt_dir}")
|
82 |
+
ckpt_dir_path.mkdir(parents=True, exist_ok=True)
|
83 |
+
else:
|
84 |
+
logger.info(f"Clearing checkpoint directory: {ckpt_dir}")
|
85 |
+
for item in ckpt_dir_path.iterdir():
|
86 |
+
try:
|
87 |
+
item.unlink() if item.is_file() else shutil.rmtree(item)
|
88 |
+
except Exception as e:
|
89 |
+
logger.error(f"Failed to delete {item}: {e}")
|
90 |
+
|
91 |
+
|
92 |
+
@task_wrapper
|
93 |
+
def train_module(
|
94 |
+
data_module: L.LightningDataModule, model: L.LightningModule, trainer: L.Trainer
|
95 |
+
):
|
96 |
+
"""Train the model and log metrics."""
|
97 |
+
logger.info("Starting training")
|
98 |
+
trainer.fit(model, data_module)
|
99 |
+
train_metrics = trainer.callback_metrics
|
100 |
+
train_acc = train_metrics.get("train_acc")
|
101 |
+
val_acc = train_metrics.get("val_acc")
|
102 |
+
logger.info(
|
103 |
+
f"Training completed. Metrics - train_acc: {train_acc}, val_acc: {val_acc}"
|
104 |
+
)
|
105 |
+
return train_metrics
|
106 |
+
|
107 |
+
|
108 |
+
@task_wrapper
|
109 |
+
def run_test_module(
|
110 |
+
cfg: DictConfig,
|
111 |
+
datamodule: L.LightningDataModule,
|
112 |
+
model: L.LightningModule,
|
113 |
+
trainer: L.Trainer,
|
114 |
+
):
|
115 |
+
"""Test the model using the best checkpoint or current model weights."""
|
116 |
+
logger.info("Starting testing")
|
117 |
+
datamodule.setup(stage="test")
|
118 |
+
test_metrics = trainer.test(
|
119 |
+
model, datamodule, ckpt_path=load_checkpoint_if_available(cfg.ckpt_path)
|
120 |
+
)
|
121 |
+
logger.info(f"Test metrics: {test_metrics}")
|
122 |
+
return test_metrics[0] if test_metrics else {}
|
123 |
+
|
124 |
+
|
125 |
+
@hydra.main(config_path="../configs", config_name="train", version_base="1.1")
|
126 |
+
def setup_run_trainer(cfg: DictConfig):
|
127 |
+
"""Set up and run the Trainer for training and testing."""
|
128 |
+
# Display configuration
|
129 |
+
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
|
130 |
+
|
131 |
+
# Initialize logger
|
132 |
+
log_path = Path(cfg.paths.log_dir) / (
|
133 |
+
"train.log" if cfg.task_name == "train" else "eval.log"
|
134 |
+
)
|
135 |
+
setup_logger(log_path)
|
136 |
+
|
137 |
+
# Display key paths
|
138 |
+
for path_name in [
|
139 |
+
"root_dir",
|
140 |
+
"data_dir",
|
141 |
+
"log_dir",
|
142 |
+
"ckpt_dir",
|
143 |
+
"artifact_dir",
|
144 |
+
"output_dir",
|
145 |
+
]:
|
146 |
+
logger.info(
|
147 |
+
f"{path_name.replace('_', ' ').capitalize()}: {cfg.paths[path_name]}"
|
148 |
+
)
|
149 |
+
|
150 |
+
# Initialize DataModule and Model
|
151 |
+
logger.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
152 |
+
datamodule: L.LightningDataModule = hydra.utils.instantiate(cfg.data)
|
153 |
+
logger.info(f"Instantiating model <{cfg.model._target_}>")
|
154 |
+
model: L.LightningModule = hydra.utils.instantiate(cfg.model)
|
155 |
+
|
156 |
+
# Check GPU availability and set seed for reproducibility
|
157 |
+
logger.info("GPU available" if torch.cuda.is_available() else "No GPU available")
|
158 |
+
L.seed_everything(cfg.seed, workers=True)
|
159 |
+
|
160 |
+
# Set up callbacks, loggers, and Trainer
|
161 |
+
callbacks = instantiate_callbacks(cfg.callbacks)
|
162 |
+
logger.info(f"Callbacks: {callbacks}")
|
163 |
+
loggers = instantiate_loggers(cfg.loggers)
|
164 |
+
logger.info(f"Loggers: {loggers}")
|
165 |
+
trainer: L.Trainer = hydra.utils.instantiate(
|
166 |
+
cfg.trainer, callbacks=callbacks, logger=loggers
|
167 |
+
)
|
168 |
+
|
169 |
+
# Training phase
|
170 |
+
train_metrics = {}
|
171 |
+
if cfg.get("train"):
|
172 |
+
clear_checkpoint_directory(cfg.paths.ckpt_dir)
|
173 |
+
train_metrics = train_module(datamodule, model, trainer)
|
174 |
+
(Path(cfg.paths.ckpt_dir) / "train_done.flag").write_text(
|
175 |
+
"Training completed.\n"
|
176 |
+
)
|
177 |
+
|
178 |
+
# Testing phase
|
179 |
+
test_metrics = {}
|
180 |
+
if cfg.get("test"):
|
181 |
+
test_metrics = run_test_module(cfg, datamodule, model, trainer)
|
182 |
+
|
183 |
+
# Combine metrics and extract optimization metric
|
184 |
+
all_metrics = {**train_metrics, **test_metrics}
|
185 |
+
optimization_metric = all_metrics.get(cfg.get("optimization_metric"), 0.0)
|
186 |
+
(
|
187 |
+
logger.warning(
|
188 |
+
f"Optimization metric '{cfg.get('optimization_metric')}' not found. Defaulting to 0."
|
189 |
+
)
|
190 |
+
if optimization_metric == 0.0
|
191 |
+
else logger.info(f"Optimization metric: {optimization_metric}")
|
192 |
+
)
|
193 |
+
|
194 |
+
return optimization_metric
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
setup_run_trainer()
|