|
from os.path import join |
|
|
|
import hydra |
|
import lightning as L |
|
from lightning.pytorch.callbacks import ( |
|
EarlyStopping, |
|
LearningRateMonitor, |
|
ModelCheckpoint, |
|
) |
|
from lightning.pytorch.loggers import TensorBoardLogger |
|
from omegaconf import DictConfig |
|
from src.data_module import DRDataModule |
|
from src.model import DRModel |
|
from src.utils import generate_run_id |
|
|
|
|
|
@hydra.main(version_base=None, config_path="conf", config_name="config") |
|
def train(cfg: DictConfig) -> None: |
|
|
|
run_id = generate_run_id() |
|
|
|
|
|
L.seed_everything(cfg.seed, workers=True) |
|
|
|
|
|
|
|
dm = DRDataModule( |
|
train_csv_path=cfg.train_csv_path, |
|
val_csv_path=cfg.val_csv_path, |
|
image_size=cfg.image_size, |
|
batch_size=cfg.batch_size, |
|
num_workers=cfg.num_workers, |
|
use_class_weighting=cfg.use_class_weighting, |
|
use_weighted_sampler=cfg.use_weighted_sampler, |
|
) |
|
dm.setup() |
|
|
|
|
|
model = DRModel( |
|
num_classes=dm.num_classes, |
|
model_name=cfg.model_name, |
|
learning_rate=cfg.learning_rate, |
|
class_weights=dm.class_weights, |
|
use_scheduler=cfg.use_scheduler, |
|
) |
|
|
|
|
|
logger = TensorBoardLogger(save_dir=cfg.logs_dir, name="", version=run_id) |
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
monitor="val_loss", |
|
mode="min", |
|
save_top_k=2, |
|
dirpath=join(cfg.checkpoint_dirpath, run_id), |
|
filename="{epoch}-{step}-{val_loss:.2f}-{val_acc:.2f}-{val_kappa:.2f}", |
|
) |
|
|
|
|
|
lr_monitor = LearningRateMonitor(logging_interval="step") |
|
|
|
|
|
early_stopping = EarlyStopping( |
|
monitor="val_loss", |
|
patience=10, |
|
verbose=True, |
|
mode="min", |
|
) |
|
|
|
|
|
trainer = L.Trainer( |
|
max_epochs=cfg.max_epochs, |
|
accelerator="auto", |
|
devices="auto", |
|
logger=logger, |
|
callbacks=[checkpoint_callback, lr_monitor, early_stopping], |
|
) |
|
|
|
|
|
trainer.fit(model, dm) |
|
|
|
|
|
if __name__ == "__main__": |
|
train() |
|
|