MilesCranmer commited on
Commit
c3293a8
·
unverified ·
1 Parent(s): 5a621f9

fix: symbolic numbers in torch

Browse files
Files changed (2) hide show
  1. pysr/export_torch.py +5 -0
  2. pysr/test/test_torch.py +11 -0
pysr/export_torch.py CHANGED
@@ -116,6 +116,11 @@ def _initialize_torch():
116
  self._value = int(expr)
117
  self._torch_func = lambda: self._value
118
  self._args = ()
 
 
 
 
 
119
  elif issubclass(expr.func, sympy.Symbol):
120
  self._name = expr.name
121
  self._torch_func = lambda value: value
 
116
  self._value = int(expr)
117
  self._torch_func = lambda: self._value
118
  self._args = ()
119
+ elif issubclass(expr.func, sympy.NumberSymbol):
120
+ # Can get here from exp(1) or exact pi
121
+ self._value = float(expr)
122
+ self._torch_func = lambda: self._value
123
+ self._args = ()
124
  elif issubclass(expr.func, sympy.Symbol):
125
  self._name = expr.name
126
  self._torch_func = lambda value: value
pysr/test/test_torch.py CHANGED
@@ -173,6 +173,17 @@ class TestTorch(unittest.TestCase):
173
  decimal=3,
174
  )
175
 
 
 
 
 
 
 
 
 
 
 
 
176
  def test_feature_selection_custom_operators(self):
177
  rstate = np.random.RandomState(0)
178
  X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
 
173
  decimal=3,
174
  )
175
 
176
+ def test_issue_656(self):
177
+ # Should correctly map numeric symbols to floats
178
+ E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
179
+ m = pysr.export_torch.sympy2torch(E_plus_x1, ["x1"])
180
+ X = np.random.randn(10, 1)
181
+ np.testing.assert_almost_equal(
182
+ m(self.torch.tensor(X)).detach().numpy(),
183
+ np.exp(1) + X[:, 0],
184
+ decimal=3,
185
+ )
186
+
187
  def test_feature_selection_custom_operators(self):
188
  rstate = np.random.RandomState(0)
189
  X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})