File size: 3,660 Bytes
3be99bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

import torch
from torch import nn


import lightning as L

import torch.nn.functional as F
from torch import optim
from torchmetrics import Accuracy

from torch.optim.lr_scheduler import ReduceLROnPlateau



class PetClassificationModel(L.LightningModule):
  def __init__(self, base_model, config):
    super().__init__()
    self.config = config
    self.num_classes = len(self.config.idx_to_class)
    metric = Accuracy(task="multiclass", num_classes=self.num_classes)
    self.train_acc = metric.clone()
    self.val_acc = metric.clone()
    self.test_acc = metric.clone()
    self.training_step_outputs = []
    self.validation_step_outputs = []
    self.test_step_outputs = []
    self.device_ = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    self.pretrained_model = base_model
    out_features = self.pretrained_model.get_classifier().out_features
    self.custom_layers = nn.Sequential(
          nn.Linear(out_features, 512, device = self.device_),
          nn.ReLU(),
          nn.Dropout(),
          nn.Linear(512, self.num_classes, device = self.device_),
        )

  def forward(self, x):
    x = self.pretrained_model(x)
    #x = self.custom_layers(x)
    return x


  def training_step(self, batch, batch_idx):
    x,y = batch
    logits = self.forward(x) # -> logits
    loss = F.cross_entropy(logits, y)
    self.log_dict({'train_loss': loss})
    self.training_step_outputs.append({'loss': loss, 'logits': logits, 'y':y})
    return loss

  def on_train_epoch_end(self):
    # Concat batches
    outputs = self.training_step_outputs
    logits = torch.cat([x['logits'] for x in outputs])
    y = torch.cat([x['y'] for x in outputs])
    self.train_acc(logits, y)
    self.log_dict({
        'train_acc': self.train_acc,
      },
      on_step = False,
      on_epoch = True,
      prog_bar = True)
    self.training_step_outputs.clear()

  def validation_step(self, batch, batch_idx):
    x,y = batch
    logits = self.forward(x)
    loss = F.cross_entropy(logits, y)
    self.log_dict({'val_loss': loss})
    self.validation_step_outputs.append({'loss': loss, 'logits': logits, 'y':y})
    return loss

  def on_validation_epoch_end(self):
    # Concat batches
    outputs = self.validation_step_outputs
    logits = torch.cat([x['logits'] for x in outputs])
    y = torch.cat([x['y'] for x in outputs])
    self.val_acc(logits, y)
    self.log_dict({
        'val_acc': self.val_acc,
      },
      on_step = False,
      on_epoch = True,
      prog_bar = True)
    self.validation_step_outputs.clear()

  def test_step(self, batch, batch_idx):
    x,y = batch
    logits = self.forward(x)
    loss = F.cross_entropy(logits, y)
    self.log_dict({'test_loss': loss})
    self.test_step_outputs.append({'loss': loss, 'logits': logits, 'y':y})
    return loss

  def on_test_epoch_end(self):
    # Concat batches
    outputs = self.test_step_outputs
    logits = torch.cat([x['logits'] for x in outputs])
    y = torch.cat([x['y'] for x in outputs])
    self.test_acc(logits, y)
    self.log_dict({
        'test_acc': self.test_acc,
      },
      on_step = False,
      on_epoch = True,
      prog_bar = True)
    self.test_step_outputs.clear()

  def predict_step(self, batch):
        x, y = batch
        return self.model(x, y)

  def configure_optimizers(self):
    optimizer = optim.Adam(self.parameters(), lr=self.config.LEARNING_RATE)
    lr_scheduler = ReduceLROnPlateau(optimizer, mode = 'min', patience = 3)
    lr_scheduler_dict = {
        "scheduler": lr_scheduler,
        "interval": "epoch",
         "monitor": "val_loss",
    }
    return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler_dict}