demo / models /__init__.py
ybbwcwaps
Add FakeVideoDetect
e8e478e
raw
history blame
1.04 kB
from .clip_models import CLIPModel
from .imagenet_models import ImagenetModel
from .transformer import FeatureTransformer
VALID_NAMES = [
'Imagenet:resnet18',
'Imagenet:resnet34',
'Imagenet:resnet50',
'Imagenet:resnet101',
'Imagenet:resnet152',
'Imagenet:vgg11',
'Imagenet:vgg19',
'Imagenet:swin-b',
'Imagenet:swin-s',
'Imagenet:swin-t',
'Imagenet:vit_b_16',
'Imagenet:vit_b_32',
'Imagenet:vit_l_16',
'Imagenet:vit_l_32',
'CLIP:RN50',
'CLIP:RN101',
'CLIP:RN50x4',
'CLIP:RN50x16',
'CLIP:RN50x64',
'CLIP:ViT-B/32',
'CLIP:ViT-B/16',
'CLIP:ViT-L/14',
'CLIP:ViT-L/14@336px',
'FeatureTransformer'
]
def get_model(name, **kwargs):
assert name in VALID_NAMES
if name.startswith("Imagenet:"):
return ImagenetModel(name[9:])
elif name.startswith("CLIP:"):
return CLIPModel(name[5:])
elif name.startswith("FeatureTransformer"):
return FeatureTransformer(**kwargs)
else:
assert False