Chirag1994's picture
my first commit
845deaa
raw
history blame
761 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
from efficientnet_pytorch import EfficientNet
class Model(nn.Module):
"""
Creates an efficientnet-b5 model instance.
"""
def __init__(self, model_name="efficientnet-b5", pool_type=F.adaptive_avg_pool2d):
super().__init__()
self.pool_type = pool_type
self.model_name = model_name
self.backbone = EfficientNet.from_pretrained(model_name)
in_features = getattr(self.backbone, "_fc").in_features
self.classifier = nn.Linear(in_features, 1)
def forward(self, x):
features = self.pool_type(self.backbone.extract_features(x), 1)
features = features.view(x.size(0), -1)
return self.classifier(features)