Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
b53e7fa
1
Parent(s):
6501ca0
Add additional test for loading from pickle file
Browse files- pysr/sr.py +6 -2
- test/test.py +10 -0
pysr/sr.py
CHANGED
@@ -926,7 +926,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
926 |
|
927 |
def _checkpoint(self):
|
928 |
"""Saves the model's current state to a checkpoint file.
|
929 |
-
|
930 |
This should only be used internally by PySRRegressor."""
|
931 |
# Save model state:
|
932 |
self.show_pickle_warnings_ = False
|
@@ -2132,8 +2132,12 @@ def load(
|
|
2132 |
assert n_features_in is None
|
2133 |
with open(str(equation_file) + ".pkl", "rb") as f:
|
2134 |
model = pkl.load(f)
|
|
|
|
|
2135 |
model.set_params(**pysr_kwargs)
|
2136 |
-
model.
|
|
|
|
|
2137 |
return model
|
2138 |
|
2139 |
# Else, we re-create it.
|
|
|
926 |
|
927 |
def _checkpoint(self):
|
928 |
"""Saves the model's current state to a checkpoint file.
|
929 |
+
|
930 |
This should only be used internally by PySRRegressor."""
|
931 |
# Save model state:
|
932 |
self.show_pickle_warnings_ = False
|
|
|
2132 |
assert n_features_in is None
|
2133 |
with open(str(equation_file) + ".pkl", "rb") as f:
|
2134 |
model = pkl.load(f)
|
2135 |
+
# Update any parameters if necessary, such as
|
2136 |
+
# extra_sympy_mappings:
|
2137 |
model.set_params(**pysr_kwargs)
|
2138 |
+
if "equations_" not in model.__dict__ or model.equations_ is None:
|
2139 |
+
model.refresh()
|
2140 |
+
|
2141 |
return model
|
2142 |
|
2143 |
# Else, we re-create it.
|
test/test.py
CHANGED
@@ -336,6 +336,16 @@ class TestPipeline(unittest.TestCase):
|
|
336 |
|
337 |
np.testing.assert_allclose(model.predict(self.X), model2.predict(self.X))
|
338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
|
340 |
class TestBest(unittest.TestCase):
|
341 |
def setUp(self):
|
|
|
336 |
|
337 |
np.testing.assert_allclose(model.predict(self.X), model2.predict(self.X))
|
338 |
|
339 |
+
# Try again, but using only the pickle file:
|
340 |
+
for file_to_delete in [str(equation_file), str(equation_file) + ".bkup"]:
|
341 |
+
if os.path.exists(file_to_delete):
|
342 |
+
os.remove(file_to_delete)
|
343 |
+
|
344 |
+
model3 = load(
|
345 |
+
model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
|
346 |
+
)
|
347 |
+
np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))
|
348 |
+
|
349 |
|
350 |
class TestBest(unittest.TestCase):
|
351 |
def setUp(self):
|