File size: 548 Bytes
cb1857e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
import torchvision
from torch import nn
from helper import setAllSeeds
def getEffNetModel(seed,numClasses):
setAllSeeds(seed)
effNetWeights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
effNetTransforms = effNetWeights.transforms()
effNet = torchvision.models.efficientnet_b2(weights=effNetWeights)
for param in effNet.parameters():
param.requires_grad = False
effNet.classifier = nn.Sequential(
nn.Dropout(p=0.3,inplace=True),
nn.Linear(1408,numClasses,bias=True)
)
return effNet,effNetTransforms
|