MilesCranmer commited on
Commit
f2a7a62
1 Parent(s): f1e7133

Add mappings from sympy exact One/Half

Browse files
Files changed (2) hide show
  1. pysr/export_jax.py +3 -0
  2. pysr/export_torch.py +3 -0
pysr/export_jax.py CHANGED
@@ -48,6 +48,9 @@ _jnp_func_lookup = {
48
  sympy.Max: "jnp.max",
49
  sympy.Min: "jnp.min",
50
  sympy.Mod: "jnp.mod",
 
 
 
51
  }
52
 
53
 
 
48
  sympy.Max: "jnp.max",
49
  sympy.Min: "jnp.min",
50
  sympy.Mod: "jnp.mod",
51
+ sympy.Heaviside: "jnp.heaviside",
52
+ sympy.core.numbers.Half: "(lambda: 0.5)",
53
+ sympy.core.numbers.One: "(lambda: 1.0)",
54
  }
55
 
56
 
pysr/export_torch.py CHANGED
@@ -77,6 +77,9 @@ def _initialize_torch():
77
  sympy.Max: torch.max,
78
  sympy.Min: torch.min,
79
  sympy.Mod: torch.remainder,
 
 
 
80
  }
81
 
82
  class _Node(torch.nn.Module):
 
77
  sympy.Max: torch.max,
78
  sympy.Min: torch.min,
79
  sympy.Mod: torch.remainder,
80
+ sympy.Heaviside: torch.heaviside,
81
+ sympy.core.numbers.Half: (lambda: 0.5),
82
+ sympy.core.numbers.One: (lambda: 1.0),
83
  }
84
 
85
  class _Node(torch.nn.Module):