MilesCranmer commited on
Commit
db8bfce
1 Parent(s): abd0cfa

Add warm start test

Browse files
Files changed (2) hide show
  1. pysr/sr.py +2 -0
  2. pysr/test/test.py +15 -2
pysr/sr.py CHANGED
@@ -1784,6 +1784,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1784
 
1785
  y_variable_names = None
1786
  if len(y.shape) > 1:
 
 
1787
  y_variable_names = [f"y{_subscriptify(i)}" for i in range(y.shape[1])]
1788
 
1789
  # Call to Julia backend.
 
1784
 
1785
  y_variable_names = None
1786
  if len(y.shape) > 1:
1787
+ # We set these manually so that they respect Python's 0 indexing
1788
+ # (by default Julia will use y1, y2...)
1789
  y_variable_names = [f"y{_subscriptify(i)}" for i in range(y.shape[1])]
1790
 
1791
  # Call to Julia backend.
pysr/test/test.py CHANGED
@@ -1007,14 +1007,17 @@ class TestDimensionalConstraints(unittest.TestCase):
1007
  )
1008
 
1009
  def test_unit_propagation(self):
1010
- """Check that units are propagated correctly."""
 
 
 
1011
  X = np.ones((100, 3))
1012
  y = np.ones((100, 1))
1013
  temp_dir = Path(tempfile.mkdtemp())
1014
  equation_file = str(temp_dir / "equation_file.csv")
1015
  model = PySRRegressor(
1016
  binary_operators=["+", "*"],
1017
- early_stop_condition="(l, c) -> l < 1e-8 && c == 3",
1018
  progress=False,
1019
  model_selection="accuracy",
1020
  niterations=DEFAULT_NITERATIONS * 2,
@@ -1027,6 +1030,7 @@ class TestDimensionalConstraints(unittest.TestCase):
1027
  procs=0,
1028
  random_state=0,
1029
  equation_file=equation_file,
 
1030
  )
1031
  model.fit(
1032
  X,
@@ -1039,6 +1043,8 @@ class TestDimensionalConstraints(unittest.TestCase):
1039
  self.assertNotIn("x1", best["equation"])
1040
  self.assertIn("x2", best["equation"])
1041
  self.assertEqual(best["complexity"], 3)
 
 
1042
 
1043
  # With pkl file:
1044
  pkl_file = str(temp_dir / "equation_file.pkl")
@@ -1055,6 +1061,13 @@ class TestDimensionalConstraints(unittest.TestCase):
1055
  best3 = model3.get_best()
1056
  self.assertIn("x0", best3["equation"])
1057
 
 
 
 
 
 
 
 
1058
 
1059
  # TODO: Determine desired behavior if second .fit() call does not have units
1060
 
 
1007
  )
1008
 
1009
  def test_unit_propagation(self):
1010
+ """Check that units are propagated correctly.
1011
+
1012
+ This also tests that variables have the correct names.
1013
+ """
1014
  X = np.ones((100, 3))
1015
  y = np.ones((100, 1))
1016
  temp_dir = Path(tempfile.mkdtemp())
1017
  equation_file = str(temp_dir / "equation_file.csv")
1018
  model = PySRRegressor(
1019
  binary_operators=["+", "*"],
1020
+ early_stop_condition="(l, c) -> l < 1e-6 && c == 3",
1021
  progress=False,
1022
  model_selection="accuracy",
1023
  niterations=DEFAULT_NITERATIONS * 2,
 
1030
  procs=0,
1031
  random_state=0,
1032
  equation_file=equation_file,
1033
+ warm_start=True,
1034
  )
1035
  model.fit(
1036
  X,
 
1043
  self.assertNotIn("x1", best["equation"])
1044
  self.assertIn("x2", best["equation"])
1045
  self.assertEqual(best["complexity"], 3)
1046
+ self.assertEqual(model.equations_.iloc[0].complexity, 1)
1047
+ self.assertGreater(model.equations_.iloc[0].loss, 1e-6)
1048
 
1049
  # With pkl file:
1050
  pkl_file = str(temp_dir / "equation_file.pkl")
 
1061
  best3 = model3.get_best()
1062
  self.assertIn("x0", best3["equation"])
1063
 
1064
+ # Try warm start, but with no units provided (should
1065
+ # be a different dataset, and thus different result):
1066
+ model.fit(X, y)
1067
+ model.early_stop_condition = "(l, c) -> l < 1e-6 && c == 1"
1068
+ self.assertEqual(model.equations_.iloc[0].complexity, 1)
1069
+ self.assertLess(model.equations_.iloc[0].loss, 1e-6)
1070
+
1071
 
1072
  # TODO: Determine desired behavior if second .fit() call does not have units
1073