Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
22eb380
1
Parent(s):
2621e9c
Add test for pickling + units
Browse files- pysr/test/test.py +25 -1
pysr/test/test.py
CHANGED
@@ -1010,10 +1010,15 @@ class TestDimensionalConstraints(unittest.TestCase):
|
|
1010 |
"""Check that units are propagated correctly."""
|
1011 |
X = np.ones((100, 3))
|
1012 |
y = np.ones((100, 1))
|
|
|
|
|
1013 |
model = PySRRegressor(
|
1014 |
binary_operators=["+", "*"],
|
1015 |
early_stop_condition="(l, c) -> l < 1e-8 && c == 3",
|
1016 |
-
|
|
|
|
|
|
|
1017 |
complexity_of_constants=10,
|
1018 |
weight_mutate_constant=0.0,
|
1019 |
should_optimize_constants=False,
|
@@ -1021,6 +1026,7 @@ class TestDimensionalConstraints(unittest.TestCase):
|
|
1021 |
deterministic=True,
|
1022 |
procs=0,
|
1023 |
random_state=0,
|
|
|
1024 |
)
|
1025 |
model.fit(
|
1026 |
X,
|
@@ -1034,6 +1040,24 @@ class TestDimensionalConstraints(unittest.TestCase):
|
|
1034 |
self.assertIn("x2", best["equation"])
|
1035 |
self.assertEqual(best["complexity"], 3)
|
1036 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1037 |
|
1038 |
def runtests():
|
1039 |
"""Run all tests in test.py."""
|
|
|
1010 |
"""Check that units are propagated correctly."""
|
1011 |
X = np.ones((100, 3))
|
1012 |
y = np.ones((100, 1))
|
1013 |
+
temp_dir = Path(tempfile.mkdtemp())
|
1014 |
+
equation_file = str(temp_dir / "equation_file.csv")
|
1015 |
model = PySRRegressor(
|
1016 |
binary_operators=["+", "*"],
|
1017 |
early_stop_condition="(l, c) -> l < 1e-8 && c == 3",
|
1018 |
+
progress=False,
|
1019 |
+
model_selection="accuracy",
|
1020 |
+
niterations=DEFAULT_NITERATIONS * 2,
|
1021 |
+
populations=DEFAULT_POPULATIONS * 2,
|
1022 |
complexity_of_constants=10,
|
1023 |
weight_mutate_constant=0.0,
|
1024 |
should_optimize_constants=False,
|
|
|
1026 |
deterministic=True,
|
1027 |
procs=0,
|
1028 |
random_state=0,
|
1029 |
+
equation_file=equation_file,
|
1030 |
)
|
1031 |
model.fit(
|
1032 |
X,
|
|
|
1040 |
self.assertIn("x2", best["equation"])
|
1041 |
self.assertEqual(best["complexity"], 3)
|
1042 |
|
1043 |
+
# With pkl file:
|
1044 |
+
pkl_file = str(temp_dir / "equation_file.pkl")
|
1045 |
+
model2 = PySRRegressor.from_file(pkl_file)
|
1046 |
+
best2 = model2.get_best()
|
1047 |
+
self.assertIn("x0", best2["equation"])
|
1048 |
+
|
1049 |
+
# From csv file alone (we need to delete pkl file:)
|
1050 |
+
# First, we delete the pkl file:
|
1051 |
+
os.remove(pkl_file)
|
1052 |
+
model3 = PySRRegressor.from_file(
|
1053 |
+
equation_file, binary_operators=["+", "*"], n_features_in=X.shape[1]
|
1054 |
+
)
|
1055 |
+
best3 = model3.get_best()
|
1056 |
+
self.assertIn("x0", best3["equation"])
|
1057 |
+
|
1058 |
+
|
1059 |
+
# TODO: Determine desired behavior if second .fit() call does not have units
|
1060 |
+
|
1061 |
|
1062 |
def runtests():
|
1063 |
"""Run all tests in test.py."""
|