File size: 2,302 Bytes
e5d6e03
 
 
9f31860
2c1d9fe
2bb6467
 
e5d6e03
 
2bb6467
9f31860
e5d6e03
 
9f31860
e5d6e03
9f31860
 
e5d6e03
 
 
 
9f31860
e5d6e03
 
2c1d9fe
9f31860
e5d6e03
 
 
 
 
 
 
 
 
 
 
9f31860
e5d6e03
 
 
 
 
 
 
 
9f31860
e5d6e03
 
 
 
 
 
 
 
 
 
30df46a
e5d6e03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb6467
9f31860
e5d6e03
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from os.path import join

import hydra
import lightning as L
import torch
from lightning.pytorch.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from lightning.pytorch.loggers import TensorBoardLogger
from omegaconf import DictConfig
from src.data_module import DRDataModule
from src.model import DRModel
from src.utils import generate_run_id


@hydra.main(version_base=None, config_path="conf", config_name="config")
def train(cfg: DictConfig) -> None:
    # generate unique run id based on current date & time
    run_id = generate_run_id()

    # Seed everything for reproducibility
    L.seed_everything(cfg.seed, workers=True)
    torch.set_float32_matmul_precision("high")

    # Initialize DataModule
    dm = DRDataModule(
        train_csv_path=cfg.train_csv_path,
        val_csv_path=cfg.val_csv_path,
        image_size=cfg.image_size,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        use_class_weighting=cfg.use_class_weighting,
        use_weighted_sampler=cfg.use_weighted_sampler,
    )
    dm.setup()

    # Init model from datamodule's attributes
    model = DRModel(
        num_classes=dm.num_classes,
        model_name=cfg.model_name,
        learning_rate=cfg.learning_rate,
        class_weights=dm.class_weights,
        use_scheduler=cfg.use_scheduler,
    )

    # Init logger
    logger = TensorBoardLogger(save_dir=cfg.logs_dir, name="", version=run_id)
    # Init callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        mode="min",
        save_top_k=2,
        dirpath=join(cfg.checkpoint_dirpath, run_id),
        filename="{epoch}-{step}-{val_loss:.2f}-{val_acc:.2f}-{val_kappa:.2f}",
    )

    # Init LearningRateMonitor
    lr_monitor = LearningRateMonitor(logging_interval="step")

    # early stopping
    early_stopping = EarlyStopping(
        monitor="val_loss",
        patience=10,
        verbose=True,
        mode="min",
    )

    # Initialize Trainer
    trainer = L.Trainer(
        max_epochs=cfg.max_epochs,
        accelerator="auto",
        devices="auto",
        logger=logger,
        callbacks=[checkpoint_callback, lr_monitor, early_stopping],
    )

    # Train the model
    trainer.fit(model, dm)


if __name__ == "__main__":
    train()