File size: 3,723 Bytes
d4ab5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Parts of this file have been adapted from
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial15/Vision_Transformer.html
"""

import pytorch_lightning as pl
import torch.nn.functional as F

from argparse import ArgumentParser
from torch import Tensor
from torch.optim import AdamW, Optimizer, RAdam
from torch.optim.lr_scheduler import _LRScheduler
from transformers import get_scheduler, PreTrainedModel


class ImageClassificationNet(pl.LightningModule):
    @staticmethod
    def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
        parser = parent_parser.add_argument_group("Classification Model")
        parser.add_argument(
            "--optimizer",
            type=str,
            default="AdamW",
            choices=["AdamW", "RAdam"],
            help="The optimizer to use to train the model.",
        )
        parser.add_argument(
            "--weight_decay",
            type=float,
            default=1e-2,
            help="The optimizer's weight decay.",
        )
        parser.add_argument(
            "--lr",
            type=float,
            default=5e-5,
            help="The initial learning rate for the model.",
        )
        return parent_parser

    def __init__(
        self,
        model: PreTrainedModel,
        num_train_steps: int,
        optimizer: str = "AdamW",
        weight_decay: float = 1e-2,
        lr: float = 5e-5,
    ):
        """A PyTorch Lightning Module for a HuggingFace model used for image classification.

        Args:
            model (PreTrainedModel): a pretrained model for image classification
            num_train_steps (int): number of training steps
            optimizer (str): optimizer to use
            weight_decay (float): weight decay for optimizer
            lr (float): the learning rate used for training
        """
        super().__init__()

        # Save the hyperparameters and the model
        self.save_hyperparameters(ignore=["model"])
        self.model = model

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x).logits

    def configure_optimizers(self) -> tuple[list[Optimizer], list[_LRScheduler]]:
        # Set the optimizer class based on the hyperparameter
        if self.hparams.optimizer == "AdamW":
            optim_class = AdamW
        elif self.hparams.optimizer == "RAdam":
            optim_class = RAdam
        else:
            raise Exception(f"Unknown optimizer {self.hparams.optimizer}")

        # Create the optimizer and the learning rate scheduler
        optimizer = optim_class(
            self.parameters(),
            weight_decay=self.hparams.weight_decay,
            lr=self.hparams.lr,
        )
        lr_scheduler = get_scheduler(
            name="linear",
            optimizer=optimizer,
            num_warmup_steps=0,
            num_training_steps=self.hparams.num_train_steps,
        )

        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch: tuple[Tensor, Tensor], mode: str) -> Tensor:
        imgs, labels = batch

        preds = self.model(imgs).logits
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log(f"{mode}_loss", loss)
        self.log(f"{mode}_acc", acc)

        return loss

    def training_step(self, batch: tuple[Tensor, Tensor], _: Tensor) -> Tensor:
        loss = self._calculate_loss(batch, mode="train")

        return loss

    def validation_step(self, batch: tuple[Tensor, Tensor], _: Tensor):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch: tuple[Tensor, Tensor], _: Tensor):
        self._calculate_loss(batch, mode="test")