Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
c9cead8
1
Parent(s):
7cda629
Make torch custom operator test deterministic
Browse files- test/test_torch.py +10 -3
test/test_torch.py
CHANGED
@@ -160,9 +160,10 @@ class TestTorch(unittest.TestCase):
|
|
160 |
)
|
161 |
|
162 |
def test_feature_selection_custom_operators(self):
|
163 |
-
|
|
|
164 |
cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
|
165 |
-
y = X["k15"] ** 2 + cos_approx(X["k20"])
|
166 |
|
167 |
model = PySRRegressor(
|
168 |
progress=False,
|
@@ -172,7 +173,12 @@ class TestTorch(unittest.TestCase):
|
|
172 |
early_stop_condition=1e-5,
|
173 |
extra_sympy_mappings={"cos_approx": cos_approx},
|
174 |
extra_torch_mappings={"cos_approx": cos_approx},
|
|
|
|
|
|
|
|
|
175 |
)
|
|
|
176 |
model.fit(X.values, y.values)
|
177 |
torch_module = model.pytorch()
|
178 |
|
@@ -180,4 +186,5 @@ class TestTorch(unittest.TestCase):
|
|
180 |
|
181 |
torch_output = torch_module(torch.tensor(X.values)).detach().numpy()
|
182 |
|
183 |
-
np.testing.assert_almost_equal(
|
|
|
|
160 |
)
|
161 |
|
162 |
def test_feature_selection_custom_operators(self):
|
163 |
+
rstate = np.random.RandomState(0)
|
164 |
+
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|
165 |
cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
|
166 |
+
y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
|
167 |
|
168 |
model = PySRRegressor(
|
169 |
progress=False,
|
|
|
173 |
early_stop_condition=1e-5,
|
174 |
extra_sympy_mappings={"cos_approx": cos_approx},
|
175 |
extra_torch_mappings={"cos_approx": cos_approx},
|
176 |
+
random_state=0,
|
177 |
+
deterministic=True,
|
178 |
+
procs=0,
|
179 |
+
multithreading=False,
|
180 |
)
|
181 |
+
np.random.seed(0)
|
182 |
model.fit(X.values, y.values)
|
183 |
torch_module = model.pytorch()
|
184 |
|
|
|
186 |
|
187 |
torch_output = torch_module(torch.tensor(X.values)).detach().numpy()
|
188 |
|
189 |
+
np.testing.assert_almost_equal(y.values, np_output, decimal=4)
|
190 |
+
np.testing.assert_almost_equal(y.values, torch_output, decimal=4)
|