import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
from config import Resnet50Config | |
from transformers import PreTrainedModel | |
class Resnet50FER(PreTrainedModel): | |
config_class = Resnet50Config | |
def __init__(self, config): | |
super().__init__(config) | |
# Load the ResNet50 model without the final fully connected layer | |
self.resnet = models.resnet50(pretrained=False) | |
num_ftrs = self.resnet.fc.in_features | |
# Replace the fully connected layer with a new one for your specific classification task | |
self.resnet.fc = nn.Linear(num_ftrs, config.num_classes) | |
def forward(self, x): | |
# Forward pass through the ResNet50 model | |
x = self.resnet(x) | |
return x |