MilesCranmer commited on
Commit
2f296b6
1 Parent(s): 5ada6c7

Test that pickle works without equation file

Browse files
Files changed (1) hide show
  1. test/test.py +52 -18
test/test.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import inspect
2
  import unittest
3
  import numpy as np
@@ -8,13 +10,14 @@ from sklearn.utils.estimator_checks import check_estimator
8
  import sympy
9
  import pandas as pd
10
  import warnings
 
 
 
 
 
 
 
11
 
12
- DEFAULT_NITERATIONS = (
13
- inspect.signature(PySRRegressor.__init__).parameters["niterations"].default
14
- )
15
- DEFAULT_POPULATIONS = (
16
- inspect.signature(PySRRegressor.__init__).parameters["populations"].default
17
- )
18
 
19
  class TestPipeline(unittest.TestCase):
20
  def setUp(self):
@@ -399,14 +402,49 @@ class TestMiscellaneous(unittest.TestCase):
399
  with self.assertRaises(ValueError):
400
  model.fit(X, y)
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  def test_scikit_learn_compatibility(self):
403
  """Test PySRRegressor compatibility with scikit-learn."""
404
  model = PySRRegressor(
405
- max_evals=1000,
 
 
406
  verbosity=0,
407
  progress=False,
408
  random_state=0,
409
- deterministic=True,
410
  procs=0,
411
  multithreading=False,
412
  warm_start=False,
@@ -419,20 +457,16 @@ class TestMiscellaneous(unittest.TestCase):
419
  try:
420
  with warnings.catch_warnings():
421
  warnings.simplefilter("ignore")
422
- # To ensure an equation file is written for each output in
423
- # nout, set stop condition to niterations=1
424
- if check.func.__name__ == "check_regressor_multioutput":
425
- model.set_params(niterations=1, max_evals=None)
426
- else:
427
- model.set_params(max_evals=10000)
428
  check(model)
429
  print("Passed", check.func.__name__)
430
- except Exception as e:
431
- error_message = str(e)
432
- exception_messages.append(f"{check.func.__name__}: {error_message}\n")
 
 
433
  print("Failed", check.func.__name__, "with:")
434
  # Add a leading tab to error message, which
435
  # might be multi-line:
436
  print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
437
  # If any checks failed don't let the test pass.
438
- self.assertEqual([], exception_messages)
 
1
+ import os
2
+ import traceback
3
  import inspect
4
  import unittest
5
  import numpy as np
 
10
  import sympy
11
  import pandas as pd
12
  import warnings
13
+ import pickle as pkl
14
+ import tempfile
15
+
16
+ DEFAULT_PARAMS = inspect.signature(PySRRegressor.__init__).parameters
17
+ DEFAULT_NITERATIONS = DEFAULT_PARAMS["niterations"].default
18
+ DEFAULT_POPULATIONS = DEFAULT_PARAMS["populations"].default
19
+ DEFAULT_NCYCLES = DEFAULT_PARAMS["ncyclesperiteration"].default
20
 
 
 
 
 
 
 
21
 
22
  class TestPipeline(unittest.TestCase):
23
  def setUp(self):
 
402
  with self.assertRaises(ValueError):
403
  model.fit(X, y)
404
 
405
+ def test_pickle_with_temp_equation_file(self):
406
+ """If we have a temporary equation file, unpickle the estimator."""
407
+ model = PySRRegressor(
408
+ populations=int(1 + DEFAULT_POPULATIONS / 5),
409
+ temp_equation_file=True,
410
+ procs=0,
411
+ multithreading=False,
412
+ )
413
+ nout = 3
414
+ X = np.random.randn(100, 2)
415
+ y = np.random.randn(100, nout)
416
+ model.fit(X, y)
417
+ contents = model.equation_file_contents_.copy()
418
+
419
+ y_predictions = model.predict(X)
420
+
421
+ equation_file_base = model.equation_file_
422
+ for i in range(1, nout + 1):
423
+ assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
424
+
425
+ with tempfile.NamedTemporaryFile() as pickle_file:
426
+ pkl.dump(model, pickle_file)
427
+ pickle_file.seek(0)
428
+ model2 = pkl.load(pickle_file)
429
+
430
+ contents2 = model2.equation_file_contents_
431
+ cols_to_check = ["equation", "loss", "complexity"]
432
+ for frame1, frame2 in zip(contents, contents2):
433
+ pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
434
+
435
+ y_predictions2 = model2.predict(X)
436
+ np.testing.assert_array_equal(y_predictions, y_predictions2)
437
+
438
  def test_scikit_learn_compatibility(self):
439
  """Test PySRRegressor compatibility with scikit-learn."""
440
  model = PySRRegressor(
441
+ niterations=int(1 + DEFAULT_NITERATIONS / 10),
442
+ populations=int(1 + DEFAULT_POPULATIONS / 3),
443
+ ncyclesperiteration=int(2 + DEFAULT_NCYCLES / 10),
444
  verbosity=0,
445
  progress=False,
446
  random_state=0,
447
+ deterministic=True, # Deterministic as tests require this.
448
  procs=0,
449
  multithreading=False,
450
  warm_start=False,
 
457
  try:
458
  with warnings.catch_warnings():
459
  warnings.simplefilter("ignore")
 
 
 
 
 
 
460
  check(model)
461
  print("Passed", check.func.__name__)
462
+ except Exception:
463
+ error_message = str(traceback.format_exc())
464
+ exception_messages.append(
465
+ f"{check.func.__name__}:\n" + error_message + "\n"
466
+ )
467
  print("Failed", check.func.__name__, "with:")
468
  # Add a leading tab to error message, which
469
  # might be multi-line:
470
  print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
471
  # If any checks failed don't let the test pass.
472
+ self.assertEqual(len(exception_messages), 0)