Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
05cf610
1
Parent(s):
6add8e3
Test sympy export to jax
Browse files- .github/workflows/CI.yml +1 -0
- test/test.py +14 -2
.github/workflows/CI.yml
CHANGED
@@ -58,6 +58,7 @@ jobs:
|
|
58 |
run: |
|
59 |
python -m pip install --upgrade pip
|
60 |
pip install -r requirements.txt
|
|
|
61 |
python setup.py install
|
62 |
shell: bash
|
63 |
- name: "Run tests"
|
|
|
58 |
run: |
|
59 |
python -m pip install --upgrade pip
|
60 |
pip install -r requirements.txt
|
61 |
+
pip install jax jaxlib # (optional import)
|
62 |
python setup.py install
|
63 |
shell: bash
|
64 |
- name: "Run tests"
|
test/test.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
import numpy as np
|
2 |
-
from pysr import pysr
|
|
|
|
|
|
|
|
|
3 |
X = np.random.randn(100, 5)
|
4 |
|
5 |
print("Test 1 - defaults; simple linear relation")
|
@@ -27,6 +31,14 @@ equations = pysr(X, y,
|
|
27 |
unary_operators=[], binary_operators=["plus"],
|
28 |
niterations=10,
|
29 |
user_input=False)
|
30 |
-
|
31 |
print(equations)
|
32 |
assert equations.iloc[-1]['MSE'] < 1e-4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
+
from pysr import pysr, sympy2jax
|
3 |
+
from jax import numpy as jnp
|
4 |
+
from jax import random
|
5 |
+
from jax import grad
|
6 |
+
import sympy
|
7 |
X = np.random.randn(100, 5)
|
8 |
|
9 |
print("Test 1 - defaults; simple linear relation")
|
|
|
31 |
unary_operators=[], binary_operators=["plus"],
|
32 |
niterations=10,
|
33 |
user_input=False)
|
|
|
34 |
print(equations)
|
35 |
assert equations.iloc[-1]['MSE'] < 1e-4
|
36 |
+
|
37 |
+
print("Test 4 - text JAX export")
|
38 |
+
x, y, z = sympy.symbols('x y z')
|
39 |
+
cosx = 1.0 * sympy.cos(x) + y
|
40 |
+
key = random.PRNGKey(0)
|
41 |
+
X = random.normal(key, (1000, 2))
|
42 |
+
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
|
43 |
+
f, params = sympy2jax(cosx, [x])
|
44 |
+
assert jnp.all(jnp.isclose(f(X, params), true)).item()
|