File size: 4,464 Bytes
30fdb7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from torchvision import models
from torch import nn
model_mapping = {
"densenet121": (
models.densenet121,
{"weights": models.DenseNet121_Weights.DEFAULT, "family": "densenet"},
),
"densenet161": (
models.densenet161,
{"weights": models.DenseNet161_Weights.DEFAULT, "family": "densenet"},
),
"densenet169": (
models.densenet169,
{"weights": models.DenseNet169_Weights.DEFAULT, "family": "densenet"},
),
"densenet201": (
models.densenet201,
{"weights": models.DenseNet201_Weights.DEFAULT, "family": "densenet"},
),
"resnet50": (
models.resnet50,
{"weights": models.ResNet50_Weights.IMAGENET1K_V2, "family": "resnet"},
),
"resnet101": (
models.resnet101,
{"weights": models.ResNet101_Weights.IMAGENET1K_V2, "family": "resnet"},
),
"resnet152": (
models.resnet152,
{"weights": models.ResNet152_Weights.IMAGENET1K_V2, "family": "resnet"},
),
"vit-b-16": (
models.vit_b_16,
{"weights": models.ViT_B_16_Weights.DEFAULT, "family": "vit"},
),
"vit-b-32": (
models.vit_b_32,
{"weights": models.ViT_B_32_Weights.DEFAULT, "family": "vit"},
),
# Add more models as needed with their respective configurations.
}
class Model(nn.Module):
"""Moodel definition."""
def __init__(self, model_name: str, num_classes: int):
"""
Initialize Model instance.
Args:
model_name (str): Name of the model architecture.
num_classes (int): Number of output classes.
"""
super(Model, self).__init__()
model_class, model_config = model_mapping[model_name]
self.model = model_class(weights=model_config["weights"])
# Freeze model parameters
for param in self.model.parameters():
param.requires_grad = False
in_features = self._get_in_features(model_config["family"])
if model_config["family"] == "densenet":
self.model.classifier = self._create_classifier(in_features, num_classes)
elif model_config["family"] == "resnet":
self.model.fc = self._create_classifier(in_features, num_classes)
elif model_config["family"] == "vit":
self.model.heads = self._create_classifier(in_features, num_classes)
def forward(self, x):
"""Forward pass through the model."""
return self.model(x)
def _get_in_features(self, family: str) -> int:
"""Return the number of input features for the classifier."""
if family == "densenet":
return self.model.classifier.in_features
elif family == "resnet":
return self.model.fc.in_features
elif family == "vit":
return self.model.heads.head.in_features
def _create_classifier(self, in_features: int, num_classes: int) -> nn.Sequential:
"""Create the classifier module."""
return nn.Sequential(
nn.Linear(in_features, in_features // 2),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(in_features // 2, num_classes),
)
class ModelFactory:
"""
Factory for creating different models based on their names.
Args:
name (str): The name of the model factory.
num_classes (int): The number of output classes.
Raises:
ValueError: If the specified model factory is not implemented.
"""
def __init__(self, name: str, num_classes: int):
"""
Initialize ModelFactory instance.
Args:
name (str): The name of the model.
num_classes (int): The number of output classes.
"""
self.name = name
self.num_classes = num_classes
def __call__(self):
"""
Create a model instance based on the provided name.
Args:
model_name (str): Name of the model architecture.
num_classes (int): Number of output classes.
Returns:
Model: An instance of the selected model.
"""
if self.name not in model_mapping:
valid_options = ", ".join(model_mapping.keys())
raise ValueError(
f"Invalid model name: '{self.name}'. Available options: {valid_options}"
)
return Model(self.name, self.num_classes)
if __name__ == "__main__":
model = ModelFactory("resnet50", 5)() |