Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from efficientnet_pytorch import EfficientNet | |
class Model(nn.Module): | |
""" | |
Creates an efficientnet-b5 model instance. | |
""" | |
def __init__(self, model_name="efficientnet-b5", pool_type=F.adaptive_avg_pool2d): | |
super().__init__() | |
self.pool_type = pool_type | |
self.model_name = model_name | |
self.backbone = EfficientNet.from_pretrained(model_name) | |
in_features = getattr(self.backbone, "_fc").in_features | |
self.classifier = nn.Linear(in_features, 1) | |
def forward(self, x): | |
features = self.pool_type(self.backbone.extract_features(x), 1) | |
features = features.view(x.size(0), -1) | |
return self.classifier(features) | |