bhimrazy's picture
Adds model factory with several model supports
30fdb7f
raw
history blame contribute delete
No virus
4.46 kB
from torchvision import models
from torch import nn
model_mapping = {
"densenet121": (
models.densenet121,
{"weights": models.DenseNet121_Weights.DEFAULT, "family": "densenet"},
),
"densenet161": (
models.densenet161,
{"weights": models.DenseNet161_Weights.DEFAULT, "family": "densenet"},
),
"densenet169": (
models.densenet169,
{"weights": models.DenseNet169_Weights.DEFAULT, "family": "densenet"},
),
"densenet201": (
models.densenet201,
{"weights": models.DenseNet201_Weights.DEFAULT, "family": "densenet"},
),
"resnet50": (
models.resnet50,
{"weights": models.ResNet50_Weights.IMAGENET1K_V2, "family": "resnet"},
),
"resnet101": (
models.resnet101,
{"weights": models.ResNet101_Weights.IMAGENET1K_V2, "family": "resnet"},
),
"resnet152": (
models.resnet152,
{"weights": models.ResNet152_Weights.IMAGENET1K_V2, "family": "resnet"},
),
"vit-b-16": (
models.vit_b_16,
{"weights": models.ViT_B_16_Weights.DEFAULT, "family": "vit"},
),
"vit-b-32": (
models.vit_b_32,
{"weights": models.ViT_B_32_Weights.DEFAULT, "family": "vit"},
),
# Add more models as needed with their respective configurations.
}
class Model(nn.Module):
"""Moodel definition."""
def __init__(self, model_name: str, num_classes: int):
"""
Initialize Model instance.
Args:
model_name (str): Name of the model architecture.
num_classes (int): Number of output classes.
"""
super(Model, self).__init__()
model_class, model_config = model_mapping[model_name]
self.model = model_class(weights=model_config["weights"])
# Freeze model parameters
for param in self.model.parameters():
param.requires_grad = False
in_features = self._get_in_features(model_config["family"])
if model_config["family"] == "densenet":
self.model.classifier = self._create_classifier(in_features, num_classes)
elif model_config["family"] == "resnet":
self.model.fc = self._create_classifier(in_features, num_classes)
elif model_config["family"] == "vit":
self.model.heads = self._create_classifier(in_features, num_classes)
def forward(self, x):
"""Forward pass through the model."""
return self.model(x)
def _get_in_features(self, family: str) -> int:
"""Return the number of input features for the classifier."""
if family == "densenet":
return self.model.classifier.in_features
elif family == "resnet":
return self.model.fc.in_features
elif family == "vit":
return self.model.heads.head.in_features
def _create_classifier(self, in_features: int, num_classes: int) -> nn.Sequential:
"""Create the classifier module."""
return nn.Sequential(
nn.Linear(in_features, in_features // 2),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(in_features // 2, num_classes),
)
class ModelFactory:
"""
Factory for creating different models based on their names.
Args:
name (str): The name of the model factory.
num_classes (int): The number of output classes.
Raises:
ValueError: If the specified model factory is not implemented.
"""
def __init__(self, name: str, num_classes: int):
"""
Initialize ModelFactory instance.
Args:
name (str): The name of the model.
num_classes (int): The number of output classes.
"""
self.name = name
self.num_classes = num_classes
def __call__(self):
"""
Create a model instance based on the provided name.
Args:
model_name (str): Name of the model architecture.
num_classes (int): Number of output classes.
Returns:
Model: An instance of the selected model.
"""
if self.name not in model_mapping:
valid_options = ", ".join(model_mapping.keys())
raise ValueError(
f"Invalid model name: '{self.name}'. Available options: {valid_options}"
)
return Model(self.name, self.num_classes)
if __name__ == "__main__":
model = ModelFactory("resnet50", 5)()