MilesCranmer commited on
Commit
f119733
·
1 Parent(s): c9cead8

Fix sympy2torch for rational numbers

Browse files
Files changed (1) hide show
  1. pysr/export_torch.py +5 -0
pysr/export_torch.py CHANGED
@@ -94,6 +94,11 @@ def _initialize_torch():
94
  self._value = torch.nn.Parameter(torch.tensor(float(expr)))
95
  self._torch_func = lambda: self._value
96
  self._args = ()
 
 
 
 
 
97
  elif issubclass(expr.func, sympy.UnevaluatedExpr):
98
  if len(expr.args) != 1 or not issubclass(
99
  expr.args[0].func, sympy.Float
 
94
  self._value = torch.nn.Parameter(torch.tensor(float(expr)))
95
  self._torch_func = lambda: self._value
96
  self._args = ()
97
+ elif issubclass(expr.func, sympy.Rational):
98
+ # This is some fraction fixed in the operator.
99
+ self._value = float(expr)
100
+ self._torch_func = lambda: self._value
101
+ self._args = ()
102
  elif issubclass(expr.func, sympy.UnevaluatedExpr):
103
  if len(expr.args) != 1 or not issubclass(
104
  expr.args[0].func, sympy.Float