MilesCranmer commited on
Commit
e6c9a63
·
unverified ·
2 Parent(s): 3538029 58c6697

Merge pull request #47 from MilesCranmer/coverage

Browse files
Files changed (3) hide show
  1. .github/workflows/CI.yml +16 -2
  2. test/test.py +35 -31
  3. test/test_jax.py +10 -8
.github/workflows/CI.yml CHANGED
@@ -59,14 +59,28 @@ jobs:
59
  python -m pip install --upgrade pip
60
  pip install -r requirements.txt
61
  python setup.py install
 
 
62
  - name: "Install JAX"
63
  if: matrix.os != 'windows-latest'
64
  run: pip install jax jaxlib # (optional import)
65
  shell: bash
66
  - name: "Run tests"
67
- run: python test/test.py
 
 
 
68
  shell: bash
69
  - name: "Run JAX tests"
70
  if: matrix.os != 'windows-latest'
71
- run: python test/test_jax.py
 
 
 
72
  shell: bash
 
 
 
 
 
 
 
59
  python -m pip install --upgrade pip
60
  pip install -r requirements.txt
61
  python setup.py install
62
+ - name: "Install Coverage tool"
63
+ run: pip install coverage coveralls
64
  - name: "Install JAX"
65
  if: matrix.os != 'windows-latest'
66
  run: pip install jax jaxlib # (optional import)
67
  shell: bash
68
  - name: "Run tests"
69
+ run: |
70
+ cd test
71
+ coverage run --source=pysr --omit=pysr.feynman_problems -m unittest test
72
+ cd ..
73
  shell: bash
74
  - name: "Run JAX tests"
75
  if: matrix.os != 'windows-latest'
76
+ run: |
77
+ cd test
78
+ coverage run --append --source=pysr --omit=pysr.feynman_problems -m unittest test_jax
79
+ cd ..
80
  shell: bash
81
+ - name: Coveralls
82
+ env:
83
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
84
+ run: |
85
+ cd test
86
+ coveralls --service=github
test/test.py CHANGED
@@ -1,38 +1,42 @@
 
1
  import numpy as np
2
  from pysr import pysr
3
  import sympy
4
- X = np.random.randn(100, 5)
5
 
6
- default_test_kwargs = dict(
7
- niterations=10,
8
- populations=4,
9
- user_input=False,
10
- annealing=True,
11
- useFrequency=False,
12
- )
 
 
 
 
 
 
 
 
 
 
13
 
14
- print("Test 1 - defaults; simple linear relation")
15
- y = X[:, 0]
16
- equations = pysr(X, y, **default_test_kwargs)
17
- print(equations)
18
- assert equations.iloc[-1]['MSE'] < 1e-4
 
 
 
 
19
 
20
- print("Test 2 - test custom operator, and multiple outputs")
21
- y = X[:, [0, 1]]**2
22
- equations = pysr(X, y,
23
- unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
24
- extra_sympy_mappings={'square': lambda x: x**2},
25
- **default_test_kwargs)
26
- print(equations)
27
- assert equations[0].iloc[-1]['MSE'] < 1e-4
28
- assert equations[1].iloc[-1]['MSE'] < 1e-4
29
 
30
- X = np.random.randn(100, 1)
31
- y = X[:, 0] + 3.0
32
- print("Test 3 - empty operator list, and single dimension input")
33
- equations = pysr(X, y,
34
- unary_operators=[], binary_operators=["plus"],
35
- **default_test_kwargs)
36
-
37
- print(equations)
38
- assert equations.iloc[-1]['MSE'] < 1e-4
 
1
+ import unittest
2
  import numpy as np
3
  from pysr import pysr
4
  import sympy
 
5
 
6
+ class TestPipeline(unittest.TestCase):
7
+ def setUp(self):
8
+ self.default_test_kwargs = dict(
9
+ niterations=10,
10
+ populations=4,
11
+ user_input=False,
12
+ annealing=True,
13
+ useFrequency=False,
14
+ )
15
+ np.random.seed(0)
16
+ self.X = np.random.randn(100, 5)
17
+
18
+ def test_linear_relation(self):
19
+ y = self.X[:, 0]
20
+ equations = pysr(self.X, y, **self.default_test_kwargs)
21
+ print(equations)
22
+ self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
23
 
24
+ def test_multioutput_custom_operator(self):
25
+ y = self.X[:, [0, 1]]**2
26
+ equations = pysr(self.X, y,
27
+ unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
28
+ extra_sympy_mappings={'square': lambda x: x**2},
29
+ **self.default_test_kwargs)
30
+ print(equations)
31
+ self.assertLessEqual(equations[0].iloc[-1]['MSE'], 1e-4)
32
+ self.assertLessEqual(equations[1].iloc[-1]['MSE'], 1e-4)
33
 
34
+ def test_empty_operators_single_input(self):
35
+ X = np.random.randn(100, 1)
36
+ y = X[:, 0] + 3.0
37
+ equations = pysr(X, y,
38
+ unary_operators=[], binary_operators=["plus"],
39
+ **self.default_test_kwargs)
 
 
 
40
 
41
+ print(equations)
42
+ self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
 
 
 
 
 
 
 
test/test_jax.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  from pysr import pysr, sympy2jax
3
  from jax import numpy as jnp
@@ -5,11 +6,12 @@ from jax import random
5
  from jax import grad
6
  import sympy
7
 
8
- print("Test JAX 1 - test export")
9
- x, y, z = sympy.symbols('x y z')
10
- cosx = 1.0 * sympy.cos(x) + y
11
- key = random.PRNGKey(0)
12
- X = random.normal(key, (1000, 2))
13
- true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
14
- f, params = sympy2jax(cosx, [x, y, z])
15
- assert jnp.all(jnp.isclose(f(X, params), true)).item()
 
 
1
+ import unittest
2
  import numpy as np
3
  from pysr import pysr, sympy2jax
4
  from jax import numpy as jnp
 
6
  from jax import grad
7
  import sympy
8
 
9
+ class TestJAX(unittest.TestCase):
10
+ def test_sympy2jax(self):
11
+ x, y, z = sympy.symbols('x y z')
12
+ cosx = 1.0 * sympy.cos(x) + y
13
+ key = random.PRNGKey(0)
14
+ X = random.normal(key, (1000, 2))
15
+ true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
16
+ f, params = sympy2jax(cosx, [x, y, z])
17
+ self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())