tttc3 commited on
Commit
fbb7cf7
·
1 Parent(s): bd90cfc

New updated checkpoint loading and tests

Browse files
Files changed (4) hide show
  1. pysr/sr.py +19 -8
  2. test/test.py +15 -12
  3. test/test_jax.py +13 -16
  4. test/test_torch.py +28 -23
pysr/sr.py CHANGED
@@ -1402,7 +1402,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1402
  Xresampled=None,
1403
  weights=None,
1404
  variable_names=None,
1405
- from_equation_file=False,
1406
  ):
1407
  """
1408
  Search for equations to fit the dataset and store them in `self.equations_`.
@@ -1500,18 +1499,21 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1500
  )
1501
 
1502
  # Fitting procedure
1503
- if not from_equation_file:
1504
- self._run(X=X, y=y, weights=weights, seed=seed)
1505
- else:
1506
- self.equations_ = self.get_hof()
1507
-
1508
- return self
1509
 
1510
- def refresh(self):
1511
  """
1512
  Updates self.equations_ with any new options passed, such as
1513
  :param`extra_sympy_mappings`.
 
 
 
 
 
1514
  """
 
 
 
1515
  self.equations_ = self.get_hof()
1516
 
1517
  def _decision_function(self, X, best_equation):
@@ -1709,6 +1711,15 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1709
  def get_hof(self):
1710
  """Get the equations from a hall of fame file. If no arguments
1711
  entered, the ones used previously from a call to PySR will be used."""
 
 
 
 
 
 
 
 
 
1712
  try:
1713
  if self.nout_ > 1:
1714
  all_outputs = []
 
1402
  Xresampled=None,
1403
  weights=None,
1404
  variable_names=None,
 
1405
  ):
1406
  """
1407
  Search for equations to fit the dataset and store them in `self.equations_`.
 
1499
  )
1500
 
1501
  # Fitting procedure
1502
+ return self._run(X=X, y=y, weights=weights, seed=seed)
 
 
 
 
 
1503
 
1504
+ def refresh(self, checkpoint_file=None):
1505
  """
1506
  Updates self.equations_ with any new options passed, such as
1507
  :param`extra_sympy_mappings`.
1508
+
1509
+ Parameters
1510
+ ----------
1511
+ checkpoint_file : str, default=None
1512
+ Path to checkpoint hall of fame file to be loaded.
1513
  """
1514
+ check_is_fitted(self, attributes=["equation_file_"])
1515
+ if checkpoint_file:
1516
+ self.equation_file_ = checkpoint_file
1517
  self.equations_ = self.get_hof()
1518
 
1519
  def _decision_function(self, X, best_equation):
 
1711
  def get_hof(self):
1712
  """Get the equations from a hall of fame file. If no arguments
1713
  entered, the ones used previously from a call to PySR will be used."""
1714
+ check_is_fitted(
1715
+ self,
1716
+ attributes=[
1717
+ "nout_",
1718
+ "equation_file_",
1719
+ "selection_mask_",
1720
+ "feature_names_in_",
1721
+ ],
1722
+ )
1723
  try:
1724
  if self.nout_ > 1:
1725
  all_outputs = []
test/test.py CHANGED
@@ -231,6 +231,17 @@ class TestPipeline(unittest.TestCase):
231
 
232
  class TestBest(unittest.TestCase):
233
  def setUp(self):
 
 
 
 
 
 
 
 
 
 
 
