MilesCranmer commited on
Commit
ce904eb
1 Parent(s): af0be92

Remove need for YAML in unittests

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +9 -17
pysr/test/test.py CHANGED
@@ -10,7 +10,6 @@ import pandas as pd
10
  import warnings
11
  import pickle as pkl
12
  import tempfile
13
- import yaml
14
  from pathlib import Path
15
 
16
  from .. import julia_helpers
@@ -718,29 +717,22 @@ class TestMiscellaneous(unittest.TestCase):
718
  param_groupings_file = (
719
  Path(__file__).parent.parent.parent / "docs" / "param_groupings.yml"
720
  )
721
- # Read the file:
 
 
722
  with open(param_groupings_file, "r") as f:
723
- param_groupings = yaml.load(f, Loader=yaml.SafeLoader)
724
-
725
- # Get all leafs of this yaml file:
726
- def get_leafs(d):
727
- if isinstance(d, dict):
728
- for v in d.values():
729
- yield from get_leafs(v)
730
- elif isinstance(d, list):
731
- for v in d:
732
- yield from get_leafs(v)
733
- else:
734
- yield d
735
-
736
- leafs = list(get_leafs(param_groupings))
737
 
738
  regressor_params = [
739
  p for p in DEFAULT_PARAMS.keys() if p not in ["self", "kwargs"]
740
  ]
741
 
742
  # Check the sets are equal:
743
- self.assertSetEqual(set(leafs), set(regressor_params))
744
 
745
 
746
  TRUE_PREAMBLE = "\n".join(
 
10
  import warnings
11
  import pickle as pkl
12
  import tempfile
 
13
  from pathlib import Path
14
 
15
  from .. import julia_helpers
 
717
  param_groupings_file = (
718
  Path(__file__).parent.parent.parent / "docs" / "param_groupings.yml"
719
  )
720
+ # Read the file, discarding lines ending in ":",
721
+ # and removing leading "\s*-\s*":
722
+ params = []
723
  with open(param_groupings_file, "r") as f:
724
+ for line in f.readlines():
725
+ if line.strip().endswith(":"):
726
+ continue
727
+ if line.strip().startswith("-"):
728
+ params.append(line.strip()[1:].strip())
 
 
 
 
 
 
 
 
 
729
 
730
  regressor_params = [
731
  p for p in DEFAULT_PARAMS.keys() if p not in ["self", "kwargs"]
732
  ]
733
 
734
  # Check the sets are equal:
735
+ self.assertSetEqual(set(params), set(regressor_params))
736
 
737
 
738
  TRUE_PREAMBLE = "\n".join(