bhimrazy commited on
Commit
30df46a
·
1 Parent(s): 31d1d47

Adds lr moitor to train

Browse files
Files changed (1) hide show
  1. train.py +7 -4
train.py CHANGED
@@ -13,12 +13,12 @@ torch.set_float32_matmul_precision("high")
13
 
14
 
15
  # Init DataModule
16
- dm = DRDataModule(batch_size=128, num_workers=8)
17
  dm.setup()
18
 
19
  # Init model from datamodule's attributes
20
  model = DRModel(
21
- num_classes=dm.num_classes, learning_rate=3e-4, class_weights=dm.class_weights
22
  )
23
 
24
  # Init logger
@@ -32,14 +32,17 @@ checkpoint_callback = ModelCheckpoint(
32
  dirpath="checkpoints",
33
  )
34
 
 
 
 
35
  # Init trainer
36
  trainer = L.Trainer(
37
  max_epochs=20,
38
  accelerator="auto",
39
  devices="auto",
40
  logger=logger,
41
- callbacks=[checkpoint_callback],
42
- enable_checkpointing=True
43
  )
44
 
45
  # Pass the datamodule as arg to trainer.fit to override model hooks :)
 
13
 
14
 
15
  # Init DataModule
16
+ dm = DRDataModule(batch_size=96, num_workers=8)
17
  dm.setup()
18
 
19
  # Init model from datamodule's attributes
20
  model = DRModel(
21
+ num_classes=dm.num_classes, learning_rate=3e-5, class_weights=dm.class_weights
22
  )
23
 
24
  # Init logger
 
32
  dirpath="checkpoints",
33
  )
34
 
35
+ # Init LearningRateMonitor
36
+ lr_monitor = LearningRateMonitor(logging_interval="step")
37
+
38
  # Init trainer
39
  trainer = L.Trainer(
40
  max_epochs=20,
41
  accelerator="auto",
42
  devices="auto",
43
  logger=logger,
44
+ callbacks=[checkpoint_callback, lr_monitor],
45
+ enable_checkpointing=True,
46
  )
47
 
48
  # Pass the datamodule as arg to trainer.fit to override model hooks :)