|
|
|
|
|
import collections.abc as abc |
|
import dataclasses |
|
import logging |
|
from typing import Any |
|
|
|
from detectron2.utils.registry import _convert_target_to_string, locate |
|
|
|
__all__ = ["dump_dataclass", "instantiate"] |
|
|
|
|
|
def dump_dataclass(obj: Any): |
|
""" |
|
Dump a dataclass recursively into a dict that can be later instantiated. |
|
|
|
Args: |
|
obj: a dataclass object |
|
|
|
Returns: |
|
dict |
|
""" |
|
assert dataclasses.is_dataclass(obj) and not isinstance( |
|
obj, type |
|
), "dump_dataclass() requires an instance of a dataclass." |
|
ret = {"_target_": _convert_target_to_string(type(obj))} |
|
for f in dataclasses.fields(obj): |
|
v = getattr(obj, f.name) |
|
if dataclasses.is_dataclass(v): |
|
v = dump_dataclass(v) |
|
if isinstance(v, (list, tuple)): |
|
v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v] |
|
ret[f.name] = v |
|
return ret |
|
|
|
|
|
def instantiate(cfg): |
|
""" |
|
Recursively instantiate objects defined in dictionaries by |
|
"_target_" and arguments. |
|
|
|
Args: |
|
cfg: a dict-like object with "_target_" that defines the caller, and |
|
other keys that define the arguments |
|
|
|
Returns: |
|
object instantiated by cfg |
|
""" |
|
from omegaconf import ListConfig, DictConfig, OmegaConf |
|
|
|
if isinstance(cfg, ListConfig): |
|
lst = [instantiate(x) for x in cfg] |
|
return ListConfig(lst, flags={"allow_objects": True}) |
|
if isinstance(cfg, list): |
|
|
|
|
|
return [instantiate(x) for x in cfg] |
|
|
|
|
|
|
|
if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(cfg._metadata.object_type): |
|
return OmegaConf.to_object(cfg) |
|
|
|
if isinstance(cfg, abc.Mapping) and "_target_" in cfg: |
|
|
|
|
|
cfg = {k: instantiate(v) for k, v in cfg.items()} |
|
cls = cfg.pop("_target_") |
|
cls = instantiate(cls) |
|
|
|
if isinstance(cls, str): |
|
cls_name = cls |
|
cls = locate(cls_name) |
|
assert cls is not None, cls_name |
|
else: |
|
try: |
|
cls_name = cls.__module__ + "." + cls.__qualname__ |
|
except Exception: |
|
|
|
cls_name = str(cls) |
|
assert callable(cls), f"_target_ {cls} does not define a callable object" |
|
try: |
|
return cls(**cfg) |
|
except TypeError: |
|
logger = logging.getLogger(__name__) |
|
logger.error(f"Error when instantiating {cls_name}!") |
|
raise |
|
return cfg |
|
|