234
  equations = pd.DataFrame(
235
  {
236
  "equation": ["1.0", "cos(x0)", "square(cos(x0))"],
@@ -243,18 +254,6 @@ class TestBest(unittest.TestCase):
243
  "equation_file.csv.bkup", sep="|"
244
  )
245
 
246
- self.model = PySRRegressor(
247
- equation_file="equation_file.csv",
248
- variable_names="x0 x1".split(" "),
249
- extra_sympy_mappings={},
250
- output_jax_format=False,
251
- model_selection="accuracy",
252
- )
253
- self.rstate = np.random.RandomState(0)
254
- # Placeholder values needed to fit the model from an equation file
255
- self.X = self.rstate.randn(10, 2)
256
- self.y = np.cos(self.X[:, 0]) ** 2
257
- self.model.fit(self.X, self.y, from_equation_file=True)
258
  self.model.refresh()
259
  self.equations_ = self.model.equations_
260
 
@@ -308,6 +307,10 @@ class TestFeatureSelection(unittest.TestCase):
308
  class TestMiscellaneous(unittest.TestCase):
309
  """Test miscellaneous functions."""
310
 
 
 
 
 
311
  def test_deprecation(self):
312
  """Ensure that deprecation works as expected.
313
 
 
231
 
232
  class TestBest(unittest.TestCase):
233
  def setUp(self):
234
+ self.rstate = np.random.RandomState(0)
235
+ self.X = self.rstate.randn(10, 2)
236
+ self.y = np.cos(self.X[:, 0]) ** 2
237
+ self.model = PySRRegressor(
238
+ niterations=1,
239
+ extra_sympy_mappings={},
240
+ output_jax_format=False,
241
+ model_selection="accuracy",
242
+ equation_file="equation_file.csv",
243
+ )
244
+ self.model.fit(self.X, self.y)
245
  equations = pd.DataFrame(
246
  {
247
  "equation": ["1.0", "cos(x0)", "square(cos(x0))"],
 
254
  "equation_file.csv.bkup", sep="|"
255
  )
256
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  self.model.refresh()
258
  self.equations_ = self.model.equations_
259
 
 
307
  class TestMiscellaneous(unittest.TestCase):
308
  """Test miscellaneous functions."""
309
 
310
+ def setUp(self):
311
+ # Allows all scikit-learn exception messages to be read.
312
+ self.maxDiff = None
313
+
314
  def test_deprecation(self):
315
  """Ensure that deprecation works as expected.
316
 
test/test_jax.py CHANGED
@@ -23,6 +23,13 @@ class TestJAX(unittest.TestCase):
23
 
24
  def test_pipeline_pandas(self):
25
  X = pd.DataFrame(np.random.randn(100, 10))
 
 
 
 
 
 
 
26
  equations = pd.DataFrame(
27
  {
28
  "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
@@ -35,14 +42,7 @@ class TestJAX(unittest.TestCase):
35
  "equation_file.csv.bkup", sep="|"
36
  )
37
 
38
- model = PySRRegressor(
39
- equation_file="equation_file.csv",
40
- output_jax_format=True,
41
- variable_names="x1 x2 x3".split(" "),
42
- )
43
-
44
- model.fit(X, y=np.ones(X.shape[0]), from_equation_file=True)
45
- model.refresh()
46
  jformat = model.jax()
47
 
48
  np.testing.assert_almost_equal(
@@ -53,6 +53,10 @@ class TestJAX(unittest.TestCase):
53
 
54
  def test_pipeline(self):
55
  X = np.random.randn(100, 10)
 
 
 
 
56
  equations = pd.DataFrame(
57
  {
58
  "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
@@ -65,14 +69,7 @@ class TestJAX(unittest.TestCase):
65
  "equation_file.csv.bkup", sep="|"
66
  )
67
 
68
- model = PySRRegressor(
69
- equation_file="equation_file.csv",
70
- output_jax_format=True,
71
- variable_names="x1 x2 x3".split(" "),
72
- )
73
-
74
- model.fit(X, y=np.ones(X.shape[0]), from_equation_file=True)
75
- model.refresh()
76
  jformat = model.jax()
77
 
78
  np.testing.assert_almost_equal(
 
23
 
24
  def test_pipeline_pandas(self):
25
  X = pd.DataFrame(np.random.randn(100, 10))
26
+ y = np.ones(X.shape[0])
27
+ model = PySRRegressor(
28
+ max_evals=10000,
29
+ output_jax_format=True,
30
+ )
31
+ model.fit(X, y)
32
+
33
  equations = pd.DataFrame(
34
  {
35
  "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
 
42
  "equation_file.csv.bkup", sep="|"
43
  )
44
 
45
+ model.refresh(checkpoint_file="equation_file.csv")
 
 
 
 
 
 
 
46
  jformat = model.jax()
47
 
48
  np.testing.assert_almost_equal(
 
53
 
54
  def test_pipeline(self):
55
  X = np.random.randn(100, 10)
56
+ y = np.ones(X.shape[0])
57
+ model = PySRRegressor(max_evals=10000, output_jax_format=True)
58
+ model.fit(X, y)
59
+
60
  equations = pd.DataFrame(
61
  {
62
  "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
 
69
  "equation_file.csv.bkup", sep="|"
70
  )
71
 
72
+ model.refresh(checkpoint_file="equation_file.csv")
 
 
 
 
 
 
 
73
  jformat = model.jax()
74
 
75
  np.testing.assert_almost_equal(
test/test_torch.py CHANGED
@@ -3,7 +3,6 @@ import numpy as np
3
  import pandas as pd
4
  from pysr import sympy2torch, PySRRegressor
5
  import sympy
6
- from functools import partial
7
 
8
 
9
  class TestTorch(unittest.TestCase):
@@ -15,6 +14,7 @@ class TestTorch(unittest.TestCase):
15
  cosx = 1.0 * sympy.cos(x) + y
16
 
17
  import torch
 
18
  X = torch.tensor(np.random.randn(1000, 3))
19
  true = 1.0 * torch.cos(X[:, 0]) + X[:, 1]
20
  torch_module = sympy2torch(cosx, [x, y, z])
@@ -23,7 +23,6 @@ class TestTorch(unittest.TestCase):
23
  )
24
 
25
  def test_pipeline_pandas(self):
26
- X = pd.DataFrame(np.random.randn(100, 10))
27
  equations = pd.DataFrame(
28
  {
29
  "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
@@ -36,18 +35,16 @@ class TestTorch(unittest.TestCase):
36
  "equation_file.csv.bkup", sep="|"
37
  )
38
 
 
 
39
  model = PySRRegressor(
 
40
  model_selection="accuracy",
41
- equation_file="equation_file.csv",
42
  extra_sympy_mappings={},
43
  output_torch_format=True,
44
  )
45
- # Because a model hasn't been fit via the `fit` method, some
46
- # attributes will not/cannot be set. For the purpose of
47
- # testing, these attributes will be set manually here.
48
- model.fit(X, y=np.ones(X.shape[0]), from_equation_file=True)
49
- model.refresh()
50
-
51
  tformat = model.pytorch()
52
  self.assertEqual(str(tformat), "_SingleSymPyModule(expression=cos(x1)**2)")
53
  import torch
@@ -60,6 +57,14 @@ class TestTorch(unittest.TestCase):
60
 
61
  def test_pipeline(self):
62
  X = np.random.randn(100, 10)
 
 
 
 
 
 
 
 
63
  equations = pd.DataFrame(
64
  {
65
  "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
@@ -72,20 +77,13 @@ class TestTorch(unittest.TestCase):
72
  "equation_file.csv.bkup", sep="|"
73
  )
74
 
75
- model = PySRRegressor(
76
- model_selection="accuracy",
77
- equation_file="equation_file.csv",
78
- extra_sympy_mappings={},
79
- output_torch_format=True,
80
- )
81
-
82
- model.fit(X, y=np.ones(X.shape[0]), from_equation_file=True)
83
- model.refresh()
84
 
85
  tformat = model.pytorch()
86
  self.assertEqual(str(tformat), "_SingleSymPyModule(expression=cos(x1)**2)")
87
 
88
  import torch
 
89
  np.testing.assert_almost_equal(
90
  tformat(torch.tensor(X)).detach().numpy(),
91
  np.square(np.cos(X[:, 1])), # 2nd feature
@@ -99,6 +97,7 @@ class TestTorch(unittest.TestCase):
99
  module = sympy2torch(expression, [x, y, z])
100
 
101
  import torch
 
102
  X = torch.rand(100, 3).float() * 10
103
 
104
  true_out = (
@@ -112,6 +111,13 @@ class TestTorch(unittest.TestCase):
112
 
113
  def test_custom_operator(self):
114
  X = np.random.randn(100, 3)
 
 
 
 
 
 
 
115
 
116
  equations = pd.DataFrame(
117
  {
@@ -126,15 +132,13 @@ class TestTorch(unittest.TestCase):
126
  )
127
 
128
  import torch
129
- model = PySRRegressor(
130
- model_selection="accuracy",
131
  equation_file="equation_file_custom_operator.csv",
132
  extra_sympy_mappings={"mycustomoperator": sympy.sin},
133
  extra_torch_mappings={"mycustomoperator": torch.sin},
134
- output_torch_format=True,
135
  )
136
- model.fit(X, y=np.ones(X.shape[0]), from_equation_file=True)
137
- model.refresh()
138
  self.assertEqual(str(model.sympy()), "sin(x1)")
139
  # Will automatically use the set global state from get_hof.
140
 
@@ -160,6 +164,7 @@ class TestTorch(unittest.TestCase):
160
 
161
  np_output = model.predict(X.values)
162
  import torch
 
163
  torch_output = torch_module(torch.tensor(X.values)).detach().numpy()
164
 
165
  np.testing.assert_almost_equal(np_output, torch_output, decimal=4)
 
3
  import pandas as pd
4
  from pysr import sympy2torch, PySRRegressor
5
  import sympy
 
6
 
7
 
8
  class TestTorch(unittest.TestCase):
 
14
  cosx = 1.0 * sympy.cos(x) + y
15
 
16
  import torch
17
+
18
  X = torch.tensor(np.random.randn(1000, 3))
19
  true = 1.0 * torch.cos(X[:, 0]) + X[:, 1]
20
  torch_module = sympy2torch(cosx, [x, y, z])
 
23
  )
24
 
25
  def test_pipeline_pandas(self):
 
26
  equations = pd.DataFrame(
27
  {
28
  "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
 
35
  "equation_file.csv.bkup", sep="|"
36
  )
37
 
38
+ X = pd.DataFrame(np.random.randn(100, 10))
39
+ y = np.ones(X.shape[0])
40
  model = PySRRegressor(
41
+ max_evals=10000,
42
  model_selection="accuracy",
 
43
  extra_sympy_mappings={},
44
  output_torch_format=True,
45
  )
46
+ model.fit(X, y)
47
+ model.refresh(checkpoint_file="equation_file.csv")
 
 
 
 
48
  tformat = model.pytorch()
49
  self.assertEqual(str(tformat), "_SingleSymPyModule(expression=cos(x1)**2)")
50
  import torch
 
57
 
58
  def test_pipeline(self):
59
  X = np.random.randn(100, 10)
60
+ y = np.ones(X.shape[0])
61
+ model = PySRRegressor(
62
+ max_evals=10000,
63
+ model_selection="accuracy",
64
+ output_torch_format=True,
65
+ )
66
+ model.fit(X, y)
67
+
68
  equations = pd.DataFrame(
69
  {
70
  "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
 
77
  "equation_file.csv.bkup", sep="|"
78
  )
79
 
80
+ model.refresh(checkpoint_file="equation_file.csv")
 
 
 
 
 
 
 
 
81
 
82
  tformat = model.pytorch()
83
  self.assertEqual(str(tformat), "_SingleSymPyModule(expression=cos(x1)**2)")
84
 
85
  import torch
86
+
87
  np.testing.assert_almost_equal(
88
  tformat(torch.tensor(X)).detach().numpy(),
89
  np.square(np.cos(X[:, 1])), # 2nd feature
 
97
  module = sympy2torch(expression, [x, y, z])
98
 
99
  import torch
100
+
101
  X = torch.rand(100, 3).float() * 10
102
 
103
  true_out = (
 
111
 
112
  def test_custom_operator(self):
113
  X = np.random.randn(100, 3)
114
+ y = np.ones(X.shape[0])
115
+ model = PySRRegressor(
116
+ max_evals=10000,
117
+ model_selection="accuracy",
118
+ output_torch_format=True,
119
+ )
120
+ model.fit(X, y)
121
 
122
  equations = pd.DataFrame(
123
  {
 
132
  )
133
 
134
  import torch
135
+
136
+ model.set_params(
137
  equation_file="equation_file_custom_operator.csv",
138
  extra_sympy_mappings={"mycustomoperator": sympy.sin},
139
  extra_torch_mappings={"mycustomoperator": torch.sin},
 
140
  )
141
+ model.refresh(checkpoint_file="equation_file_custom_operator.csv")
 
142
  self.assertEqual(str(model.sympy()), "sin(x1)")
143
  # Will automatically use the set global state from get_hof.
144
 
 
164
 
165
  np_output = model.predict(X.values)
166
  import torch
167
+
168
  torch_output = torch_module(torch.tensor(X.values)).detach().numpy()
169
 
170
  np.testing.assert_almost_equal(np_output, torch_output, decimal=4)