Spaces:
Sleeping
Sleeping
# Copyright (c) Alibaba. All rights reserved. | |
import inspect | |
import warnings | |
import functools | |
from functools import partial | |
from typing import Any, Dict, Optional | |
from collections import abc | |
from inspect import getfullargspec | |
def is_seq_of(seq, expected_type, seq_type=None): | |
"""Check whether it is a sequence of some type. | |
Args: | |
seq (Sequence): The sequence to be checked. | |
expected_type (type): Expected type of sequence items. | |
seq_type (type, optional): Expected sequence type. | |
Returns: | |
bool: Whether the sequence is valid. | |
""" | |
if seq_type is None: | |
exp_seq_type = abc.Sequence | |
else: | |
assert isinstance(seq_type, type) | |
exp_seq_type = seq_type | |
if not isinstance(seq, exp_seq_type): | |
return False | |
for item in seq: | |
if not isinstance(item, expected_type): | |
return False | |
return True | |
def deprecated_api_warning(name_dict, cls_name=None): | |
"""A decorator to check if some arguments are deprecate and try to replace | |
deprecate src_arg_name to dst_arg_name. | |
Args: | |
name_dict(dict): | |
key (str): Deprecate argument names. | |
val (str): Expected argument names. | |
Returns: | |
func: New function. | |
""" | |
def api_warning_wrapper(old_func): | |
def new_func(*args, **kwargs): | |
# get the arg spec of the decorated method | |
args_info = getfullargspec(old_func) | |
# get name of the function | |
func_name = old_func.__name__ | |
if cls_name is not None: | |
func_name = f'{cls_name}.{func_name}' | |
if args: | |
arg_names = args_info.args[:len(args)] | |
for src_arg_name, dst_arg_name in name_dict.items(): | |
if src_arg_name in arg_names: | |
warnings.warn( | |
f'"{src_arg_name}" is deprecated in ' | |
f'`{func_name}`, please use "{dst_arg_name}" ' | |
'instead', DeprecationWarning) | |
arg_names[arg_names.index(src_arg_name)] = dst_arg_name | |
if kwargs: | |
for src_arg_name, dst_arg_name in name_dict.items(): | |
if src_arg_name in kwargs: | |
assert dst_arg_name not in kwargs, ( | |
f'The expected behavior is to replace ' | |
f'the deprecated key `{src_arg_name}` to ' | |
f'new key `{dst_arg_name}`, but got them ' | |
f'in the arguments at the same time, which ' | |
f'is confusing. `{src_arg_name} will be ' | |
f'deprecated in the future, please ' | |
f'use `{dst_arg_name}` instead.') | |
warnings.warn( | |
f'"{src_arg_name}" is deprecated in ' | |
f'`{func_name}`, please use "{dst_arg_name}" ' | |
'instead', DeprecationWarning) | |
kwargs[dst_arg_name] = kwargs.pop(src_arg_name) | |
# apply converted arguments to the decorated method | |
output = old_func(*args, **kwargs) | |
return output | |
return new_func | |
return api_warning_wrapper | |
def build_from_cfg(cfg: Dict, | |
registry: 'Registry', | |
default_args: Optional[Dict] = None) -> Any: | |
"""Build a module from config dict when it is a class configuration, or | |
call a function from config dict when it is a function configuration. | |
Example: | |
>>> MODELS = Registry('models') | |
>>> @MODELS.register_module() | |
>>> class ResNet: | |
>>> pass | |
>>> resnet = build_from_cfg(dict(type='Resnet'), MODELS) | |
>>> # Returns an instantiated object | |
>>> @MODELS.register_module() | |
>>> def resnet50(): | |
>>> pass | |
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS) | |
>>> # Return a result of the calling function | |
Args: | |
cfg (dict): Config dict. It should at least contain the key "type". | |
registry (:obj:`Registry`): The registry to search the type from. | |
default_args (dict, optional): Default initialization arguments. | |
Returns: | |
object: The constructed object. | |
""" | |
if not isinstance(cfg, dict): | |
raise TypeError(f'cfg must be a dict, but got {type(cfg)}') | |
if 'type' not in cfg: | |
if default_args is None or 'type' not in default_args: | |
raise KeyError( | |
'`cfg` or `default_args` must contain the key "type", ' | |
f'but got {cfg}\n{default_args}') | |
if not isinstance(registry, Registry): | |
raise TypeError('registry must be an mmcv.Registry object, ' | |
f'but got {type(registry)}') | |
if not (isinstance(default_args, dict) or default_args is None): | |
raise TypeError('default_args must be a dict or None, ' | |
f'but got {type(default_args)}') | |
args = cfg.copy() | |
if default_args is not None: | |
for name, value in default_args.items(): | |
args.setdefault(name, value) | |
obj_type = args.pop('type') | |
if isinstance(obj_type, str): | |
obj_cls = registry.get(obj_type) | |
if obj_cls is None: | |
raise KeyError( | |
f'{obj_type} is not in the {registry.name} registry') | |
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): | |
obj_cls = obj_type | |
else: | |
raise TypeError( | |
f'type must be a str or valid type, but got {type(obj_type)}') | |
try: | |
return obj_cls(**args) | |
except Exception as e: | |
# Normal TypeError does not print class name. | |
raise type(e)(f'{obj_cls.__name__}: {e}') | |
class Registry: | |
"""A registry to map strings to classes or functions. | |
Registered object could be built from registry. Meanwhile, registered | |
functions could be called from registry. | |
Example: | |
>>> MODELS = Registry('models') | |
>>> @MODELS.register_module() | |
>>> class ResNet: | |
>>> pass | |
>>> resnet = MODELS.build(dict(type='ResNet')) | |
>>> @MODELS.register_module() | |
>>> def resnet50(): | |
>>> pass | |
>>> resnet = MODELS.build(dict(type='resnet50')) | |
Please refer to | |
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for | |
advanced usage. | |
Args: | |
name (str): Registry name. | |
build_func(func, optional): Build function to construct instance from | |
Registry, func:`build_from_cfg` is used if neither ``parent`` or | |
``build_func`` is specified. If ``parent`` is specified and | |
``build_func`` is not given, ``build_func`` will be inherited | |
from ``parent``. Default: None. | |
parent (Registry, optional): Parent registry. The class registered in | |
children registry could be built from parent. Default: None. | |
scope (str, optional): The scope of registry. It is the key to search | |
for children registry. If not specified, scope will be the name of | |
the package where class is defined, e.g. mmdet, mmcls, mmseg. | |
Default: None. | |
""" | |
def __init__(self, name, build_func=None, parent=None, scope=None): | |
self._name = name | |
self._module_dict = dict() | |
self._children = dict() | |
self._scope = self.infer_scope() if scope is None else scope | |
# self.build_func will be set with the following priority: | |
# 1. build_func | |
# 2. parent.build_func | |
# 3. build_from_cfg | |
if build_func is None: | |
if parent is not None: | |
self.build_func = parent.build_func | |
else: | |
self.build_func = build_from_cfg | |
else: | |
self.build_func = build_func | |
if parent is not None: | |
assert isinstance(parent, Registry) | |
parent._add_children(self) | |
self.parent = parent | |
else: | |
self.parent = None | |
def __len__(self): | |
return len(self._module_dict) | |
def __contains__(self, key): | |
return self.get(key) is not None | |
def __repr__(self): | |
format_str = self.__class__.__name__ + \ | |
f'(name={self._name}, ' \ | |
f'items={self._module_dict})' | |
return format_str | |
def infer_scope(): | |
"""Infer the scope of registry. | |
The name of the package where registry is defined will be returned. | |
Example: | |
>>> # in mmdet/models/backbone/resnet.py | |
>>> MODELS = Registry('models') | |
>>> @MODELS.register_module() | |
>>> class ResNet: | |
>>> pass | |
The scope of ``ResNet`` will be ``mmdet``. | |
Returns: | |
str: The inferred scope name. | |
""" | |
# We access the caller using inspect.currentframe() instead of | |
# inspect.stack() for performance reasons. See details in PR #1844 | |
frame = inspect.currentframe() | |
# get the frame where `infer_scope()` is called | |
infer_scope_caller = frame.f_back.f_back | |
filename = inspect.getmodule(infer_scope_caller).__name__ | |
split_filename = filename.split('.') | |
return split_filename[0] | |
def split_scope_key(key): | |
"""Split scope and key. | |
The first scope will be split from key. | |
Examples: | |
>>> Registry.split_scope_key('mmdet.ResNet') | |
'mmdet', 'ResNet' | |
>>> Registry.split_scope_key('ResNet') | |
None, 'ResNet' | |
Return: | |
tuple[str | None, str]: The former element is the first scope of | |
the key, which can be ``None``. The latter is the remaining key. | |
""" | |
split_index = key.find('.') | |
if split_index != -1: | |
return key[:split_index], key[split_index + 1:] | |
else: | |
return None, key | |
def name(self): | |
return self._name | |
def scope(self): | |
return self._scope | |
def module_dict(self): | |
return self._module_dict | |
def children(self): | |
return self._children | |
def get(self, key): | |
"""Get the registry record. | |
Args: | |
key (str): The class name in string format. | |
Returns: | |
class: The corresponding class. | |
""" | |
scope, real_key = self.split_scope_key(key) | |
if scope is None or scope == self._scope: | |
# get from self | |
if real_key in self._module_dict: | |
return self._module_dict[real_key] | |
else: | |
# get from self._children | |
if scope in self._children: | |
return self._children[scope].get(real_key) | |
else: | |
# goto root | |
parent = self.parent | |
while parent.parent is not None: | |
parent = parent.parent | |
return parent.get(key) | |
def build(self, *args, **kwargs): | |
return self.build_func(*args, **kwargs, registry=self) | |
def _add_children(self, registry): | |
"""Add children for a registry. | |
The ``registry`` will be added as children based on its scope. | |
The parent registry could build objects from children registry. | |
Example: | |
>>> models = Registry('models') | |
>>> mmdet_models = Registry('models', parent=models) | |
>>> @mmdet_models.register_module() | |
>>> class ResNet: | |
>>> pass | |
>>> resnet = models.build(dict(type='mmdet.ResNet')) | |
""" | |
assert isinstance(registry, Registry) | |
assert registry.scope is not None | |
assert registry.scope not in self.children, \ | |
f'scope {registry.scope} exists in {self.name} registry' | |
self.children[registry.scope] = registry | |
def _register_module(self, module, module_name=None, force=False): | |
if not inspect.isclass(module) and not inspect.isfunction(module): | |
raise TypeError('module must be a class or a function, ' | |
f'but got {type(module)}') | |
if module_name is None: | |
module_name = module.__name__ | |
if isinstance(module_name, str): | |
module_name = [module_name] | |
for name in module_name: | |
if not force and name in self._module_dict: | |
raise KeyError(f'{name} is already registered ' | |
f'in {self.name}') | |
self._module_dict[name] = module | |
def deprecated_register_module(self, cls=None, force=False): | |
warnings.warn( | |
'The old API of register_module(module, force=False) ' | |
'is deprecated and will be removed, please use the new API ' | |
'register_module(name=None, force=False, module=None) instead.', | |
DeprecationWarning) | |
if cls is None: | |
return partial(self.deprecated_register_module, force=force) | |
self._register_module(cls, force=force) | |
return cls | |
def register_module(self, name=None, force=False, module=None): | |
"""Register a module. | |
A record will be added to `self._module_dict`, whose key is the class | |
name or the specified name, and value is the class itself. | |
It can be used as a decorator or a normal function. | |
Example: | |
>>> backbones = Registry('backbone') | |
>>> @backbones.register_module() | |
>>> class ResNet: | |
>>> pass | |
>>> backbones = Registry('backbone') | |
>>> @backbones.register_module(name='mnet') | |
>>> class MobileNet: | |
>>> pass | |
>>> backbones = Registry('backbone') | |
>>> class ResNet: | |
>>> pass | |
>>> backbones.register_module(ResNet) | |
Args: | |
name (str | None): The module name to be registered. If not | |
specified, the class name will be used. | |
force (bool, optional): Whether to override an existing class with | |
the same name. Default: False. | |
module (type): Module class or function to be registered. | |
""" | |
if not isinstance(force, bool): | |
raise TypeError(f'force must be a boolean, but got {type(force)}') | |
# NOTE: This is a walkaround to be compatible with the old api, | |
# while it may introduce unexpected bugs. | |
if isinstance(name, type): | |
return self.deprecated_register_module(name, force=force) | |
# raise the error ahead of time | |
if not (name is None or isinstance(name, str) or is_seq_of(name, str)): | |
raise TypeError( | |
'name must be either of None, an instance of str or a sequence' | |
f' of str, but got {type(name)}') | |
# use it as a normal method: x.register_module(module=SomeClass) | |
if module is not None: | |
self._register_module(module=module, module_name=name, force=force) | |
return module | |
# use it as a decorator: @x.register_module() | |
def _register(module): | |
self._register_module(module=module, module_name=name, force=force) | |
return module | |
return _register |