|
from .clip_models import CLIPModel |
|
from .imagenet_models import ImagenetModel |
|
from .transformer import FeatureTransformer |
|
|
|
|
|
VALID_NAMES = [ |
|
'Imagenet:resnet18', |
|
'Imagenet:resnet34', |
|
'Imagenet:resnet50', |
|
'Imagenet:resnet101', |
|
'Imagenet:resnet152', |
|
'Imagenet:vgg11', |
|
'Imagenet:vgg19', |
|
'Imagenet:swin-b', |
|
'Imagenet:swin-s', |
|
'Imagenet:swin-t', |
|
'Imagenet:vit_b_16', |
|
'Imagenet:vit_b_32', |
|
'Imagenet:vit_l_16', |
|
'Imagenet:vit_l_32', |
|
|
|
'CLIP:RN50', |
|
'CLIP:RN101', |
|
'CLIP:RN50x4', |
|
'CLIP:RN50x16', |
|
'CLIP:RN50x64', |
|
'CLIP:ViT-B/32', |
|
'CLIP:ViT-B/16', |
|
'CLIP:ViT-L/14', |
|
'CLIP:ViT-L/14@336px', |
|
|
|
'FeatureTransformer' |
|
] |
|
|
|
|
|
|
|
|
|
|
|
def get_model(name, **kwargs): |
|
assert name in VALID_NAMES |
|
if name.startswith("Imagenet:"): |
|
return ImagenetModel(name[9:]) |
|
elif name.startswith("CLIP:"): |
|
return CLIPModel(name[5:]) |
|
elif name.startswith("FeatureTransformer"): |
|
return FeatureTransformer(**kwargs) |
|
else: |
|
assert False |
|
|