MilesCranmer commited on
Commit
05cf610
1 Parent(s): 6add8e3

Test sympy export to jax

Browse files
Files changed (2) hide show
  1. .github/workflows/CI.yml +1 -0
  2. test/test.py +14 -2
.github/workflows/CI.yml CHANGED
@@ -58,6 +58,7 @@ jobs:
58
  run: |
59
  python -m pip install --upgrade pip
60
  pip install -r requirements.txt
 
61
  python setup.py install
62
  shell: bash
63
  - name: "Run tests"
 
58
  run: |
59
  python -m pip install --upgrade pip
60
  pip install -r requirements.txt
61
+ pip install jax jaxlib # (optional import)
62
  python setup.py install
63
  shell: bash
64
  - name: "Run tests"
test/test.py CHANGED
@@ -1,5 +1,9 @@
1
  import numpy as np
2
- from pysr import pysr
 
 
 
 
3
  X = np.random.randn(100, 5)
4
 
5
  print("Test 1 - defaults; simple linear relation")
@@ -27,6 +31,14 @@ equations = pysr(X, y,
27
  unary_operators=[], binary_operators=["plus"],
28
  niterations=10,
29
  user_input=False)
30
-
31
  print(equations)
32
  assert equations.iloc[-1]['MSE'] < 1e-4
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
+ from pysr import pysr, sympy2jax
3
+ from jax import numpy as jnp
4
+ from jax import random
5
+ from jax import grad
6
+ import sympy
7
  X = np.random.randn(100, 5)
8
 
9
  print("Test 1 - defaults; simple linear relation")
 
31
  unary_operators=[], binary_operators=["plus"],
32
  niterations=10,
33
  user_input=False)
 
34
  print(equations)
35
  assert equations.iloc[-1]['MSE'] < 1e-4
36
+
37
+ print("Test 4 - text JAX export")
38
+ x, y, z = sympy.symbols('x y z')
39
+ cosx = 1.0 * sympy.cos(x) + y
40
+ key = random.PRNGKey(0)
41
+ X = random.normal(key, (1000, 2))
42
+ true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
43
+ f, params = sympy2jax(cosx, [x])
44
+ assert jnp.all(jnp.isclose(f(X, params), true)).item()