Spaces:
Running
Running
MilesCranmer
commited on
fix: symbolic numbers in torch
Browse files- pysr/export_torch.py +5 -0
- pysr/test/test_torch.py +11 -0
pysr/export_torch.py
CHANGED
@@ -116,6 +116,11 @@ def _initialize_torch():
|
|
116 |
self._value = int(expr)
|
117 |
self._torch_func = lambda: self._value
|
118 |
self._args = ()
|
|
|
|
|
|
|
|
|
|
|
119 |
elif issubclass(expr.func, sympy.Symbol):
|
120 |
self._name = expr.name
|
121 |
self._torch_func = lambda value: value
|
|
|
116 |
self._value = int(expr)
|
117 |
self._torch_func = lambda: self._value
|
118 |
self._args = ()
|
119 |
+
elif issubclass(expr.func, sympy.NumberSymbol):
|
120 |
+
# Can get here from exp(1) or exact pi
|
121 |
+
self._value = float(expr)
|
122 |
+
self._torch_func = lambda: self._value
|
123 |
+
self._args = ()
|
124 |
elif issubclass(expr.func, sympy.Symbol):
|
125 |
self._name = expr.name
|
126 |
self._torch_func = lambda value: value
|
pysr/test/test_torch.py
CHANGED
@@ -173,6 +173,17 @@ class TestTorch(unittest.TestCase):
|
|
173 |
decimal=3,
|
174 |
)
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
def test_feature_selection_custom_operators(self):
|
177 |
rstate = np.random.RandomState(0)
|
178 |
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|
|
|
173 |
decimal=3,
|
174 |
)
|
175 |
|
176 |
+
def test_issue_656(self):
|
177 |
+
# Should correctly map numeric symbols to floats
|
178 |
+
E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
|
179 |
+
m = pysr.export_torch.sympy2torch(E_plus_x1, ["x1"])
|
180 |
+
X = np.random.randn(10, 1)
|
181 |
+
np.testing.assert_almost_equal(
|
182 |
+
m(self.torch.tensor(X)).detach().numpy(),
|
183 |
+
np.exp(1) + X[:, 0],
|
184 |
+
decimal=3,
|
185 |
+
)
|
186 |
+
|
187 |
def test_feature_selection_custom_operators(self):
|
188 |
rstate = np.random.RandomState(0)
|
189 |
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|