Spaces:
Sleeping
Sleeping
File size: 1,026 Bytes
bee69fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from torch import manual_seed, nn
from torchvision import transforms, models
def create_model_alexnet(num_classes:int=2, seed:int=42):
"""Creates model and transforms.
Args:
num_classes (int, optional): number of classes in the classifier head, defaults to 2.
seed (int, optional): random seed value. Defaults to 42.
Returns:
model (torch.nn.Module): Alexnet model.
transforms (torchvision.transforms): Alexnet image transforms.
"""
# Create Alexnet pretrained weights, transforms and model
weights = models.AlexNet_Weights.IMAGENET1K_V1.DEFAULT
auto_transform = weights.transforms()
model_alexnet = models.alexnet(weights=weights)
# Freeze all layers in base model
for param in model_alexnet.parameters():
param.requires_grad = False
# Change classifier head with random seed for reproducibility
manual_seed(seed)
model_alexnet.classifier[6] = nn.Linear(4096, out_features=num_classes)
return model_alexnet, auto_transform
|