from .clip_models import CLIPModel from .imagenet_models import ImagenetModel from .transformer import FeatureTransformer VALID_NAMES = [ 'Imagenet:resnet18', 'Imagenet:resnet34', 'Imagenet:resnet50', 'Imagenet:resnet101', 'Imagenet:resnet152', 'Imagenet:vgg11', 'Imagenet:vgg19', 'Imagenet:swin-b', 'Imagenet:swin-s', 'Imagenet:swin-t', 'Imagenet:vit_b_16', 'Imagenet:vit_b_32', 'Imagenet:vit_l_16', 'Imagenet:vit_l_32', 'CLIP:RN50', 'CLIP:RN101', 'CLIP:RN50x4', 'CLIP:RN50x16', 'CLIP:RN50x64', 'CLIP:ViT-B/32', 'CLIP:ViT-B/16', 'CLIP:ViT-L/14', 'CLIP:ViT-L/14@336px', 'FeatureTransformer' ] def get_model(name, **kwargs): assert name in VALID_NAMES if name.startswith("Imagenet:"): return ImagenetModel(name[9:]) elif name.startswith("CLIP:"): return CLIPModel(name[5:]) elif name.startswith("FeatureTransformer"): return FeatureTransformer(**kwargs) else: assert False