Spaces:
Paused
Paused
# Copyright (c) OpenMMLab. All rights reserved. | |
import inspect | |
import warnings | |
from functools import partial | |
from .misc import is_seq_of | |
def build_from_cfg(cfg, registry, default_args=None): | |
"""Build a module from config dict. | |
Args: | |
cfg (dict): Config dict. It should at least contain the key "type". | |
registry (:obj:`Registry`): The registry to search the type from. | |
default_args (dict, optional): Default initialization arguments. | |
Returns: | |
object: The constructed object. | |
""" | |
if not isinstance(cfg, dict): | |
raise TypeError(f'cfg must be a dict, but got {type(cfg)}') | |
if 'type' not in cfg: | |
if default_args is None or 'type' not in default_args: | |
raise KeyError( | |
'`cfg` or `default_args` must contain the key "type", ' | |
f'but got {cfg}\n{default_args}') | |
if not isinstance(registry, Registry): | |
raise TypeError('registry must be an mmcv.Registry object, ' | |
f'but got {type(registry)}') | |
if not (isinstance(default_args, dict) or default_args is None): | |
raise TypeError('default_args must be a dict or None, ' | |
f'but got {type(default_args)}') | |
args = cfg.copy() | |
if default_args is not None: | |
for name, value in default_args.items(): | |
args.setdefault(name, value) | |
obj_type = args.pop('type') | |
if isinstance(obj_type, str): | |
obj_cls = registry.get(obj_type) | |
if obj_cls is None: | |
raise KeyError( | |
f'{obj_type} is not in the {registry.name} registry') | |
elif inspect.isclass(obj_type): | |
obj_cls = obj_type | |
else: | |
raise TypeError( | |
f'type must be a str or valid type, but got {type(obj_type)}') | |
try: | |
return obj_cls(**args) | |
except Exception as e: | |
# Normal TypeError does not print class name. | |
raise type(e)(f'{obj_cls.__name__}: {e}') | |
class Registry: | |
"""A registry to map strings to classes. | |
Registered object could be built from registry. | |
Example: | |
>>> MODELS = Registry('models') | |
>>> @MODELS.register_module() | |
>>> class ResNet: | |
>>> pass | |
>>> resnet = MODELS.build(dict(type='ResNet')) | |
Please refer to | |
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for | |
advanced usage. | |
Args: | |
name (str): Registry name. | |
build_func(func, optional): Build function to construct instance from | |
Registry, func:`build_from_cfg` is used if neither ``parent`` or | |
``build_func`` is specified. If ``parent`` is specified and | |
``build_func`` is not given, ``build_func`` will be inherited | |
from ``parent``. Default: None. | |
parent (Registry, optional): Parent registry. The class registered in | |
children registry could be built from parent. Default: None. | |
scope (str, optional): The scope of registry. It is the key to search | |
for children registry. If not specified, scope will be the name of | |
the package where class is defined, e.g. mmdet, mmcls, mmseg. | |
Default: None. | |
""" | |
def __init__(self, name, build_func=None, parent=None, scope=None): | |
self._name = name | |
self._module_dict = dict() | |
self._children = dict() | |
self._scope = self.infer_scope() if scope is None else scope | |
# self.build_func will be set with the following priority: | |
# 1. build_func | |
# 2. parent.build_func | |
# 3. build_from_cfg | |
if build_func is None: | |
if parent is not None: | |
self.build_func = parent.build_func | |
else: | |
self.build_func = build_from_cfg | |
else: | |
self.build_func = build_func | |
if parent is not None: | |
assert isinstance(parent, Registry) | |
parent._add_children(self) | |
self.parent = parent | |
else: | |
self.parent = None | |
def __len__(self): | |
return len(self._module_dict) | |
def __contains__(self, key): | |
return self.get(key) is not None | |
def __repr__(self): | |
format_str = self.__class__.__name__ + \ | |
f'(name={self._name}, ' \ | |
f'items={self._module_dict})' | |
return format_str | |
def infer_scope(): | |
"""Infer the scope of registry. | |
The name of the package where registry is defined will be returned. | |
Example: | |
# in mmdet/models/backbone/resnet.py | |
>>> MODELS = Registry('models') | |
>>> @MODELS.register_module() | |
>>> class ResNet: | |
>>> pass | |
The scope of ``ResNet`` will be ``mmdet``. | |
Returns: | |
scope (str): The inferred scope name. | |
""" | |
# inspect.stack() trace where this function is called, the index-2 | |
# indicates the frame where `infer_scope()` is called | |
filename = inspect.getmodule(inspect.stack()[2][0]).__name__ | |
split_filename = filename.split('.') | |
return split_filename[0] | |
def split_scope_key(key): | |
"""Split scope and key. | |
The first scope will be split from key. | |
Examples: | |
>>> Registry.split_scope_key('mmdet.ResNet') | |
'mmdet', 'ResNet' | |
>>> Registry.split_scope_key('ResNet') | |
None, 'ResNet' | |
Return: | |
scope (str, None): The first scope. | |
key (str): The remaining key. | |
""" | |
split_index = key.find('.') | |
if split_index != -1: | |
return key[:split_index], key[split_index + 1:] | |
else: | |
return None, key | |
def name(self): | |
return self._name | |
def scope(self): | |
return self._scope | |
def module_dict(self): | |
return self._module_dict | |
def children(self): | |
return self._children | |
def get(self, key): | |
"""Get the registry record. | |
Args: | |
key (str): The class name in string format. | |
Returns: | |
class: The corresponding class. | |
""" | |
scope, real_key = self.split_scope_key(key) | |
if scope is None or scope == self._scope: | |
# get from self | |
if real_key in self._module_dict: | |
return self._module_dict[real_key] | |
else: | |
# get from self._children | |
if scope in self._children: | |
return self._children[scope].get(real_key) | |
else: | |
# goto root | |
parent = self.parent | |
while parent.parent is not None: | |
parent = parent.parent | |
return parent.get(key) | |
def build(self, *args, **kwargs): | |
return self.build_func(*args, **kwargs, registry=self) | |
def _add_children(self, registry): | |
"""Add children for a registry. | |
The ``registry`` will be added as children based on its scope. | |
The parent registry could build objects from children registry. | |
Example: | |
>>> models = Registry('models') | |
>>> mmdet_models = Registry('models', parent=models) | |
>>> @mmdet_models.register_module() | |
>>> class ResNet: | |
>>> pass | |
>>> resnet = models.build(dict(type='mmdet.ResNet')) | |
""" | |
assert isinstance(registry, Registry) | |
assert registry.scope is not None | |
assert registry.scope not in self.children, \ | |
f'scope {registry.scope} exists in {self.name} registry' | |
self.children[registry.scope] = registry | |
def _register_module(self, module_class, module_name=None, force=False): | |
if not inspect.isclass(module_class): | |
raise TypeError('module must be a class, ' | |
f'but got {type(module_class)}') | |
if module_name is None: | |
module_name = module_class.__name__ | |
if isinstance(module_name, str): | |
module_name = [module_name] | |
for name in module_name: | |
if not force and name in self._module_dict: | |
raise KeyError(f'{name} is already registered ' | |
f'in {self.name}') | |
self._module_dict[name] = module_class | |
def deprecated_register_module(self, cls=None, force=False): | |
warnings.warn( | |
'The old API of register_module(module, force=False) ' | |
'is deprecated and will be removed, please use the new API ' | |
'register_module(name=None, force=False, module=None) instead.') | |
if cls is None: | |
return partial(self.deprecated_register_module, force=force) | |
self._register_module(cls, force=force) | |
return cls | |
def register_module(self, name=None, force=False, module=None): | |
"""Register a module. | |
A record will be added to `self._module_dict`, whose key is the class | |
name or the specified name, and value is the class itself. | |
It can be used as a decorator or a normal function. | |
Example: | |
>>> backbones = Registry('backbone') | |
>>> @backbones.register_module() | |
>>> class ResNet: | |
>>> pass | |
>>> backbones = Registry('backbone') | |
>>> @backbones.register_module(name='mnet') | |
>>> class MobileNet: | |
>>> pass | |
>>> backbones = Registry('backbone') | |
>>> class ResNet: | |
>>> pass | |
>>> backbones.register_module(ResNet) | |
Args: | |
name (str | None): The module name to be registered. If not | |
specified, the class name will be used. | |
force (bool, optional): Whether to override an existing class with | |
the same name. Default: False. | |
module (type): Module class to be registered. | |
""" | |
if not isinstance(force, bool): | |
raise TypeError(f'force must be a boolean, but got {type(force)}') | |
# NOTE: This is a walkaround to be compatible with the old api, | |
# while it may introduce unexpected bugs. | |
if isinstance(name, type): | |
return self.deprecated_register_module(name, force=force) | |
# raise the error ahead of time | |
if not (name is None or isinstance(name, str) or is_seq_of(name, str)): | |
raise TypeError( | |
'name must be either of None, an instance of str or a sequence' | |
f' of str, but got {type(name)}') | |
# use it as a normal method: x.register_module(module=SomeClass) | |
if module is not None: | |
self._register_module( | |
module_class=module, module_name=name, force=force) | |
return module | |
# use it as a decorator: @x.register_module() | |
def _register(cls): | |
self._register_module( | |
module_class=cls, module_name=name, force=force) | |
return cls | |
return _register | |