demo / models /imagenet_models.py
ybbwcwaps
Add FakeVideoDetect
e8e478e
raw
history blame
942 Bytes
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from .vision_transformer import vit_b_16, vit_b_32, vit_l_16, vit_l_32
from torchvision import transforms
from PIL import Image
import torch
import torch.nn as nn
model_dict = {
'resnet18': resnet18,
'resnet34': resnet34,
'resnet50': resnet50,
'resnet101': resnet101,
'resnet152': resnet152,
'vit_b_16': vit_b_16,
'vit_b_32': vit_b_32,
'vit_l_16': vit_l_16,
'vit_l_32': vit_l_32
}
CHANNELS = {
"resnet50" : 2048,
"vit_b_16" : 768,
}
class ImagenetModel(nn.Module):
def __init__(self, name, num_classes=1):
super(ImagenetModel, self).__init__()
self.model = model_dict[name](pretrained=True)
self.fc = nn.Linear(CHANNELS[name], num_classes) #manually define a fc layer here
def forward(self, x):
feature = self.model(x)["penultimate"]
return self.fc(feature)