MilesCranmer commited on
Commit
7fc7b82
·
unverified ·
2 Parent(s): 06ca0e3 36e5dde

Merge pull request #658 from MilesCranmer/fix-number-symbol

Browse files
Dockerfile CHANGED
@@ -1,8 +1,8 @@
1
  # This builds a dockerfile containing a working copy of PySR
2
  # with all pre-requisites installed.
3
 
4
- ARG JLVERSION=1.9.4
5
- ARG PYVERSION=3.11.6
6
  ARG BASE_IMAGE=bullseye
7
 
8
  FROM julia:${JLVERSION}-${BASE_IMAGE} AS jl
 
1
  # This builds a dockerfile containing a working copy of PySR
2
  # with all pre-requisites installed.
3
 
4
+ ARG JLVERSION=1.10.4
5
+ ARG PYVERSION=3.12.2
6
  ARG BASE_IMAGE=bullseye
7
 
8
  FROM julia:${JLVERSION}-${BASE_IMAGE} AS jl
pysr/export_jax.py CHANGED
@@ -55,7 +55,9 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
55
  if issubclass(expr.func, sympy.Float):
56
  parameters.append(float(expr))
57
  return f"parameters[{len(parameters) - 1}]"
58
- elif issubclass(expr.func, sympy.Rational):
 
 
59
  return f"{float(expr)}"
60
  elif issubclass(expr.func, sympy.Integer):
61
  return f"{int(expr)}"
 
55
  if issubclass(expr.func, sympy.Float):
56
  parameters.append(float(expr))
57
  return f"parameters[{len(parameters) - 1}]"
58
+ elif issubclass(expr.func, sympy.Rational) or issubclass(
59
+ expr.func, sympy.NumberSymbol
60
+ ):
61
  return f"{float(expr)}"
62
  elif issubclass(expr.func, sympy.Integer):
63
  return f"{int(expr)}"
pysr/export_sympy.py CHANGED
@@ -87,7 +87,12 @@ def pysr2sympy(
87
  **sympy_mappings,
88
  }
89
 
90
- return sympify(equation, locals=local_sympy_mappings)
 
 
 
 
 
91
 
92
 
93
  def assert_valid_sympy_symbol(var_name: str) -> None:
 
87
  **sympy_mappings,
88
  }
89
 
90
+ try:
91
+ return sympify(equation, locals=local_sympy_mappings, evaluate=False)
92
+ except TypeError as e:
93
+ if "got an unexpected keyword argument 'evaluate'" in str(e):
94
+ return sympify(equation, locals=local_sympy_mappings)
95
+ raise TypeError(f"Error processing equation '{equation}'") from e
96
 
97
 
98
  def assert_valid_sympy_symbol(var_name: str) -> None:
pysr/export_torch.py CHANGED
@@ -116,6 +116,11 @@ def _initialize_torch():
116
  self._value = int(expr)
117
  self._torch_func = lambda: self._value
118
  self._args = ()
 
 
 
 
 
119
  elif issubclass(expr.func, sympy.Symbol):
120
  self._name = expr.name
121
  self._torch_func = lambda value: value
 
116
  self._value = int(expr)
117
  self._torch_func = lambda: self._value
118
  self._args = ()
119
+ elif issubclass(expr.func, sympy.NumberSymbol):
120
+ # Can get here from exp(1) or exact pi
121
+ self._value = float(expr)
122
+ self._torch_func = lambda: self._value
123
+ self._args = ()
124
  elif issubclass(expr.func, sympy.Symbol):
125
  self._name = expr.name
126
  self._torch_func = lambda value: value
pysr/test/test.py CHANGED
@@ -674,7 +674,7 @@ class TestMiscellaneous(unittest.TestCase):
674
  pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
675
 
676
  y_predictions2 = model2.predict(X)
677
- np.testing.assert_array_equal(y_predictions, y_predictions2)
678
 
679
  def test_scikit_learn_compatibility(self):
680
  """Test PySRRegressor compatibility with scikit-learn."""
@@ -1039,7 +1039,7 @@ class TestLaTeXTable(unittest.TestCase):
1039
  middle_part_2 = r"""
1040
  $y_{1} = x_{1}$ & $1$ & $1.32$ & $0.0$ \\
1041
  $y_{1} = \cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
1042
- $y_{1} = x_{0}^{2} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
1043
  """
