File size: 626 Bytes
42f61ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torchvision

from torch import nn

def create_effnetb0_model(num_classes: int=3, seed: int=42):
    effnetb0_weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
    effnetb0_transforms = effnetb0_weights.transforms()
    effnetb0 = torchvision.models.efficientnet_b0(weights=effnetb0_weights)

    for param in effnetb0.parameters():
        param.required_grad = False
    
    torch.manual_seed(seed)

    effnetb0.classifier = nn.Sequential(
        nn.Dropout(p=0.3, inplace=True),
        nn.Linear(in_features=1280, out_features=num_classes)
    )

    return effnetb0, effnetb0_transforms