Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
class ResNet(nn.Module): | |
def __init__( | |
self, resnet_type="resnet18", trainable_layers=3, num_output_neurons=2 | |
): | |
super(ResNet, self).__init__() | |
# Dictionary to map resnet_type to the corresponding torchvision model and weights | |
resnet_dict = { | |
"resnet18": (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1), | |
"resnet34": (models.resnet34, models.ResNet34_Weights.IMAGENET1K_V1), | |
"resnet50": (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2), | |
"resnet101": (models.resnet101, models.ResNet101_Weights.IMAGENET1K_V2), | |
"resnet152": (models.resnet152, models.ResNet152_Weights.IMAGENET1K_V2), | |
} | |
# Ensure the provided resnet_type is valid | |
if resnet_type not in resnet_dict: | |
raise ValueError( | |
f"Invalid resnet_type. Expected one of: {list(resnet_dict.keys())}" | |
) | |
# Load the specified ResNet model with pre-trained weights | |
model_func, weights = resnet_dict[resnet_type] | |
self.resnet = model_func(weights=weights) | |
# Remove the last fully connected layer | |
self.resnet = nn.Sequential(*list(self.resnet.children())[:-2]) | |
# Additional pooling to reduce dimensionality further | |
self.pool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling | |
# Number of input features to the first fully connected layer | |
if resnet_type in ["resnet18", "resnet34"]: | |
fc_in_features = 512 | |
else: | |
fc_in_features = 2048 | |
# Simplified fully connected layers with Batch Normalization and Dropout | |
self.fc1 = nn.Linear( | |
fc_in_features, 128 | |
) # Input features depend on the resnet type | |
self.bn1 = nn.BatchNorm1d(128) # Batch Normalization | |
self.dropout1 = nn.Dropout(0.5) # Helps prevent overfitting | |
self.fc2 = nn.Linear(128, 64) | |
self.bn2 = nn.BatchNorm1d(64) # Batch Normalization | |
self.dropout2 = nn.Dropout(0.5) # Helps prevent overfitting | |
self.fc3 = nn.Linear( | |
64, num_output_neurons | |
) # Output layer for binary classification | |
# Set the requires_grad attribute based on the number of trainable layers | |
self.set_trainable_layers(trainable_layers) | |
def set_trainable_layers(self, trainable_layers): | |
# If trainable_layers is 0, freeze all layers | |
if trainable_layers == 0: | |
for param in self.resnet.parameters(): | |
param.requires_grad = False | |
else: | |
# Get the total number of layers in resnet | |
total_layers = len(list(self.resnet.children())) | |
# Make the last `trainable_layers` layers trainable | |
for i, layer in enumerate(self.resnet.children()): | |
if i < total_layers - trainable_layers: | |
for param in layer.parameters(): | |
param.requires_grad = False | |
else: | |
for param in layer.parameters(): | |
param.requires_grad = True | |
def forward(self, x): | |
# Use the ResNet backbone | |
x = self.resnet(x) | |
# Global average pooling | |
x = self.pool(x) | |
# Flattening the output for the dense layer | |
x = x.view(x.size(0), -1) # Adjust this based on the actual output size | |
x = F.relu(self.fc1(x)) | |
x = self.bn1(x) | |
x = self.dropout1(x) | |
x = F.relu(self.fc2(x)) | |
x = self.bn2(x) | |
x = self.dropout2(x) | |
x = self.fc3(x) | |
return x | |