File size: 1,041 Bytes
e8e478e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from .clip_models import CLIPModel
from .imagenet_models import ImagenetModel
from .transformer import FeatureTransformer
VALID_NAMES = [
'Imagenet:resnet18',
'Imagenet:resnet34',
'Imagenet:resnet50',
'Imagenet:resnet101',
'Imagenet:resnet152',
'Imagenet:vgg11',
'Imagenet:vgg19',
'Imagenet:swin-b',
'Imagenet:swin-s',
'Imagenet:swin-t',
'Imagenet:vit_b_16',
'Imagenet:vit_b_32',
'Imagenet:vit_l_16',
'Imagenet:vit_l_32',
'CLIP:RN50',
'CLIP:RN101',
'CLIP:RN50x4',
'CLIP:RN50x16',
'CLIP:RN50x64',
'CLIP:ViT-B/32',
'CLIP:ViT-B/16',
'CLIP:ViT-L/14',
'CLIP:ViT-L/14@336px',
'FeatureTransformer'
]
def get_model(name, **kwargs):
assert name in VALID_NAMES
if name.startswith("Imagenet:"):
return ImagenetModel(name[9:])
elif name.startswith("CLIP:"):
return CLIPModel(name[5:])
elif name.startswith("FeatureTransformer"):
return FeatureTransformer(**kwargs)
else:
assert False
|