File size: 942 Bytes
e8e478e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from .vision_transformer import vit_b_16, vit_b_32, vit_l_16, vit_l_32
from torchvision import transforms
from PIL import Image
import torch
import torch.nn as nn
model_dict = {
'resnet18': resnet18,
'resnet34': resnet34,
'resnet50': resnet50,
'resnet101': resnet101,
'resnet152': resnet152,
'vit_b_16': vit_b_16,
'vit_b_32': vit_b_32,
'vit_l_16': vit_l_16,
'vit_l_32': vit_l_32
}
CHANNELS = {
"resnet50" : 2048,
"vit_b_16" : 768,
}
class ImagenetModel(nn.Module):
def __init__(self, name, num_classes=1):
super(ImagenetModel, self).__init__()
self.model = model_dict[name](pretrained=True)
self.fc = nn.Linear(CHANNELS[name], num_classes) #manually define a fc layer here
def forward(self, x):
feature = self.model(x)["penultimate"]
return self.fc(feature)
|