|
import lightning as L |
|
import torch |
|
from torch import nn |
|
from torchmetrics.functional import accuracy, cohen_kappa |
|
from torchvision import models |
|
|
|
|
|
class DRModel(L.LightningModule): |
|
def __init__( |
|
self, num_classes: int, learning_rate: float = 2e-4, class_weights=None |
|
): |
|
super().__init__() |
|
self.save_hyperparameters() |
|
self.num_classes = num_classes |
|
self.learning_rate = learning_rate |
|
|
|
|
|
|
|
|
|
|
|
self.model = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT) |
|
|
|
|
|
|
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
in_features = 768 |
|
self.model.heads = nn.Sequential( |
|
|
|
nn.Linear(in_features, in_features // 2), |
|
nn.ReLU(), |
|
nn.Dropout(0.5), |
|
nn.Linear(in_features // 2, num_classes), |
|
) |
|
|
|
|
|
self.criterion = nn.CrossEntropyLoss(weight=class_weights) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
def training_step(self, batch): |
|
x, y = batch |
|
logits = self.model(x) |
|
loss = self.criterion(logits, y) |
|
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
x, y = batch |
|
logits = self.model(x) |
|
loss = self.criterion(logits, y) |
|
preds = torch.argmax(logits, dim=1) |
|
acc = accuracy(preds, y, task="multiclass", num_classes=self.num_classes) |
|
kappa = cohen_kappa( |
|
preds, |
|
y, |
|
task="multiclass", |
|
num_classes=self.num_classes, |
|
weights="quadratic", |
|
) |
|
self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True) |
|
self.log("val_acc", acc, on_step=True, on_epoch=True, prog_bar=True) |
|
self.log("val_kappa", kappa, on_step=True, on_epoch=True, prog_bar=True) |
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
self.parameters(), lr=self.learning_rate, weight_decay=0.05 |
|
) |
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) |
|
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
optimizer, |
|
mode="min", |
|
factor=0.1, |
|
patience=5, |
|
verbose=True, |
|
threshold=0.001, |
|
) |
|
return { |
|
"optimizer": optimizer, |
|
"lr_scheduler": { |
|
"scheduler": scheduler, |
|
"interval": "epoch", |
|
"monitor": "val_loss", |
|
}, |
|
} |
|
|
|
|