bhimrazy commited on
Commit
30fdb7f
1 Parent(s): c118196

Adds model factory with several model supports

Browse files
Files changed (2) hide show
  1. src/model.py +24 -41
  2. src/models/factory.py +141 -0
src/model.py CHANGED
@@ -2,42 +2,26 @@ import lightning as L
2
  import torch
3
  from torch import nn
4
  from torchmetrics.functional import accuracy, cohen_kappa
5
- from torchvision import models
6
 
7
 
8
  class DRModel(L.LightningModule):
9
  def __init__(
10
- self, num_classes: int, learning_rate: float = 2e-4, class_weights=None
 
 
 
 
 
11
  ):
12
  super().__init__()
13
  self.save_hyperparameters()
14
  self.num_classes = num_classes
15
  self.learning_rate = learning_rate
 
16
 
17
  # Define the model
18
- # self.model = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
19
- # self.model = models.densenet169(weights=models.DenseNet169_Weights.DEFAULT)
20
- # self.model = models.densenet161(weights=models.DenseNet161_Weights.DEFAULT)
21
- self.model = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)
22
- # self.model = models.vit_b_32(weights=models.ViT_B_32_Weights.DEFAULT)
23
-
24
- # freeze the feature extractor
25
- for param in self.model.parameters():
26
- param.requires_grad = False
27
-
28
- # self.model.head.weight.requires_grad = True
29
- # self.model.head.bias.requires_grad = True
30
-
31
- # Change the output layer to have the number of classes
32
- # in_features = self.model.classifier.in_features
33
- in_features = 768
34
- self.model.heads = nn.Sequential(
35
- # self.model.classifier = nn.Sequential(
36
- nn.Linear(in_features, in_features // 2),
37
- nn.ReLU(),
38
- nn.Dropout(0.5),
39
- nn.Linear(in_features // 2, num_classes),
40
- )
41
 
42
  # Define the loss function
43
  self.criterion = nn.CrossEntropyLoss(weight=class_weights)
@@ -70,16 +54,20 @@ class DRModel(L.LightningModule):
70
  self.log("val_kappa", kappa, on_step=True, on_epoch=True, prog_bar=True)
71
 
72
  def configure_optimizers(self):
73
- # optimizer = torch.optim.Adam(
74
- # self.parameters(), lr=self.learning_rate, weight_decay=1e-4
75
- # )
76
-
77
  optimizer = torch.optim.AdamW(
78
  self.parameters(), lr=self.learning_rate, weight_decay=0.05
79
  )
80
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
81
- # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
82
- scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
 
 
 
 
 
 
 
 
83
  optimizer,
84
  mode="min", # or "max" if you're maximizing a metric
85
  factor=0.1, # factor by which the learning rate will be reduced
@@ -87,12 +75,7 @@ class DRModel(L.LightningModule):
87
  verbose=True, # print a message when learning rate is reduced
88
  threshold=0.001, # threshold for measuring the new optimum, to only focus on significant changes
89
  )
90
- return {
91
- "optimizer": optimizer,
92
- "lr_scheduler": {
93
- "scheduler": scheduler,
94
- "interval": "epoch",
95
- "monitor": "val_loss",
96
- },
97
- }
98
- # return optimizer
 
2
  import torch
3
  from torch import nn
4
  from torchmetrics.functional import accuracy, cohen_kappa
5
+ from src.models.factory import ModelFactory
6
 
7
 
8
  class DRModel(L.LightningModule):
9
  def __init__(
10
+ self,
11
+ num_classes: int,
12
+ model_name: str = "densenet121",
13
+ learning_rate: float = 3e-4,
14
+ class_weights=None,
15
+ use_scheduler: bool = True,
16
  ):
17
  super().__init__()
18
  self.save_hyperparameters()
19
  self.num_classes = num_classes
20
  self.learning_rate = learning_rate
21
+ self.use_scheduler = use_scheduler
22
 
23
  # Define the model
24
+ self.model = ModelFactory(name=model_name, num_classes=num_classes)()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # Define the loss function
27
  self.criterion = nn.CrossEntropyLoss(weight=class_weights)
 
54
  self.log("val_kappa", kappa, on_step=True, on_epoch=True, prog_bar=True)
55
 
56
  def configure_optimizers(self):
 
 
 
 
57
  optimizer = torch.optim.AdamW(
58
  self.parameters(), lr=self.learning_rate, weight_decay=0.05
59
  )
60
+
61
+ configuration = {
62
+ "optimizer": optimizer,
63
+ "monitor": "val_loss", # monitor validation loss
64
+ }
65
+
66
+ if self.use_scheduler:
67
+ # Add lr scheduler
68
+ # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
69
+ # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
70
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
71
  optimizer,
72
  mode="min", # or "max" if you're maximizing a metric
73
  factor=0.1, # factor by which the learning rate will be reduced
 
75
  verbose=True, # print a message when learning rate is reduced
76
  threshold=0.001, # threshold for measuring the new optimum, to only focus on significant changes
77
  )
78
+
79
+ configuration["lr_scheduler"] = scheduler
80
+
81
+ return configuration
 
 
 
 
 
