MilesCranmer commited on
Commit
2d025c2
·
unverified ·
1 Parent(s): 42005bd

Add test for units

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +37 -0
pysr/test/test.py CHANGED
@@ -906,6 +906,42 @@ class TestLaTeXTable(unittest.TestCase):
906
  self.assertEqual(latex_table_str, true_latex_table_str)
907
 
908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
909
  def runtests():
910
  """Run all tests in test.py."""
911
  suite = unittest.TestSuite()
@@ -916,6 +952,7 @@ def runtests():
916
  TestFeatureSelection,
917
  TestMiscellaneous,
918
  TestLaTeXTable,
 
919
  ]
920
  for test_case in test_cases:
921
  tests = loader.loadTestsFromTestCase(test_case)
 
906
  self.assertEqual(latex_table_str, true_latex_table_str)
907
 
908
 
909
+ class TestDimensionalConstraints(unittest.TestCase):
910
+ def setUp(self):
911
+ self.default_test_kwargs = dict(
912
+ progress=False,
913
+ model_selection="accuracy",
914
+ niterations=DEFAULT_NITERATIONS * 2,
915
+ populations=DEFAULT_POPULATIONS * 2,
916
+ temp_equation_file=True,
917
+ )
918
+ self.rstate = np.random.RandomState(0)
919
+ self.X = self.rstate.randn(100, 5)
920
+
921
+ def test_dimensional_constraints(self):
922
+ y = np.cos(self.X[:, 0])
923
+ model = PySRRegressor(
924
+ unary_operators=["cos"],
925
+ **self.default_test_kwargs,
926
+ early_stop_condition=1e-8,
927
+ )
928
+ model.fit(self.X, y, X_units=["m", "m", "m", "m", "m"], y_units="m")
929
+
930
+ # The best expression should have complexity larger than just 2:
931
+ self.assertGreater(model.get_best()["complexity"], 2)
932
+ self.assertLess(model.get_best()["loss"], 1e-6)
933
+ self.assertGreater(model.equations_.query("complexity <= 2").loss.min(), 1e-6)
934
+
935
+
936
+ # TODO: add tests for:
937
+ # - custom operators + dimensions
938
+ # - invalid number of dimensions
939
+ # - X
940
+ # - y
941
+ # - no constants, so that it needs to find the right fraction
942
+ # - custom dimensional_constraint_penalty
943
+
944
+
945
  def runtests():
946
  """Run all tests in test.py."""
947
  suite = unittest.TestSuite()
 
952
  TestFeatureSelection,
953
  TestMiscellaneous,
954
  TestLaTeXTable,
955
+ TestDimensionalConstraints,
956
  ]
957
  for test_case in test_cases:
958
  tests = loader.loadTestsFromTestCase(test_case)