demo_B21_AIML / model.py
Ramendra's picture
Upload 3 files
bee69fc
raw
history blame
1.03 kB
from torch import manual_seed, nn
from torchvision import transforms, models
def create_model_alexnet(num_classes:int=2, seed:int=42):
"""Creates model and transforms.
Args:
num_classes (int, optional): number of classes in the classifier head, defaults to 2.
seed (int, optional): random seed value. Defaults to 42.
Returns:
model (torch.nn.Module): Alexnet model.
transforms (torchvision.transforms): Alexnet image transforms.
"""
# Create Alexnet pretrained weights, transforms and model
weights = models.AlexNet_Weights.IMAGENET1K_V1.DEFAULT
auto_transform = weights.transforms()
model_alexnet = models.alexnet(weights=weights)
# Freeze all layers in base model
for param in model_alexnet.parameters():
param.requires_grad = False
# Change classifier head with random seed for reproducibility
manual_seed(seed)
model_alexnet.classifier[6] = nn.Linear(4096, out_features=num_classes)
return model_alexnet, auto_transform