from .ResNet import * from .VGGNet import * __all__ = ['get_backbone'] def get_backbone(model_name='', pretrained=True, num_classes=None, **kwargs): if 'res' in model_name: model = get_resnet(model_name, pretrained=pretrained, num_classes=num_classes, **kwargs) elif 'vgg' in model_name: model = get_vgg(model_name, pretrained=pretrained, num_classes=num_classes, **kwargs) else: raise NotImplementedError return model