bhimrazy commited on
Commit
e5d6e03
1 Parent(s): a2f341c

Adds support for hydra

Browse files
Files changed (1) hide show
  1. train.py +68 -46
train.py CHANGED
@@ -1,61 +1,83 @@
 
 
 
1
  import lightning as L
2
- import torch
3
  from lightning.pytorch.callbacks import (
4
- ModelCheckpoint,
5
- LearningRateMonitor,
6
  EarlyStopping,
 
 
7
  )
8
  from lightning.pytorch.loggers import TensorBoardLogger
9
-
10
- from src.dataset import DRDataModule
11
  from src.model import DRModel
 
12
 
13
- # seed everything for reproducibility
14
- SEED = 42
15
- L.seed_everything(SEED, workers=True)
16
- torch.set_float32_matmul_precision("high")
17
 
 
 
 
 
18
 
19
- # Init DataModule
20
- dm = DRDataModule(batch_size=128, num_workers=24)
21
- dm.setup()
22
 
23
- # Init model from datamodule's attributes
24
- model = DRModel(
25
- num_classes=dm.num_classes, learning_rate=3e-4, class_weights=dm.class_weights
26
- )
 
 
 
 
 
 
 
27
 
28
- # Init logger
29
- logger = TensorBoardLogger(save_dir="logs",name="")
30
- # Init callbacks
31
- checkpoint_callback = ModelCheckpoint(
32
- monitor="val_loss",
33
- mode="min",
34
- save_top_k=2,
35
- dirpath="artifacts/checkpoints",
36
- filename="{epoch}-{step}-{val_loss:.2f}-{val_acc:.2f}-{val_kappa:.2f}",
37
- )
38
 
39
- # Init LearningRateMonitor
40
- lr_monitor = LearningRateMonitor(logging_interval="step")
 
 
 
 
 
 
 
 
41
 
42
- # early stopping
43
- early_stopping = EarlyStopping(
44
- monitor="val_loss",
45
- patience=10,
46
- verbose=True,
47
- mode="min",
48
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Init trainer
51
- trainer = L.Trainer(
52
- max_epochs=50,
53
- accelerator="auto",
54
- devices="auto",
55
- logger=logger,
56
- callbacks=[checkpoint_callback, lr_monitor, early_stopping],
57
- # check_val_every_n_epoch=4,
58
- )
59
 
60
- # Pass the datamodule as arg to trainer.fit to override model hooks :)
61
- trainer.fit(model, dm)
 
1
+ from os.path import join
2
+
3
+ import hydra
4
  import lightning as L
 
5
  from lightning.pytorch.callbacks import (
 
 
6
  EarlyStopping,
7
+ LearningRateMonitor,
8
+ ModelCheckpoint,
9
  )
10
  from lightning.pytorch.loggers import TensorBoardLogger
11
+ from omegaconf import DictConfig
12
+ from src.data_module import DRDataModule
13
  from src.model import DRModel
14
+ from src.utils import generate_run_id
15
 
 
 
 
 
16
 
17
+ @hydra.main(version_base=None, config_path="conf", config_name="config")
18
+ def train(cfg: DictConfig) -> None:
19
+ # generate unique run id based on current date & time
20
+ run_id = generate_run_id()
21
 
22
+ # Seed everything for reproducibility
23
+ L.seed_everything(cfg.seed, workers=True)
24
+ # torch.set_float32_matmul_precision("high")
25
 
26
+ # Initialize DataModule
27
+ dm = DRDataModule(
28
+ train_csv_path=cfg.train_csv_path,
29
+ val_csv_path=cfg.val_csv_path,
30
+ image_size=cfg.image_size,
31
+ batch_size=cfg.batch_size,
32
+ num_workers=cfg.num_workers,
33
+ use_class_weighting=cfg.use_class_weighting,
34
+ use_weighted_sampler=cfg.use_weighted_sampler,
35
+ )
36
+ dm.setup()
37
 
38
+ # Init model from datamodule's attributes
39
+ model = DRModel(
40
+ num_classes=dm.num_classes,
41
+ model_name=cfg.model_name,
42
+ learning_rate=cfg.learning_rate,
43
+ class_weights=dm.class_weights,
44
+ use_scheduler=cfg.use_scheduler,
45
+ )
 
 
46
 
47
+ # Init logger
48
+ logger = TensorBoardLogger(save_dir=cfg.logs_dir, name="", version=run_id)
49
+ # Init callbacks
50
+ checkpoint_callback = ModelCheckpoint(
51
+ monitor="val_loss",
52
+ mode="min",
53
+ save_top_k=2,
54
+ dirpath=join(cfg.checkpoint_dirpath, run_id),
55
+ filename="{epoch}-{step}-{val_loss:.2f}-{val_acc:.2f}-{val_kappa:.2f}",
56
+ )
57
 
58
+ # Init LearningRateMonitor
59
+ lr_monitor = LearningRateMonitor(logging_interval="step")
60
+
61
+ # early stopping
62
+ early_stopping = EarlyStopping(
63
+ monitor="val_loss",
64
+ patience=10,
65
+ verbose=True,
66
+ mode="min",
67
+ )
68
+
69
+ # Initialize Trainer
70
+ trainer = L.Trainer(
71
+ max_epochs=cfg.max_epochs,
72
+ accelerator="auto",
73
+ devices="auto",
74
+ logger=logger,
75
+ callbacks=[checkpoint_callback, lr_monitor, early_stopping],
76
+ )
77
+
78
+ # Train the model
79
+ trainer.fit(model, dm)
80
 
 
 
 
 
 
 
 
 
 
81
 
82
+ if __name__ == "__main__":
83
+ train()