Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
beaf20b
1
Parent(s):
f119733
Make JAX custom operator test deterministic
Browse files- test/test_jax.py +10 -3
test/test_jax.py
CHANGED
@@ -80,9 +80,10 @@ class TestJAX(unittest.TestCase):
|
|
80 |
)
|
81 |
|
82 |
def test_feature_selection_custom_operators(self):
|
83 |
-
|
|
|
84 |
cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
|
85 |
-
y = X["k15"] ** 2 + cos_approx(X["k20"])
|
86 |
|
87 |
model = PySRRegressor(
|
88 |
progress=False,
|
@@ -94,7 +95,12 @@ class TestJAX(unittest.TestCase):
|
|
94 |
extra_jax_mappings={
|
95 |
"cos_approx": "(lambda x: 1 - x**2 / 2 + x**4 / 24 + x**6 / 720)"
|
96 |
},
|
|
|
|
|
|
|
|
|
97 |
)
|
|
|
98 |
model.fit(X.values, y.values)
|
99 |
f, parameters = model.jax().values()
|
100 |
|
@@ -104,4 +110,5 @@ class TestJAX(unittest.TestCase):
|
|
104 |
np_output = np_prediction(X.values)
|
105 |
jax_output = jax_prediction(X.values)
|
106 |
|
107 |
-
np.testing.assert_almost_equal(
|
|
|
|
80 |
)
|
81 |
|
82 |
def test_feature_selection_custom_operators(self):
|
83 |
+
rstate = np.random.RandomState(0)
|
84 |
+
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|
85 |
cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
|
86 |
+
y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
|
87 |
|
88 |
model = PySRRegressor(
|
89 |
progress=False,
|
|
|
95 |
extra_jax_mappings={
|
96 |
"cos_approx": "(lambda x: 1 - x**2 / 2 + x**4 / 24 + x**6 / 720)"
|
97 |
},
|
98 |
+
random_state=0,
|
99 |
+
deterministic=True,
|
100 |
+
procs=0,
|
101 |
+
multithreading=False,
|
102 |
)
|
103 |
+
np.random.seed(0)
|
104 |
model.fit(X.values, y.values)
|
105 |
f, parameters = model.jax().values()
|
106 |
|
|
|
110 |
np_output = np_prediction(X.values)
|
111 |
jax_output = jax_prediction(X.values)
|
112 |
|
113 |
+
np.testing.assert_almost_equal(y.values, np_output, decimal=4)
|
114 |
+
np.testing.assert_almost_equal(y.values, jax_output, decimal=4)
|