import lightning as L | |
import torch | |
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor | |
from lightning.pytorch.loggers import TensorBoardLogger | |
from src.dataset import DRDataModule | |
from src.model import DRModel | |
# seed everything for reproducibility | |
SEED = 42 | |
L.seed_everything(SEED, workers=True) | |
torch.set_float32_matmul_precision("high") | |
# Init DataModule | |
dm = DRDataModule(batch_size=128, num_workers=8) | |
dm.setup() | |
# Init model from datamodule's attributes | |
model = DRModel( | |
num_classes=dm.num_classes, learning_rate=3e-4, class_weights=dm.class_weights | |
) | |
# Init logger | |
logger = TensorBoardLogger("lightning_logs", name="dr_model") | |
# Init callbacks | |
checkpoint_callback = ModelCheckpoint( | |
monitor="val_loss", | |
mode="min", | |
save_top_k=3, | |
dirpath="checkpoints", | |
) | |
# Init trainer | |
trainer = L.Trainer( | |
max_epochs=20, | |
accelerator="auto", | |
devices="auto", | |
logger=logger, | |
callbacks=[checkpoint_callback], | |
enable_checkpointing=True | |
) | |
# Pass the datamodule as arg to trainer.fit to override model hooks :) | |
trainer.fit(model, dm) | |