File size: 3,819 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry utility."""


def register(registered_collection, reg_key):
  """Register decorated function or class to collection.

  Register decorated function or class into registered_collection, in a
  hierarchical order. For example, when reg_key="my_model/my_exp/my_config_0"
  the decorated function or class is stored under
  registered_collection["my_model"]["my_exp"]["my_config_0"].
  This decorator is supposed to be used together with the lookup() function in
  this file.

  Args:
    registered_collection: a dictionary. The decorated function or class will be
      put into this collection.
    reg_key: The key for retrieving the registered function or class. If reg_key
      is a string, it can be hierarchical like my_model/my_exp/my_config_0
  Returns:
    A decorator function
  Raises:
    KeyError: when function or class to register already exists.
  """
  def decorator(fn_or_cls):
    """Put fn_or_cls in the dictionary."""
    if isinstance(reg_key, str):
      hierarchy = reg_key.split("/")
      collection = registered_collection
      for h_idx, entry_name in enumerate(hierarchy[:-1]):
        if entry_name not in collection:
          collection[entry_name] = {}
        collection = collection[entry_name]
        if not isinstance(collection, dict):
          raise KeyError(
              "Collection path {} at position {} already registered as "
              "a function or class.".format(entry_name, h_idx))
      leaf_reg_key = hierarchy[-1]
    else:
      collection = registered_collection
      leaf_reg_key = reg_key

    if leaf_reg_key in collection:
      raise KeyError("Function or class {} registered multiple times.".format(
          leaf_reg_key))

    collection[leaf_reg_key] = fn_or_cls
    return fn_or_cls
  return decorator


def lookup(registered_collection, reg_key):
  """Lookup and return decorated function or class in the collection.

  Lookup decorated function or class in registered_collection, in a
  hierarchical order. For example, when
  reg_key="my_model/my_exp/my_config_0",
  this function will return
  registered_collection["my_model"]["my_exp"]["my_config_0"].

  Args:
    registered_collection: a dictionary. The decorated function or class will be
      retrieved from this collection.
    reg_key: The key for retrieving the registered function or class. If reg_key
      is a string, it can be hierarchical like my_model/my_exp/my_config_0
  Returns:
    The registered function or class.
  Raises:
    LookupError: when reg_key cannot be found.
  """
  if isinstance(reg_key, str):
    hierarchy = reg_key.split("/")
    collection = registered_collection
    for h_idx, entry_name in enumerate(hierarchy):
      if entry_name not in collection:
        raise LookupError(
            "collection path {} at position {} never registered.".format(
                entry_name, h_idx))
      collection = collection[entry_name]
    return collection
  else:
    if reg_key not in registered_collection:
      raise LookupError("registration key {} never registered.".format(reg_key))
    return registered_collection[reg_key]