|
|
|
import torch |
|
import torchvision |
|
from torch import nn |
|
from helper import setAllSeeds |
|
|
|
def getEffNetModel(seed,numClasses): |
|
setAllSeeds(seed) |
|
effNetWeights = torchvision.models.EfficientNet_B2_Weights.DEFAULT |
|
effNetTransforms = effNetWeights.transforms() |
|
effNet = torchvision.models.efficientnet_b2(weights=effNetWeights) |
|
for param in effNet.parameters(): |
|
param.requires_grad = False |
|
effNet.classifier = nn.Sequential( |
|
nn.Dropout(p=0.3,inplace=True), |
|
nn.Linear(1408,numClasses,bias=True) |
|
) |
|
return effNet,effNetTransforms |
|
|