src/models/factory.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import models
2
+ from torch import nn
3
+
4
+ model_mapping = {
5
+ "densenet121": (
6
+ models.densenet121,
7
+ {"weights": models.DenseNet121_Weights.DEFAULT, "family": "densenet"},
8
+ ),
9
+ "densenet161": (
10
+ models.densenet161,
11
+ {"weights": models.DenseNet161_Weights.DEFAULT, "family": "densenet"},
12
+ ),
13
+ "densenet169": (
14
+ models.densenet169,
15
+ {"weights": models.DenseNet169_Weights.DEFAULT, "family": "densenet"},
16
+ ),
17
+ "densenet201": (
18
+ models.densenet201,
19
+ {"weights": models.DenseNet201_Weights.DEFAULT, "family": "densenet"},
20
+ ),
21
+ "resnet50": (
22
+ models.resnet50,
23
+ {"weights": models.ResNet50_Weights.IMAGENET1K_V2, "family": "resnet"},
24
+ ),
25
+ "resnet101": (
26
+ models.resnet101,
27
+ {"weights": models.ResNet101_Weights.IMAGENET1K_V2, "family": "resnet"},
28
+ ),
29
+ "resnet152": (
30
+ models.resnet152,
31
+ {"weights": models.ResNet152_Weights.IMAGENET1K_V2, "family": "resnet"},
32
+ ),
33
+ "vit-b-16": (
34
+ models.vit_b_16,
35
+ {"weights": models.ViT_B_16_Weights.DEFAULT, "family": "vit"},
36
+ ),
37
+ "vit-b-32": (
38
+ models.vit_b_32,
39
+ {"weights": models.ViT_B_32_Weights.DEFAULT, "family": "vit"},
40
+ ),
41
+ # Add more models as needed with their respective configurations.
42
+ }
43
+
44
+
45
+ class Model(nn.Module):
46
+ """Moodel definition."""
47
+
48
+ def __init__(self, model_name: str, num_classes: int):
49
+ """
50
+ Initialize Model instance.
51
+
52
+ Args:
53
+ model_name (str): Name of the model architecture.
54
+ num_classes (int): Number of output classes.
55
+ """
56
+ super(Model, self).__init__()
57
+
58
+ model_class, model_config = model_mapping[model_name]
59
+ self.model = model_class(weights=model_config["weights"])
60
+
61
+ # Freeze model parameters
62
+ for param in self.model.parameters():
63
+ param.requires_grad = False
64
+
65
+ in_features = self._get_in_features(model_config["family"])
66
+
67
+ if model_config["family"] == "densenet":
68
+ self.model.classifier = self._create_classifier(in_features, num_classes)
69
+ elif model_config["family"] == "resnet":
70
+ self.model.fc = self._create_classifier(in_features, num_classes)
71
+ elif model_config["family"] == "vit":
72
+ self.model.heads = self._create_classifier(in_features, num_classes)
73
+
74
+ def forward(self, x):
75
+ """Forward pass through the model."""
76
+ return self.model(x)
77
+
78
+ def _get_in_features(self, family: str) -> int:
79
+ """Return the number of input features for the classifier."""
80
+ if family == "densenet":
81
+ return self.model.classifier.in_features
82
+ elif family == "resnet":
83
+ return self.model.fc.in_features
84
+ elif family == "vit":
85
+ return self.model.heads.head.in_features
86
+
87
+ def _create_classifier(self, in_features: int, num_classes: int) -> nn.Sequential:
88
+ """Create the classifier module."""
89
+ return nn.Sequential(
90
+ nn.Linear(in_features, in_features // 2),
91
+ nn.ReLU(),
92
+ nn.Dropout(0.5),
93
+ nn.Linear(in_features // 2, num_classes),
94
+ )
95
+
96
+
97
+ class ModelFactory:
98
+ """
99
+ Factory for creating different models based on their names.
100
+
101
+ Args:
102
+ name (str): The name of the model factory.
103
+ num_classes (int): The number of output classes.
104
+
105
+ Raises:
106
+ ValueError: If the specified model factory is not implemented.
107
+ """
108
+
109
+ def __init__(self, name: str, num_classes: int):
110
+ """
111
+ Initialize ModelFactory instance.
112
+
113
+ Args:
114
+ name (str): The name of the model.
115
+ num_classes (int): The number of output classes.
116
+ """
117
+ self.name = name
118
+ self.num_classes = num_classes
119
+
120
+ def __call__(self):
121
+ """
122
+ Create a model instance based on the provided name.
123
+
124
+ Args:
125
+ model_name (str): Name of the model architecture.
126
+ num_classes (int): Number of output classes.
127
+
128
+ Returns:
129
+ Model: An instance of the selected model.
130
+ """
131
+ if self.name not in model_mapping:
132
+ valid_options = ", ".join(model_mapping.keys())
133
+ raise ValueError(
134
+ f"Invalid model name: '{self.name}'. Available options: {valid_options}"
135
+ )
136
+
137
+ return Model(self.name, self.num_classes)
138
+
139
+
140
+ if __name__ == "__main__":
141
+ model = ModelFactory("resnet50", 5)()