Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import os | |
import unittest | |
import tempfile | |
from itertools import count | |
from detectron2.config import LazyConfig, LazyCall as L | |
from omegaconf import DictConfig | |
class TestLazyPythonConfig(unittest.TestCase): | |
def setUp(self): | |
self.curr_dir = os.path.dirname(__file__) | |
self.root_filename = os.path.join(self.curr_dir, "root_cfg.py") | |
def test_load(self): | |
cfg = LazyConfig.load(self.root_filename) | |
self.assertEqual(cfg.dir1a_dict.a, "modified") | |
self.assertEqual(cfg.dir1b_dict.a, 1) | |
self.assertEqual(cfg.lazyobj.x, "base_a_1") | |
cfg.lazyobj.x = "new_x" | |
# reload | |
cfg = LazyConfig.load(self.root_filename) | |
self.assertEqual(cfg.lazyobj.x, "base_a_1") | |
def test_save_load(self): | |
cfg = LazyConfig.load(self.root_filename) | |
with tempfile.TemporaryDirectory(prefix="detectron2") as d: | |
fname = os.path.join(d, "test_config.yaml") | |
LazyConfig.save(cfg, fname) | |
cfg2 = LazyConfig.load(fname) | |
self.assertEqual(cfg2.lazyobj._target_, "itertools.count") | |
self.assertEqual(cfg.lazyobj._target_, count) | |
cfg2.lazyobj.pop("_target_") | |
cfg.lazyobj.pop("_target_") | |
# the rest are equal | |
self.assertEqual(cfg, cfg2) | |
def test_failed_save(self): | |
cfg = DictConfig({"x": lambda: 3}, flags={"allow_objects": True}) | |
with tempfile.TemporaryDirectory(prefix="detectron2") as d: | |
fname = os.path.join(d, "test_config.yaml") | |
LazyConfig.save(cfg, fname) | |
self.assertTrue(os.path.exists(fname)) | |
self.assertTrue(os.path.exists(fname + ".pkl")) | |
def test_overrides(self): | |
cfg = LazyConfig.load(self.root_filename) | |
LazyConfig.apply_overrides(cfg, ["lazyobj.x=123", 'dir1b_dict.a="123"']) | |
self.assertEqual(cfg.dir1b_dict.a, "123") | |
self.assertEqual(cfg.lazyobj.x, 123) | |
LazyConfig.apply_overrides(cfg, ["dir1b_dict.a=abc"]) | |
self.assertEqual(cfg.dir1b_dict.a, "abc") | |
def test_invalid_overrides(self): | |
cfg = LazyConfig.load(self.root_filename) | |
with self.assertRaises(KeyError): | |
LazyConfig.apply_overrides(cfg, ["lazyobj.x.xxx=123"]) | |
def test_to_py(self): | |
cfg = LazyConfig.load(self.root_filename) | |
cfg.lazyobj.x = {"a": 1, "b": 2, "c": L(count)(x={"r": "a", "s": 2.4, "t": [1, 2, 3, "z"]})} | |
cfg.list = ["a", 1, "b", 3.2] | |
py_str = LazyConfig.to_py(cfg) | |
expected = """cfg.dir1a_dict.a = "modified" | |
cfg.dir1a_dict.b = 2 | |
cfg.dir1b_dict.a = 1 | |
cfg.dir1b_dict.b = 2 | |
cfg.lazyobj = itertools.count( | |
x={ | |
"a": 1, | |
"b": 2, | |
"c": itertools.count(x={"r": "a", "s": 2.4, "t": [1, 2, 3, "z"]}), | |
}, | |
y="base_a_1_from_b", | |
) | |
cfg.list = ["a", 1, "b", 3.2] | |
""" | |
self.assertEqual(py_str, expected) | |
def test_bad_import(self): | |
file = os.path.join(self.curr_dir, "dir1", "bad_import.py") | |
with self.assertRaisesRegex(ImportError, "relative import"): | |
LazyConfig.load(file) | |
def test_bad_import2(self): | |
file = os.path.join(self.curr_dir, "dir1", "bad_import2.py") | |
with self.assertRaisesRegex(ImportError, "not exist"): | |
LazyConfig.load(file) | |
def test_load_rel(self): | |
file = os.path.join(self.curr_dir, "dir1", "load_rel.py") | |
cfg = LazyConfig.load(file) | |
self.assertIn("x", cfg) | |