File size: 1,307 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from llm_studio.python_configs.text_causal_classification_modeling_config import (
ConfigProblemBase as CausalClassificationConfigProblemBase,
)
from llm_studio.python_configs.text_causal_language_modeling_config import (
ConfigProblemBase as CausalConfigProblemBase,
)
from llm_studio.python_configs.text_sequence_to_sequence_modeling_config import (
ConfigProblemBase as Seq2SeqConfigProblemBase,
)
from llm_studio.src.utils.config_utils import (
NON_GENERATION_PROBLEM_TYPES,
convert_cfg_base_to_nested_dictionary,
)
def test_from_dict():
for cfg_class in [
CausalConfigProblemBase,
Seq2SeqConfigProblemBase,
CausalClassificationConfigProblemBase,
]:
cfg = cfg_class()
cfg_dict = convert_cfg_base_to_nested_dictionary(cfg)
cfg2 = cfg_class.from_dict(cfg_dict) # type: ignore
cfg_dict_2 = convert_cfg_base_to_nested_dictionary(cfg2)
for k, v in cfg_dict.items():
if isinstance(v, dict):
for k2, v2 in v.items():
assert cfg_dict_2[k][k2] == v2
assert cfg_dict_2[k] == v
def test_classification_config_is_in_non_generating_problem_types():
cfg = CausalClassificationConfigProblemBase()
assert cfg.problem_type in NON_GENERATION_PROBLEM_TYPES
|