# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import importlib import os.path as osp from mmengine.config import Config from mmengine.config.utils import (_get_cfg_metainfo, _get_external_cfg_base_path, _get_package_and_cfg_path) from mmengine.registry import MODELS, DefaultScope from mmengine.runner import load_checkpoint from mmengine.utils import get_installed_path, install_package def get_config(cfg_path: str, pretrained: bool = False) -> Config: """Get config from external package. Args: cfg_path (str): External relative config path. pretrained (bool): Whether to save pretrained model path. If ``pretrained==True``, the url of pretrained model can be accessed by ``cfg.model_path``. Defaults to False. Examples: >>> cfg = get_config('mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', pretrained=True) >>> # Equivalent to >>> # cfg = Config.fromfile('/path/to/faster-rcnn_r50_fpn_1x_coco.py') >>> cfg.model_path https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth Returns: Config: A `Config` parsed from external package. """ # noqa E301 # Get package name and relative config path. package, cfg_path = _get_package_and_cfg_path(cfg_path) # Install package if it's not installed. install_package(package) package_path = get_installed_path(package) try: # Use `cfg_path` to search target config file. cfg_meta = _get_cfg_metainfo(package_path, cfg_path) cfg_path = osp.join(package_path, '.mim', cfg_meta['Config']) cfg = Config.fromfile(cfg_path) if pretrained: assert 'Weights' in cfg_meta, ('Cannot find `Weights` in cfg_file' '.metafile.yml, please check the' 'metafile') cfg.model_path = cfg_meta['Weights'] except ValueError: # Since the base config does not contain a metafile, the absolute # config is `osp.join(package_path, cfg_path_prefix, cfg_name)` cfg_path = _get_external_cfg_base_path(package_path, cfg_path) cfg = Config.fromfile(cfg_path) except Exception as e: raise e return cfg def get_model(cfg_path: str, pretrained: bool = False, **kwargs): """Get built model from external package. Args: cfg_path (str): External relative config path with prefix 'package::' and without suffix. pretrained (bool): Whether to load pretrained model. Defaults to False. kwargs (dict): Default arguments to build model. Examples: >>> model = get_model('mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', pretrained=True) >>> type(model) Returns: nn.Module: Built model. """ # noqa E301 package = cfg_path.split('::')[0] with DefaultScope.overwrite_default_scope(package): # type: ignore cfg = get_config(cfg_path, pretrained) if 'data_preprocessor' in cfg: cfg.model.data_preprocessor = cfg.data_preprocessor models_module = importlib.import_module(f'{package}.utils') models_module.register_all_modules() # type: ignore model = MODELS.build(cfg.model, default_args=kwargs) if pretrained: load_checkpoint(model, cfg.model_path) # Hack to use pretrained weights. # If we do not set _is_init here, Runner will call # `model.init_weights()` to overwrite the pretrained model. model._is_init = True return model