h5 / model.py
Diego192's picture
Upload 8 files
3e04925 verified
raw
history blame contribute delete
752 Bytes
import torch
import torch.nn as nn
import torchvision.models as models
from config import Resnet50Config
from transformers import PreTrainedModel
class Resnet50FER(PreTrainedModel):
config_class = Resnet50Config
def __init__(self, config):
super().__init__(config)
# Load the ResNet50 model without the final fully connected layer
self.resnet = models.resnet50(pretrained=False)
num_ftrs = self.resnet.fc.in_features
# Replace the fully connected layer with a new one for your specific classification task
self.resnet.fc = nn.Linear(num_ftrs, config.num_classes)
def forward(self, x):
# Forward pass through the ResNet50 model
x = self.resnet(x)
return x