File size: 2,785 Bytes
9a7fe1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from argparse import Namespace
from typing import Union
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig
REGISTRIES = {}
def setup_registry(registry_name: str,
base_class=None,
default=None,
required=False):
assert registry_name.startswith('--')
registry_name = registry_name[2:].replace('-', '_')
REGISTRY = {}
REGISTRY_CLASS_NAMES = set()
DATACLASS_REGISTRY = {}
# maintain a registry of all registries
if registry_name in REGISTRIES:
return # registry already exists
REGISTRIES[registry_name] = {
'registry': REGISTRY,
'default': default,
'dataclass_registry': DATACLASS_REGISTRY,
}
def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args,
**extra_kwargs):
assert isinstance(cfg, str)
choice = cfg
if choice in DATACLASS_REGISTRY:
cfg = DATACLASS_REGISTRY[choice]()
if choice is None:
if required:
raise ValueError('{} is required!'.format(registry_name))
return None
cls = REGISTRY[choice]
if hasattr(cls, 'build_' + registry_name):
builder = getattr(cls, 'build_' + registry_name)
else:
builder = cls
return builder(cfg, *extra_args, **extra_kwargs)
def register_x(name, dataclass=None):
def register_x_cls(cls):
if name in REGISTRY:
raise ValueError('Cannot register duplicate {} ({})'.format(
registry_name, name))
if cls.__name__ in REGISTRY_CLASS_NAMES:
raise ValueError(
'Cannot register {} with duplicate class name ({})'.format(
registry_name, cls.__name__))
if base_class is not None and not issubclass(cls, base_class):
raise ValueError('{} must extend {}'.format(
cls.__name__, base_class.__name__))
cls.__dataclass = dataclass
if cls.__dataclass is not None:
DATACLASS_REGISTRY[name] = cls.__dataclass
cs = ConfigStore.instance()
node = dataclass()
node._name = name
cs.store(name=name,
group=registry_name,
node=node,
provider='layoutlmft')
REGISTRY[name] = cls
return cls
return register_x_cls
return build_x, register_x, REGISTRY, DATACLASS_REGISTRY
|