MilesCranmer commited on
Commit
2f528dc
·
unverified ·
1 Parent(s): 505af8d

test: refactor into separate HelpMessages class

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +100 -95
pysr/test/test.py CHANGED
@@ -563,6 +563,105 @@ class TestMiscellaneous(unittest.TestCase):
563
  test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
564
  self.assertEqual(test_pkl_file, str(expected_pkl_file))
565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  def test_deprecation(self):
567
  """Ensure that deprecation works as expected.
568
 
@@ -705,101 +804,6 @@ class TestMiscellaneous(unittest.TestCase):
705
  model.get_best()
706
  print("Failed", opt["kwargs"])
707
 
708
- def test_pickle_with_temp_equation_file(self):
709
- """If we have a temporary equation file, unpickle the estimator."""
710
- model = PySRRegressor(
711
- populations=int(1 + DEFAULT_POPULATIONS / 5),
712
- temp_equation_file=True,
713
- procs=0,
714
- multithreading=False,
715
- )
716
- nout = 3
717
- X = np.random.randn(100, 2)
718
- y = np.random.randn(100, nout)
719
- model.fit(X, y)
720
- contents = model.equation_file_contents_.copy()
721
-
722
- y_predictions = model.predict(X)
723
-
724
- equation_file_base = model.equation_file_
725
- for i in range(1, nout + 1):
726
- assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
727
-
728
- with tempfile.NamedTemporaryFile() as pickle_file:
729
- pkl.dump(model, pickle_file)
730
- pickle_file.seek(0)
731
- model2 = pkl.load(pickle_file)
732
-
733
- contents2 = model2.equation_file_contents_
734
- cols_to_check = ["equation", "loss", "complexity"]
735
- for frame1, frame2 in zip(contents, contents2):
736
- pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
737
-
738
- y_predictions2 = model2.predict(X)
739
- np.testing.assert_array_equal(y_predictions, y_predictions2)
740
-
741
- def test_scikit_learn_compatibility(self):
742
- """Test PySRRegressor compatibility with scikit-learn."""
743
- model = PySRRegressor(
744
- niterations=int(1 + DEFAULT_NITERATIONS / 10),
745
- populations=int(1 + DEFAULT_POPULATIONS / 3),
746
- ncycles_per_iteration=int(2 + DEFAULT_NCYCLES / 10),
747
- verbosity=0,
748
- progress=False,
749
- random_state=0,
750
- deterministic=True, # Deterministic as tests require this.
751
- procs=0,
752
- multithreading=False,
753
- warm_start=False,
754
- temp_equation_file=True,
755
- ) # Return early.
756
-
757
- check_generator = check_estimator(model, generate_only=True)
758
- exception_messages = []
759
- for _, check in check_generator:
760
- if check.func.__name__ == "check_complex_data":
761
- # We can use complex data, so avoid this check.
762
- continue
763
- try:
764
- with warnings.catch_warnings():
765
- warnings.simplefilter("ignore")
766
- check(model)
767
- print("Passed", check.func.__name__)
768
- except Exception:
769
- error_message = str(traceback.format_exc())
770
- exception_messages.append(
771
- f"{check.func.__name__}:\n" + error_message + "\n"
772
- )
773
- print("Failed", check.func.__name__, "with:")
774
- # Add a leading tab to error message, which
775
- # might be multi-line:
776
- print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
777
- # If any checks failed don't let the test pass.
778
- self.assertEqual(len(exception_messages), 0)
779
-
780
- def test_param_groupings(self):
781
- """Test that param_groupings are complete"""
782
- param_groupings_file = Path(__file__).parent.parent / "param_groupings.yml"
783
- if not param_groupings_file.exists():
784
- return
785
-
786
- # Read the file, discarding lines ending in ":",
787
- # and removing leading "\s*-\s*":
788
- params = []
789
- with open(param_groupings_file, "r") as f:
790
- for line in f.readlines():
791
- if line.strip().endswith(":"):
792
- continue
793
- if line.strip().startswith("-"):
794
- params.append(line.strip()[1:].strip())
795
-
796
- regressor_params = [
797
- p for p in DEFAULT_PARAMS.keys() if p not in ["self", "kwargs"]
798
- ]
799
-
800
- # Check the sets are equal:
801
- self.assertSetEqual(set(params), set(regressor_params))
802
-
803
 
804
  TRUE_PREAMBLE = "\n".join(
805
  [
@@ -1148,6 +1152,7 @@ def runtests(just_tests=False):
1148
  TestBest,
1149
  TestFeatureSelection,
1150
  TestMiscellaneous,
 
1151
  TestLaTeXTable,
1152
  TestDimensionalConstraints,
1153
  ]
 
563
  test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
564
  self.assertEqual(test_pkl_file, str(expected_pkl_file))
565
 
566
+ def test_pickle_with_temp_equation_file(self):
567
+ """If we have a temporary equation file, unpickle the estimator."""
568
+ model = PySRRegressor(
569
+ populations=int(1 + DEFAULT_POPULATIONS / 5),
570
+ temp_equation_file=True,
571
+ procs=0,
572
+ multithreading=False,
573
+ )
574
+ nout = 3
575
+ X = np.random.randn(100, 2)
576
+ y = np.random.randn(100, nout)
577
+ model.fit(X, y)
578
+ contents = model.equation_file_contents_.copy()
579
+
580
+ y_predictions = model.predict(X)
581
+
582
+ equation_file_base = model.equation_file_
583
+ for i in range(1, nout + 1):
584
+ assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
585
+
586
+ with tempfile.NamedTemporaryFile() as pickle_file:
587
+ pkl.dump(model, pickle_file)
588
+ pickle_file.seek(0)
589
+ model2 = pkl.load(pickle_file)
590
+
591
+ contents2 = model2.equation_file_contents_
592
+ cols_to_check = ["equation", "loss", "complexity"]
593
+ for frame1, frame2 in zip(contents, contents2):
594
+ pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
595
+
596
+ y_predictions2 = model2.predict(X)
597
+ np.testing.assert_array_equal(y_predictions, y_predictions2)
598
+
599
+ def test_scikit_learn_compatibility(self):
600
+ """Test PySRRegressor compatibility with scikit-learn."""
601
+ model = PySRRegressor(
602
+ niterations=int(1 + DEFAULT_NITERATIONS / 10),
603
+ populations=int(1 + DEFAULT_POPULATIONS / 3),
604
+ ncycles_per_iteration=int(2 + DEFAULT_NCYCLES / 10),
605
+ verbosity=0,
606
+ progress=False,
607
+ random_state=0,
608
+ deterministic=True, # Deterministic as tests require this.
609
+ procs=0,
610
+ multithreading=False,
611
+ warm_start=False,
612
+ temp_equation_file=True,
613
+ ) # Return early.
614
+
615
+ check_generator = check_estimator(model, generate_only=True)
616
+ exception_messages = []
617
+ for _, check in check_generator:
618
+ if check.func.__name__ == "check_complex_data":
619
+ # We can use complex data, so avoid this check.
620
+ continue
621
+ try:
622
+ with warnings.catch_warnings():
623
+ warnings.simplefilter("ignore")
624
+ check(model)
625
+ print("Passed", check.func.__name__)
626
+ except Exception:
627
+ error_message = str(traceback.format_exc())
628
+ exception_messages.append(
629
+ f"{check.func.__name__}:\n" + error_message + "\n"
630
+ )
631
+ print("Failed", check.func.__name__, "with:")
632
+ # Add a leading tab to error message, which
633
+ # might be multi-line:
634
+ print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
635
+ # If any checks failed don't let the test pass.
636
+ self.assertEqual(len(exception_messages), 0)
637
+
638
+ def test_param_groupings(self):
639
+ """Test that param_groupings are complete"""
640
+ param_groupings_file = Path(__file__).parent.parent / "param_groupings.yml"
641
+ if not param_groupings_file.exists():
642
+ return
643
+
644
+ # Read the file, discarding lines ending in ":",
645
+ # and removing leading "\s*-\s*":
646
+ params = []
647
+ with open(param_groupings_file, "r") as f:
648
+ for line in f.readlines():
649
+ if line.strip().endswith(":"):
650
+ continue
651
+ if line.strip().startswith("-"):
652
+ params.append(line.strip()[1:].strip())
653
+
654
+ regressor_params = [
655
+ p for p in DEFAULT_PARAMS.keys() if p not in ["self", "kwargs"]
656
+ ]
657
+
658
+ # Check the sets are equal:
659
+ self.assertSetEqual(set(params), set(regressor_params))
660
+
661
+
662
+ class TestHelpMessages(unittest.TestCase):
663
+ """Test user help messages."""
664
+
665
  def test_deprecation(self):
666
  """Ensure that deprecation works as expected.
667
 
 
804
  model.get_best()
805
  print("Failed", opt["kwargs"])
806
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807
 
808
  TRUE_PREAMBLE = "\n".join(
809
  [
 
1152
  TestBest,
1153
  TestFeatureSelection,
1154
  TestMiscellaneous,
1155
+ TestHelpMessages,
1156
  TestLaTeXTable,
1157
  TestDimensionalConstraints,
1158
  ]