diegulio
🐶🧡🐱
3be99bb
import torch
from torch import nn
import lightning as L
import torch.nn.functional as F
from torch import optim
from torchmetrics import Accuracy
from torch.optim.lr_scheduler import ReduceLROnPlateau
class PetClassificationModel(L.LightningModule):
def __init__(self, base_model, config):
super().__init__()
self.config = config
self.num_classes = len(self.config.idx_to_class)
metric = Accuracy(task="multiclass", num_classes=self.num_classes)
self.train_acc = metric.clone()
self.val_acc = metric.clone()
self.test_acc = metric.clone()
self.training_step_outputs = []
self.validation_step_outputs = []
self.test_step_outputs = []
self.device_ = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.pretrained_model = base_model
out_features = self.pretrained_model.get_classifier().out_features
self.custom_layers = nn.Sequential(
nn.Linear(out_features, 512, device = self.device_),
nn.ReLU(),
nn.Dropout(),
nn.Linear(512, self.num_classes, device = self.device_),
)
def forward(self, x):
x = self.pretrained_model(x)
#x = self.custom_layers(x)
return x
def training_step(self, batch, batch_idx):
x,y = batch
logits = self.forward(x) # -> logits
loss = F.cross_entropy(logits, y)
self.log_dict({'train_loss': loss})
self.training_step_outputs.append({'loss': loss, 'logits': logits, 'y':y})
return loss
def on_train_epoch_end(self):
# Concat batches
outputs = self.training_step_outputs
logits = torch.cat([x['logits'] for x in outputs])
y = torch.cat([x['y'] for x in outputs])
self.train_acc(logits, y)
self.log_dict({
'train_acc': self.train_acc,
},
on_step = False,
on_epoch = True,
prog_bar = True)
self.training_step_outputs.clear()
def validation_step(self, batch, batch_idx):
x,y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
self.log_dict({'val_loss': loss})
self.validation_step_outputs.append({'loss': loss, 'logits': logits, 'y':y})
return loss
def on_validation_epoch_end(self):
# Concat batches
outputs = self.validation_step_outputs
logits = torch.cat([x['logits'] for x in outputs])
y = torch.cat([x['y'] for x in outputs])
self.val_acc(logits, y)
self.log_dict({
'val_acc': self.val_acc,
},
on_step = False,
on_epoch = True,
prog_bar = True)
self.validation_step_outputs.clear()
def test_step(self, batch, batch_idx):
x,y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
self.log_dict({'test_loss': loss})
self.test_step_outputs.append({'loss': loss, 'logits': logits, 'y':y})
return loss
def on_test_epoch_end(self):
# Concat batches
outputs = self.test_step_outputs
logits = torch.cat([x['logits'] for x in outputs])
y = torch.cat([x['y'] for x in outputs])
self.test_acc(logits, y)
self.log_dict({
'test_acc': self.test_acc,
},
on_step = False,
on_epoch = True,
prog_bar = True)
self.test_step_outputs.clear()
def predict_step(self, batch):
x, y = batch
return self.model(x, y)
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.config.LEARNING_RATE)
lr_scheduler = ReduceLROnPlateau(optimizer, mode = 'min', patience = 3)
lr_scheduler_dict = {
"scheduler": lr_scheduler,
"interval": "epoch",
"monitor": "val_loss",
}
return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler_dict}