NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
8.96 kB
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base configurations to standardize experiments."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import copy
import functools
from typing import Any, List, Mapping, Optional, Type
import dataclasses
import tensorflow as tf
import yaml
from official.modeling.hyperparams import params_dict
@dataclasses.dataclass
class Config(params_dict.ParamsDict):
"""The base configuration class that supports YAML/JSON based overrides.
* It recursively enforces a whitelist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types.
* It converts dict to Config even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]),
type(config.key[0][0][0]) is Config rather than dict.
"""
# It's safe to add bytes and other immutable types here.
IMMUTABLE_TYPES = (str, int, float, bool, type(None))
# It's safe to add set, frozenset and other collections here.
SEQUENCE_TYPES = (list, tuple)
default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None
restrictions: dataclasses.InitVar[Optional[List[str]]] = None
@classmethod
def _isvalidsequence(cls, v):
"""Check if the input values are valid sequences.
Args:
v: Input sequence.
Returns:
True if the sequence is valid. Valid sequence includes the sequence
type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
is dict or ParamsDict.
"""
if not isinstance(v, cls.SEQUENCE_TYPES):
return False
return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or
all(isinstance(e, dict) for e in v) or
all(isinstance(e, params_dict.ParamsDict) for e in v))
@classmethod
def _import_config(cls, v, subconfig_type):
"""Returns v with dicts converted to Configs, recursively."""
if not issubclass(subconfig_type, params_dict.ParamsDict):
raise TypeError(
'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
subconfig_type))
if isinstance(v, cls.IMMUTABLE_TYPES):
return v
elif isinstance(v, cls.SEQUENCE_TYPES):
# Only support one layer of sequence.
if not cls._isvalidsequence(v):
raise TypeError(
'Invalid sequence: only supports single level {!r} of {!r} or '
'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES,
cls.IMMUTABLE_TYPES, v))
import_fn = functools.partial(
cls._import_config, subconfig_type=subconfig_type)
return type(v)(map(import_fn, v))
elif isinstance(v, params_dict.ParamsDict):
# Deepcopy here is a temporary solution for preserving type in nested
# Config object.
return copy.deepcopy(v)
elif isinstance(v, dict):
return subconfig_type(v)
else:
raise TypeError('Unknown type: {!r}'.format(type(v)))
@classmethod
def _export_config(cls, v):
"""Returns v with Configs converted to dicts, recursively."""
if isinstance(v, cls.IMMUTABLE_TYPES):
return v
elif isinstance(v, cls.SEQUENCE_TYPES):
return type(v)(map(cls._export_config, v))
elif isinstance(v, params_dict.ParamsDict):
return v.as_dict()
elif isinstance(v, dict):
raise TypeError('dict value not supported in converting.')
else:
raise TypeError('Unknown type: {!r}'.format(type(v)))
@classmethod
def _get_subconfig_type(cls, k) -> Type[params_dict.ParamsDict]:
"""Get element type by the field name.
Args:
k: the key/name of the field.
Returns:
Config as default. If a type annotation is found for `k`,
1) returns the type of the annotation if it is subtype of ParamsDict;
2) returns the element type if the annotation of `k` is List[SubType]
or Tuple[SubType].
"""
subconfig_type = Config
if k in cls.__annotations__:
# Directly Config subtype.
type_annotation = cls.__annotations__[k]
if (isinstance(type_annotation, type) and
issubclass(type_annotation, Config)):
subconfig_type = cls.__annotations__[k]
else:
# Check if the field is a sequence of subtypes.
field_type = getattr(type_annotation, '__origin__', type(None))
if (isinstance(field_type, type) and
issubclass(field_type, cls.SEQUENCE_TYPES)):
element_type = getattr(type_annotation, '__args__', [type(None)])[0]
subconfig_type = (
element_type if issubclass(element_type, params_dict.ParamsDict)
else subconfig_type)
return subconfig_type
def __post_init__(self, default_params, restrictions, *args, **kwargs):
super().__init__(default_params=default_params,
restrictions=restrictions,
*args,
**kwargs)
def _set(self, k, v):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
k: key to set.
v: value.
Raises:
RuntimeError
"""
subconfig_type = self._get_subconfig_type(k)
if isinstance(v, dict):
if k not in self.__dict__ or not self.__dict__[k]:
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
self.__dict__[k] = subconfig_type(v)
else:
self.__dict__[k].override(v)
else:
self.__dict__[k] = self._import_config(v, subconfig_type)
def __setattr__(self, k, v):
if k not in self.RESERVED_ATTR:
if getattr(self, '_locked', False):
raise ValueError('The Config has been locked. ' 'No change is allowed.')
self._set(k, v)
def _override(self, override_dict, is_strict=True):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
override_dict: dictionary to write to .
is_strict: If True, not allows to add new keys.
Raises:
KeyError: overriding reserved keys or keys not exist (is_strict=True).
"""
for k, v in sorted(override_dict.items()):
if k in self.RESERVED_ATTR:
raise KeyError('The key {!r} is internally reserved. '
'Can not be overridden.'.format(k))
if k not in self.__dict__:
if is_strict:
raise KeyError('The key {!r} does not exist in {!r}. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'.format(
k, type(self)))
else:
self._set(k, v)
else:
if isinstance(v, dict) and self.__dict__[k]:
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
else:
self._set(k, v)
def as_dict(self):
"""Returns a dict representation of params_dict.ParamsDict.
For the nested params_dict.ParamsDict, a nested dict will be returned.
"""
return {
k: self._export_config(v)
for k, v in self.__dict__.items()
if k not in self.RESERVED_ATTR
}
def replace(self, **kwargs):
"""Like `override`, but returns a copy with the current config unchanged."""
params = self.__class__(self)
params.override(kwargs, is_strict=True)
return params
@classmethod
def from_yaml(cls, file_path: str):
# Note: This only works if the Config has all default values.
with tf.io.gfile.GFile(file_path, 'r') as f:
loaded = yaml.load(f)
config = cls()
config.override(loaded)
return config
@classmethod
def from_json(cls, file_path: str):
"""Wrapper for `from_yaml`."""
return cls.from_yaml(file_path)
@classmethod
def from_args(cls, *args, **kwargs):
"""Builds a config from the given list of arguments."""
attributes = list(cls.__annotations__.keys())
default_params = {a: p for a, p in zip(attributes, args)}
default_params.update(kwargs)
return cls(default_params)