import lightning as L import torch from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor from lightning.pytorch.loggers import TensorBoardLogger from src.dataset import DRDataModule from src.model import DRModel # seed everything for reproducibility SEED = 42 L.seed_everything(SEED, workers=True) torch.set_float32_matmul_precision("high") # Init DataModule dm = DRDataModule(batch_size=128, num_workers=8) dm.setup() # Init model from datamodule's attributes model = DRModel( num_classes=dm.num_classes, learning_rate=3e-4, class_weights=dm.class_weights ) # Init logger logger = TensorBoardLogger("lightning_logs", name="dr_model") # Init callbacks checkpoint_callback = ModelCheckpoint( monitor="val_loss", mode="min", save_top_k=3, dirpath="checkpoints", ) # Init trainer trainer = L.Trainer( max_epochs=20, accelerator="auto", devices="auto", logger=logger, callbacks=[checkpoint_callback], enable_checkpointing=True ) # Pass the datamodule as arg to trainer.fit to override model hooks :) trainer.fit(model, dm)