alexfremont's picture
first commit for API
38a3c61
raw
history blame
3.69 kB
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class ResNet(nn.Module):
def __init__(
self, resnet_type="resnet18", trainable_layers=3, num_output_neurons=2
):
super(ResNet, self).__init__()
# Dictionary to map resnet_type to the corresponding torchvision model and weights
resnet_dict = {
"resnet18": (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1),
"resnet34": (models.resnet34, models.ResNet34_Weights.IMAGENET1K_V1),
"resnet50": (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2),
"resnet101": (models.resnet101, models.ResNet101_Weights.IMAGENET1K_V2),
"resnet152": (models.resnet152, models.ResNet152_Weights.IMAGENET1K_V2),
}
# Ensure the provided resnet_type is valid
if resnet_type not in resnet_dict:
raise ValueError(
f"Invalid resnet_type. Expected one of: {list(resnet_dict.keys())}"
)
# Load the specified ResNet model with pre-trained weights
model_func, weights = resnet_dict[resnet_type]
self.resnet = model_func(weights=weights)
# Remove the last fully connected layer
self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])
# Additional pooling to reduce dimensionality further
self.pool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling
# Number of input features to the first fully connected layer
if resnet_type in ["resnet18", "resnet34"]:
fc_in_features = 512
else:
fc_in_features = 2048
# Simplified fully connected layers with Batch Normalization and Dropout
self.fc1 = nn.Linear(
fc_in_features, 128
) # Input features depend on the resnet type
self.bn1 = nn.BatchNorm1d(128) # Batch Normalization
self.dropout1 = nn.Dropout(0.5) # Helps prevent overfitting
self.fc2 = nn.Linear(128, 64)
self.bn2 = nn.BatchNorm1d(64) # Batch Normalization
self.dropout2 = nn.Dropout(0.5) # Helps prevent overfitting
self.fc3 = nn.Linear(
64, num_output_neurons
) # Output layer for binary classification
# Set the requires_grad attribute based on the number of trainable layers
self.set_trainable_layers(trainable_layers)
def set_trainable_layers(self, trainable_layers):
# If trainable_layers is 0, freeze all layers
if trainable_layers == 0:
for param in self.resnet.parameters():
param.requires_grad = False
else:
# Get the total number of layers in resnet
total_layers = len(list(self.resnet.children()))
# Make the last `trainable_layers` layers trainable
for i, layer in enumerate(self.resnet.children()):
if i < total_layers - trainable_layers:
for param in layer.parameters():
param.requires_grad = False
else:
for param in layer.parameters():
param.requires_grad = True
def forward(self, x):
# Use the ResNet backbone
x = self.resnet(x)
# Global average pooling
x = self.pool(x)
# Flattening the output for the dense layer
x = x.view(x.size(0), -1) # Adjust this based on the actual output size
x = F.relu(self.fc1(x))
x = self.bn1(x)
x = self.dropout1(x)
x = F.relu(self.fc2(x))
x = self.bn2(x)
x = self.dropout2(x)
x = self.fc3(x)
return x