1044
  true_latex_table_str = "\n\n".join(
1045
  self.create_true_latex(part, include_score=True)
@@ -1092,7 +1092,7 @@ class TestLaTeXTable(unittest.TestCase):
1092
  middle_part = r"""
1093
  $y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
1094
  $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
1095
- \begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0}^{5} + x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} - 5.20 \sin{\left(2.60 x_{0} - 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
1096
  """
1097
  true_latex_table_str = (
1098
  TRUE_PREAMBLE
 
674
  pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
675
 
676
  y_predictions2 = model2.predict(X)
677
+ np.testing.assert_array_almost_equal(y_predictions, y_predictions2)
678
 
679
  def test_scikit_learn_compatibility(self):
680
  """Test PySRRegressor compatibility with scikit-learn."""
 
1039
  middle_part_2 = r"""
1040
  $y_{1} = x_{1}$ & $1$ & $1.32$ & $0.0$ \\
1041
  $y_{1} = \cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
1042
+ $y_{1} = x_{0} x_{0} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
1043
  """
1044
  true_latex_table_str = "\n\n".join(
1045
  self.create_true_latex(part, include_score=True)
 
1092
  middle_part = r"""
1093
  $y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
1094
  $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
1095
+ \begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0} x_{0} x_{0} + x_{0} x_{0} x_{0} x_{0} x_{0} + 3.20 x_{0} - 1.20 x_{1} + x_{1} x_{1} x_{1} + 5.20 \sin{\left(- 2.60 x_{0} + 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0} x_{0} x_{0} + 3.20 x_{0} - 1.20 x_{1} + x_{1} x_{1} x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
1096
  """
1097
  true_latex_table_str = (
1098
  TRUE_PREAMBLE
pysr/test/test_jax.py CHANGED
@@ -5,27 +5,29 @@ import numpy as np
5
  import pandas as pd
6
  import sympy
7
 
 
8
  from pysr import PySRRegressor, sympy2jax
9
 
10
 
11
  class TestJAX(unittest.TestCase):
12
  def setUp(self):
13
  np.random.seed(0)
 
 
 
14
 
15
  def test_sympy2jax(self):
16
- from jax import numpy as jnp
17
  from jax import random
18
 
19
  x, y, z = sympy.symbols("x y z")
20
  cosx = 1.0 * sympy.cos(x) + y
21
  key = random.PRNGKey(0)
22
  X = random.normal(key, (1000, 2))
23
- true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
24
  f, params = sympy2jax(cosx, [x, y, z])
25
- self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
26
 
27
  def test_pipeline_pandas(self):
28
- from jax import numpy as jnp
29
 
30
  X = pd.DataFrame(np.random.randn(100, 10))
31
  y = np.ones(X.shape[0])
@@ -52,14 +54,12 @@ class TestJAX(unittest.TestCase):
52
  jformat = model.jax()
53
 
54
  np.testing.assert_almost_equal(
55
- np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
56
  np.square(np.cos(X.values[:, 1])), # Select feature 1
57
  decimal=3,
58
  )
59
 
60
  def test_pipeline(self):
61
- from jax import numpy as jnp
62
-
63
  X = np.random.randn(100, 10)
64
  y = np.ones(X.shape[0])
65
  model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
@@ -81,11 +81,39 @@ class TestJAX(unittest.TestCase):
81
  jformat = model.jax()
82
 
83
  np.testing.assert_almost_equal(
84
- np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
85
  np.square(np.cos(X[:, 1])), # Select feature 1
86
  decimal=3,
87
  )
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def test_feature_selection_custom_operators(self):
90
  rstate = np.random.RandomState(0)
91
  X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
 
5
  import pandas as pd
6
  import sympy
7
 
8
+ import pysr
9
  from pysr import PySRRegressor, sympy2jax
10
 
11
 
12
  class TestJAX(unittest.TestCase):
13
  def setUp(self):
14
  np.random.seed(0)
15
+ from jax import numpy as jnp
16
+
17
+ self.jnp = jnp
18
 
19
  def test_sympy2jax(self):
 
20
  from jax import random
21
 
22
  x, y, z = sympy.symbols("x y z")
23
  cosx = 1.0 * sympy.cos(x) + y
24
  key = random.PRNGKey(0)
25
  X = random.normal(key, (1000, 2))
26
+ true = 1.0 * self.jnp.cos(X[:, 0]) + X[:, 1]
27
  f, params = sympy2jax(cosx, [x, y, z])
28
+ self.assertTrue(self.jnp.all(self.jnp.isclose(f(X, params), true)).item())
29
 
30
  def test_pipeline_pandas(self):
 
31
 
32
  X = pd.DataFrame(np.random.randn(100, 10))
33
  y = np.ones(X.shape[0])
 
54
  jformat = model.jax()
55
 
56
  np.testing.assert_almost_equal(
57
+ np.array(jformat["callable"](self.jnp.array(X), jformat["parameters"])),
58
  np.square(np.cos(X.values[:, 1])), # Select feature 1
59
  decimal=3,
60
  )
61
 
62
  def test_pipeline(self):
 
 
63
  X = np.random.randn(100, 10)
64
  y = np.ones(X.shape[0])
65
  model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
 
81
  jformat = model.jax()
82
 
83
  np.testing.assert_almost_equal(
84
+ np.array(jformat["callable"](self.jnp.array(X), jformat["parameters"])),
85
  np.square(np.cos(X[:, 1])), # Select feature 1
86
  decimal=3,
87
  )
88
 
89
+ def test_avoid_simplification(self):
90
+ ex = pysr.export_sympy.pysr2sympy(
91
+ "square(exp(sign(0.44796443))) + 1.5 * x1",
92
+ feature_names_in=["x1"],
93
+ extra_sympy_mappings={"square": lambda x: x**2},
94
+ )
95
+ f, params = pysr.export_jax.sympy2jax(ex, [sympy.symbols("x1")])
96
+ key = np.random.RandomState(0)
97
+ X = key.randn(10, 1)
98
+ np.testing.assert_almost_equal(
99
+ np.array(f(self.jnp.array(X), params)),
100
+ np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0],
101
+ decimal=3,
102
+ )
103
+
104
+ def test_issue_656(self):
105
+ import sympy
106
+
107
+ E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
108
+ f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
109
+ key = np.random.RandomState(0)
110
+ X = key.randn(10, 1)
111
+ np.testing.assert_almost_equal(
112
+ np.array(f(self.jnp.array(X), params)),
113
+ np.exp(1) + X[:, 0],
114
+ decimal=3,
115
+ )
116
+
117
  def test_feature_selection_custom_operators(self):
118
  rstate = np.random.RandomState(0)
119
  X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
pysr/test/test_torch.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import pandas as pd
5
  import sympy
6
 
 
7
  from pysr import PySRRegressor, sympy2torch
8
 
9
 
@@ -153,10 +154,43 @@ class TestTorch(unittest.TestCase):
153
  decimal=3,
154
  )
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def test_feature_selection_custom_operators(self):
157
  rstate = np.random.RandomState(0)
158
  X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
159
- cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
 
 
 
160
  y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
161
 
162
  model = PySRRegressor(
 
4
  import pandas as pd
5
  import sympy
6
 
7
+ import pysr
8
  from pysr import PySRRegressor, sympy2torch
9
 
10
 
 
154
  decimal=3,
155
  )
156
 
157
+ def test_avoid_simplification(self):
158
+ # SymPy should not simplify without permission
159
+ torch = self.torch
160
+ ex = pysr.export_sympy.pysr2sympy(
161
+ "square(exp(sign(0.44796443))) + 1.5 * x1",
162
+ # ^ Normally this would become exp1 and require
163
+ # its own mapping
164
+ feature_names_in=["x1"],
165
+ extra_sympy_mappings={"square": lambda x: x**2},
166
+ )
167
+ m = pysr.export_torch.sympy2torch(ex, ["x1"])
168
+ rng = np.random.RandomState(0)
169
+ X = rng.randn(10, 1)
170
+ np.testing.assert_almost_equal(
171
+ m(torch.tensor(X)).detach().numpy(),
172
+ np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0],
173
+ decimal=3,
174
+ )
175
+
176
+ def test_issue_656(self):
177
+ # Should correctly map numeric symbols to floats
178
+ E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
179
+ m = pysr.export_torch.sympy2torch(E_plus_x1, ["x1"])
180
+ X = np.random.randn(10, 1)
181
+ np.testing.assert_almost_equal(
182
+ m(self.torch.tensor(X)).detach().numpy(),
183
+ np.exp(1) + X[:, 0],
184
+ decimal=3,
185
+ )
186
+
187
  def test_feature_selection_custom_operators(self):
188
  rstate = np.random.RandomState(0)
189
  X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
190
+
191
+ def cos_approx(x):
192
+ return 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
193
+
194
  y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
195
 
196
  model = PySRRegressor(