Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
2f296b6
1
Parent(s):
5ada6c7
Test that pickle works without equation file
Browse files- test/test.py +52 -18
test/test.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import inspect
|
2 |
import unittest
|
3 |
import numpy as np
|
@@ -8,13 +10,14 @@ from sklearn.utils.estimator_checks import check_estimator
|
|
8 |
import sympy
|
9 |
import pandas as pd
|
10 |
import warnings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
DEFAULT_NITERATIONS = (
|
13 |
-
inspect.signature(PySRRegressor.__init__).parameters["niterations"].default
|
14 |
-
)
|
15 |
-
DEFAULT_POPULATIONS = (
|
16 |
-
inspect.signature(PySRRegressor.__init__).parameters["populations"].default
|
17 |
-
)
|
18 |
|
19 |
class TestPipeline(unittest.TestCase):
|
20 |
def setUp(self):
|
@@ -399,14 +402,49 @@ class TestMiscellaneous(unittest.TestCase):
|
|
399 |
with self.assertRaises(ValueError):
|
400 |
model.fit(X, y)
|
401 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
def test_scikit_learn_compatibility(self):
|
403 |
"""Test PySRRegressor compatibility with scikit-learn."""
|
404 |
model = PySRRegressor(
|
405 |
-
|
|
|
|
|
406 |
verbosity=0,
|
407 |
progress=False,
|
408 |
random_state=0,
|
409 |
-
deterministic=True,
|
410 |
procs=0,
|
411 |
multithreading=False,
|
412 |
warm_start=False,
|
@@ -419,20 +457,16 @@ class TestMiscellaneous(unittest.TestCase):
|
|
419 |
try:
|
420 |
with warnings.catch_warnings():
|
421 |
warnings.simplefilter("ignore")
|
422 |
-
# To ensure an equation file is written for each output in
|
423 |
-
# nout, set stop condition to niterations=1
|
424 |
-
if check.func.__name__ == "check_regressor_multioutput":
|
425 |
-
model.set_params(niterations=1, max_evals=None)
|
426 |
-
else:
|
427 |
-
model.set_params(max_evals=10000)
|
428 |
check(model)
|
429 |
print("Passed", check.func.__name__)
|
430 |
-
except Exception
|
431 |
-
error_message = str(
|
432 |
-
exception_messages.append(
|
|
|
|
|
433 |
print("Failed", check.func.__name__, "with:")
|
434 |
# Add a leading tab to error message, which
|
435 |
# might be multi-line:
|
436 |
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
437 |
# If any checks failed don't let the test pass.
|
438 |
-
self.assertEqual(
|
|
|
1 |
+
import os
|
2 |
+
import traceback
|
3 |
import inspect
|
4 |
import unittest
|
5 |
import numpy as np
|
|
|
10 |
import sympy
|
11 |
import pandas as pd
|
12 |
import warnings
|
13 |
+
import pickle as pkl
|
14 |
+
import tempfile
|
15 |
+
|
16 |
+
DEFAULT_PARAMS = inspect.signature(PySRRegressor.__init__).parameters
|
17 |
+
DEFAULT_NITERATIONS = DEFAULT_PARAMS["niterations"].default
|
18 |
+
DEFAULT_POPULATIONS = DEFAULT_PARAMS["populations"].default
|
19 |
+
DEFAULT_NCYCLES = DEFAULT_PARAMS["ncyclesperiteration"].default
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
class TestPipeline(unittest.TestCase):
|
23 |
def setUp(self):
|
|
|
402 |
with self.assertRaises(ValueError):
|
403 |
model.fit(X, y)
|
404 |
|
405 |
+
def test_pickle_with_temp_equation_file(self):
|
406 |
+
"""If we have a temporary equation file, unpickle the estimator."""
|
407 |
+
model = PySRRegressor(
|
408 |
+
populations=int(1 + DEFAULT_POPULATIONS / 5),
|
409 |
+
temp_equation_file=True,
|
410 |
+
procs=0,
|
411 |
+
multithreading=False,
|
412 |
+
)
|
413 |
+
nout = 3
|
414 |
+
X = np.random.randn(100, 2)
|
415 |
+
y = np.random.randn(100, nout)
|
416 |
+
model.fit(X, y)
|
417 |
+
contents = model.equation_file_contents_.copy()
|
418 |
+
|
419 |
+
y_predictions = model.predict(X)
|
420 |
+
|
421 |
+
equation_file_base = model.equation_file_
|
422 |
+
for i in range(1, nout + 1):
|
423 |
+
assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
|
424 |
+
|
425 |
+
with tempfile.NamedTemporaryFile() as pickle_file:
|
426 |
+
pkl.dump(model, pickle_file)
|
427 |
+
pickle_file.seek(0)
|
428 |
+
model2 = pkl.load(pickle_file)
|
429 |
+
|
430 |
+
contents2 = model2.equation_file_contents_
|
431 |
+
cols_to_check = ["equation", "loss", "complexity"]
|
432 |
+
for frame1, frame2 in zip(contents, contents2):
|
433 |
+
pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
|
434 |
+
|
435 |
+
y_predictions2 = model2.predict(X)
|
436 |
+
np.testing.assert_array_equal(y_predictions, y_predictions2)
|
437 |
+
|
438 |
def test_scikit_learn_compatibility(self):
|
439 |
"""Test PySRRegressor compatibility with scikit-learn."""
|
440 |
model = PySRRegressor(
|
441 |
+
niterations=int(1 + DEFAULT_NITERATIONS / 10),
|
442 |
+
populations=int(1 + DEFAULT_POPULATIONS / 3),
|
443 |
+
ncyclesperiteration=int(2 + DEFAULT_NCYCLES / 10),
|
444 |
verbosity=0,
|
445 |
progress=False,
|
446 |
random_state=0,
|
447 |
+
deterministic=True, # Deterministic as tests require this.
|
448 |
procs=0,
|
449 |
multithreading=False,
|
450 |
warm_start=False,
|
|
|
457 |
try:
|
458 |
with warnings.catch_warnings():
|
459 |
warnings.simplefilter("ignore")
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
check(model)
|
461 |
print("Passed", check.func.__name__)
|
462 |
+
except Exception:
|
463 |
+
error_message = str(traceback.format_exc())
|
464 |
+
exception_messages.append(
|
465 |
+
f"{check.func.__name__}:\n" + error_message + "\n"
|
466 |
+
)
|
467 |
print("Failed", check.func.__name__, "with:")
|
468 |
# Add a leading tab to error message, which
|
469 |
# might be multi-line:
|
470 |
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
471 |
# If any checks failed don't let the test pass.
|
472 |
+
self.assertEqual(len(exception_messages), 0)
|