File size: 6,139 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import dataclasses
import importlib
from types import ModuleType
from typing import Any, Dict, List, Type

import yaml

from llm_studio.python_configs.base import DefaultConfigProblemBase
from llm_studio.src.utils.type_annotations import KNOWN_TYPE_ANNOTATIONS


def rreload(module):
    """Recursively reload modules.

    Args:
        module: module to reload
    """

    for attribute_name in dir(module):
        if "Config" in attribute_name:
            attribute1 = getattr(module, attribute_name)
            for attribute_name in dir(attribute1):
                attribute2 = getattr(attribute1, attribute_name)
                if type(attribute2) is ModuleType:
                    importlib.reload(attribute2)


def _load_cls(module_path: str, cls_name: str) -> Any:
    """Loads the python class.

    Args:
        module_path: path to the module
        cls_name: name of the class

    Returns:
        Loaded python class
    """

    module_path_fixed = module_path
    if module_path_fixed.endswith(".py"):
        module_path_fixed = module_path_fixed[:-3]
    module_path_fixed = module_path_fixed.replace("/", ".")

    module = importlib.import_module(module_path_fixed)
    module = importlib.reload(module)
    rreload(module)
    module = importlib.reload(module)

    assert hasattr(module, cls_name), "{} file should contain {} class".format(
        module_path, cls_name
    )

    cls = getattr(module, cls_name)

    return cls


def load_config_py(config_path: str, config_name: str = "Config"):
    """Loads the config class.

    Args:
        config_path: path to the config file
        config_name: name of the config class

    Returns:
        Loaded config class
    """

    return _load_cls(config_path, config_name)()


def _get_type_annotation_error(v: Any, type_annotation: Type) -> ValueError:
    return ValueError(
        f"Cannot show {v}: not a dataclass"
        f" and {type_annotation} is not a known type annotation."
    )


def convert_cfg_base_to_nested_dictionary(cfg: DefaultConfigProblemBase) -> dict:
    """Returns a grouped config settings dict for a given configuration

    Args:
        cfg: configuration
        q: Q

    Returns:
        Dict of configuration settings
    """

    cfg_dict = cfg.__dict__
    type_annotations = cfg.get_annotations()
    cfg_dict = {key: cfg_dict[key] for key in cfg._get_order()}

    grouped_cfg_dict = {}

    for k, v in cfg_dict.items():
        if k.startswith("_"):
            continue

        if any([x in k for x in ["api", "secret"]]):
            raise AssertionError(
                "Config item must not contain the word 'api' or 'secret'"
            )

        type_annotation = type_annotations[k]

        if type_annotation in KNOWN_TYPE_ANNOTATIONS:
            grouped_cfg_dict.update({k: v})
        elif dataclasses.is_dataclass(v):
            group_items = parse_cfg_dataclass(cfg=v)
            group_items = {
                k: list(v) if isinstance(v, tuple) else v
                for d in group_items
                for k, v in d.items()
            }
            grouped_cfg_dict.update({k: group_items})
        else:
            raise _get_type_annotation_error(v, type_annotations[k])

    # not an explicit field in the config
    grouped_cfg_dict["problem_type"] = cfg.problem_type
    return grouped_cfg_dict


def convert_nested_dictionary_to_cfg_base(
    cfg_dict: Dict[str, Any]
) -> DefaultConfigProblemBase:
    """
    Inverse operation of convert_cfg_base_to_nested_dictionary
    """
    problem_type = cfg_dict["problem_type"]
    module_name = f"llm_studio.python_configs.{problem_type}_config"
    try:
        module = importlib.import_module(module_name)
    except ModuleNotFoundError:
        raise NotImplementedError(f"Problem Type {problem_type} not implemented")
    return module.ConfigProblemBase.from_dict(cfg_dict)


def get_parent_element(cfg):
    if hasattr(cfg, "_parent_experiment") and cfg._parent_experiment != "":
        key = "Parent Experiment"
        value = cfg._parent_experiment
        return {key: value}

    return None


def parse_cfg_dataclass(cfg) -> List[Dict]:
    """Returns all single config settings for a given configuration

    Args:
        cfg: configuration
    """

    items = []

    parent_element = get_parent_element(cfg)
    if parent_element:
        items.append(parent_element)

    cfg_dict = cfg.__dict__
    type_annotations = cfg.get_annotations()
    cfg_dict = {key: cfg_dict[key] for key in cfg._get_order()}

    for k, v in cfg_dict.items():
        if k.startswith("_"):
            continue

        if any([x in k for x in ["api"]]):
            continue

        type_annotation = type_annotations[k]

        if type_annotation in KNOWN_TYPE_ANNOTATIONS:
            if type_annotation == float:
                v = float(v)
            t = [{k: v}]
        elif dataclasses.is_dataclass(v):
            elements_group = parse_cfg_dataclass(cfg=v)
            t = elements_group
        else:
            continue

        items += t

    return items


def save_config_yaml(path: str, cfg: DefaultConfigProblemBase) -> None:
    """Saves config as yaml file

    Args:
        path: path of file to save to
        cfg: config to save
    """
    """
    Returns a dictionary representation of the config object.
    Protected attributes (starting with an underscore) are not included.
    Nested configs are converted to nested dictionaries.
    """
    cfg_dict = convert_cfg_base_to_nested_dictionary(cfg)
    with open(path, "w") as fp:
        yaml.dump(cfg_dict, fp, indent=4)


def load_config_yaml(path: str):
    """Loads config from yaml file

    Args:
        path: path of file to load from
    Returns:
        config object
    """
    with open(path, "r") as fp:
        cfg_dict = yaml.load(fp, Loader=yaml.FullLoader)
    return convert_nested_dictionary_to_cfg_base(cfg_dict)


# Note that importing ConfigProblemBase from the python_configs
# and using cfg.problem_type below will not work because of circular imports
NON_GENERATION_PROBLEM_TYPES = ["text_causal_classification_modeling"]