Spaces:
Runtime error
Runtime error
File size: 761 Bytes
845deaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from efficientnet_pytorch import EfficientNet
class Model(nn.Module):
"""
Creates an efficientnet-b5 model instance.
"""
def __init__(self, model_name="efficientnet-b5", pool_type=F.adaptive_avg_pool2d):
super().__init__()
self.pool_type = pool_type
self.model_name = model_name
self.backbone = EfficientNet.from_pretrained(model_name)
in_features = getattr(self.backbone, "_fc").in_features
self.classifier = nn.Linear(in_features, 1)
def forward(self, x):
features = self.pool_type(self.backbone.extract_features(x), 1)
features = features.view(x.size(0), -1)
return self.classifier(features)
|