File size: 1,307 Bytes
07423df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from llm_studio.python_configs.text_causal_classification_modeling_config import (
    ConfigProblemBase as CausalClassificationConfigProblemBase,
)
from llm_studio.python_configs.text_causal_language_modeling_config import (
    ConfigProblemBase as CausalConfigProblemBase,
)
from llm_studio.python_configs.text_sequence_to_sequence_modeling_config import (
    ConfigProblemBase as Seq2SeqConfigProblemBase,
)
from llm_studio.src.utils.config_utils import (
    NON_GENERATION_PROBLEM_TYPES,
    convert_cfg_base_to_nested_dictionary,
)


def test_from_dict():
    for cfg_class in [
        CausalConfigProblemBase,
        Seq2SeqConfigProblemBase,
        CausalClassificationConfigProblemBase,
    ]:
        cfg = cfg_class()
        cfg_dict = convert_cfg_base_to_nested_dictionary(cfg)
        cfg2 = cfg_class.from_dict(cfg_dict)  # type: ignore
        cfg_dict_2 = convert_cfg_base_to_nested_dictionary(cfg2)
        for k, v in cfg_dict.items():
            if isinstance(v, dict):
                for k2, v2 in v.items():
                    assert cfg_dict_2[k][k2] == v2
            assert cfg_dict_2[k] == v


def test_classification_config_is_in_non_generating_problem_types():
    cfg = CausalClassificationConfigProblemBase()
    assert cfg.problem_type in NON_GENERATION_PROBLEM_TYPES