import torch.nn as nn import torch.nn.functional as F import torchvision.models as models class ResNet(nn.Module): def __init__( self, resnet_type="resnet18", trainable_layers=3, num_output_neurons=2 ): super(ResNet, self).__init__() # Dictionary to map resnet_type to the corresponding torchvision model and weights resnet_dict = { "resnet18": (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1), "resnet34": (models.resnet34, models.ResNet34_Weights.IMAGENET1K_V1), "resnet50": (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2), "resnet101": (models.resnet101, models.ResNet101_Weights.IMAGENET1K_V2), "resnet152": (models.resnet152, models.ResNet152_Weights.IMAGENET1K_V2), } # Ensure the provided resnet_type is valid if resnet_type not in resnet_dict: raise ValueError( f"Invalid resnet_type. Expected one of: {list(resnet_dict.keys())}" ) # Load the specified ResNet model with pre-trained weights model_func, weights = resnet_dict[resnet_type] self.resnet = model_func(weights=weights) # Remove the last fully connected layer self.resnet = nn.Sequential(*list(self.resnet.children())[:-2]) # Additional pooling to reduce dimensionality further self.pool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling # Number of input features to the first fully connected layer if resnet_type in ["resnet18", "resnet34"]: fc_in_features = 512 else: fc_in_features = 2048 # Simplified fully connected layers with Batch Normalization and Dropout self.fc1 = nn.Linear( fc_in_features, 128 ) # Input features depend on the resnet type self.bn1 = nn.BatchNorm1d(128) # Batch Normalization self.dropout1 = nn.Dropout(0.5) # Helps prevent overfitting self.fc2 = nn.Linear(128, 64) self.bn2 = nn.BatchNorm1d(64) # Batch Normalization self.dropout2 = nn.Dropout(0.5) # Helps prevent overfitting self.fc3 = nn.Linear( 64, num_output_neurons ) # Output layer for binary classification # Set the requires_grad attribute based on the number of trainable layers self.set_trainable_layers(trainable_layers) def set_trainable_layers(self, trainable_layers): # If trainable_layers is 0, freeze all layers if trainable_layers == 0: for param in self.resnet.parameters(): param.requires_grad = False else: # Get the total number of layers in resnet total_layers = len(list(self.resnet.children())) # Make the last `trainable_layers` layers trainable for i, layer in enumerate(self.resnet.children()): if i < total_layers - trainable_layers: for param in layer.parameters(): param.requires_grad = False else: for param in layer.parameters(): param.requires_grad = True def forward(self, x): # Use the ResNet backbone x = self.resnet(x) # Global average pooling x = self.pool(x) # Flattening the output for the dense layer x = x.view(x.size(0), -1) # Adjust this based on the actual output size x = F.relu(self.fc1(x)) x = self.bn1(x) x = self.dropout1(x) x = F.relu(self.fc2(x)) x = self.bn2(x) x = self.dropout2(x) x = self.fc3(x) return x