import torch import torch.nn as nn import torch.nn.functional as F from efficientnet_pytorch import EfficientNet class Model(nn.Module): """ Creates an efficientnet-b5 model instance. """ def __init__(self, model_name="efficientnet-b5", pool_type=F.adaptive_avg_pool2d): super().__init__() self.pool_type = pool_type self.model_name = model_name self.backbone = EfficientNet.from_pretrained(model_name) in_features = getattr(self.backbone, "_fc").in_features self.classifier = nn.Linear(in_features, 1) def forward(self, x): features = self.pool_type(self.backbone.extract_features(x), 1) features = features.view(x.size(0), -1) return self.classifier(features)