|
import lightning as L |
|
import torch |
|
from lightning.pytorch.callbacks import ( |
|
ModelCheckpoint, |
|
LearningRateMonitor, |
|
EarlyStopping, |
|
) |
|
from lightning.pytorch.loggers import TensorBoardLogger |
|
|
|
from src.dataset import DRDataModule |
|
from src.model import DRModel |
|
|
|
|
|
SEED = 42 |
|
L.seed_everything(SEED, workers=True) |
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
|
|
dm = DRDataModule(batch_size=128, num_workers=24) |
|
dm.setup() |
|
|
|
|
|
model = DRModel( |
|
num_classes=dm.num_classes, learning_rate=3e-4, class_weights=dm.class_weights |
|
) |
|
|
|
|
|
logger = TensorBoardLogger(save_dir="logs",name="") |
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
monitor="val_loss", |
|
mode="min", |
|
save_top_k=2, |
|
dirpath="artifacts/checkpoints", |
|
filename="{epoch}-{step}-{val_loss:.2f}-{val_acc:.2f}-{val_kappa:.2f}", |
|
) |
|
|
|
|
|
lr_monitor = LearningRateMonitor(logging_interval="step") |
|
|
|
|
|
early_stopping = EarlyStopping( |
|
monitor="val_loss", |
|
patience=10, |
|
verbose=True, |
|
mode="min", |
|
) |
|
|
|
|
|
trainer = L.Trainer( |
|
max_epochs=50, |
|
accelerator="auto", |
|
devices="auto", |
|
logger=logger, |
|
callbacks=[checkpoint_callback, lr_monitor, early_stopping], |
|
|
|
) |
|
|
|
|
|
trainer.fit(model, dm) |
|
|