File size: 761 Bytes
845deaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
import torch.nn.functional as F
from efficientnet_pytorch import EfficientNet


class Model(nn.Module):
    """
    Creates an efficientnet-b5 model instance.    
    """
    def __init__(self, model_name="efficientnet-b5", pool_type=F.adaptive_avg_pool2d):
        super().__init__()
        self.pool_type = pool_type
        self.model_name = model_name
        self.backbone = EfficientNet.from_pretrained(model_name)
        in_features = getattr(self.backbone, "_fc").in_features
        self.classifier = nn.Linear(in_features, 1)

    def forward(self, x):
        features = self.pool_type(self.backbone.extract_features(x), 1)
        features = features.view(x.size(0), -1)
        return self.classifier(features)