MilesCranmer commited on
Commit
b07eb2d
1 Parent(s): c96b30c

Test selection inside jax/torch

Browse files
Files changed (2) hide show
  1. test/test_jax.py +3 -3
  2. test/test_torch.py +3 -3
test/test_jax.py CHANGED
@@ -17,9 +17,9 @@ class TestJAX(unittest.TestCase):
17
  f, params = sympy2jax(cosx, [x, y, z])
18
  self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
19
  def test_pipeline(self):
20
- X = np.random.randn(100, 2)
21
  equations = pd.DataFrame({
22
- 'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
23
  'MSE': [1.0, 0.1, 1e-5],
24
  'Complexity': [1, 2, 3]
25
  })
@@ -30,7 +30,7 @@ class TestJAX(unittest.TestCase):
30
  equations = get_hof(
31
  'equation_file.csv', n_features=2, variables_names='x0 x1'.split(' '),
32
  extra_sympy_mappings={}, output_jax_format=True,
33
- multioutput=False, nout=1)
34
 
35
  jformat = equations.iloc[-1].jax_format
36
  np.testing.assert_almost_equal(
 
17
  f, params = sympy2jax(cosx, [x, y, z])
18
  self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
19
  def test_pipeline(self):
20
+ X = np.random.randn(100, 10)
21
  equations = pd.DataFrame({
22
+ 'Equation': ['1.0', 'cos(x1)', 'square(cos(x1))'],
23
  'MSE': [1.0, 0.1, 1e-5],
24
  'Complexity': [1, 2, 3]
25
  })
 
30
  equations = get_hof(
31
  'equation_file.csv', n_features=2, variables_names='x0 x1'.split(' '),
32
  extra_sympy_mappings={}, output_jax_format=True,
33
+ multioutput=False, nout=1, selection=[1, 2, 3])
34
 
35
  jformat = equations.iloc[-1].jax_format
36
  np.testing.assert_almost_equal(
test/test_torch.py CHANGED
@@ -16,9 +16,9 @@ class TestTorch(unittest.TestCase):
16
  np.all(np.isclose(torch_module(X).detach().numpy(), true.detach().numpy()))
17
  )
18
  def test_pipeline(self):
19
- X = np.random.randn(100, 2)
20
  equations = pd.DataFrame({
21
- 'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
22
  'MSE': [1.0, 0.1, 1e-5],
23
  'Complexity': [1, 2, 3]
24
  })
@@ -29,7 +29,7 @@ class TestTorch(unittest.TestCase):
29
  equations = get_hof(
30
  'equation_file.csv', n_features=2, variables_names='x0 x1'.split(' '),
31
  extra_sympy_mappings={}, output_torch_format=True,
32
- multioutput=False, nout=1)
33
 
34
  tformat = equations.iloc[-1].torch_format
35
  np.testing.assert_almost_equal(
 
16
  np.all(np.isclose(torch_module(X).detach().numpy(), true.detach().numpy()))
17
  )
18
  def test_pipeline(self):
19
+ X = np.random.randn(100, 10)
20
  equations = pd.DataFrame({
21
+ 'Equation': ['1.0', 'cos(x1)', 'square(cos(x1))'],
22
  'MSE': [1.0, 0.1, 1e-5],
23
  'Complexity': [1, 2, 3]
24
  })
 
29
  equations = get_hof(
30
  'equation_file.csv', n_features=2, variables_names='x0 x1'.split(' '),
31
  extra_sympy_mappings={}, output_torch_format=True,
32
+ multioutput=False, nout=1, selection=[1, 2, 3])
33
 
34
  tformat = equations.iloc[-1].torch_format
35
  np.testing.assert_almost_equal(