Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
90d3ef7
1
Parent(s):
0857108
Add unittest for complex numbers
Browse files- pysr/test/test.py +14 -0
pysr/test/test.py
CHANGED
@@ -181,6 +181,20 @@ class TestPipeline(unittest.TestCase):
|
|
181 |
print("Model equations: ", model.sympy()[1])
|
182 |
print("True equation: x1^2")
|
183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
def test_empty_operators_single_input_warm_start(self):
|
185 |
X = self.rstate.randn(100, 1)
|
186 |
y = X[:, 0] + 3.0
|
|
|
181 |
print("Model equations: ", model.sympy()[1])
|
182 |
print("True equation: x1^2")
|
183 |
|
184 |
+
def test_complex_equations_anonymous_stop(self):
|
185 |
+
X = self.rstate.randn(100, 3) + 1j * self.rstate.randn(100, 3)
|
186 |
+
y = (2 + 1j) * np.cos(X[:, 0] * (0.5 - 0.3j))
|
187 |
+
model = PySRRegressor(
|
188 |
+
binary_operators=["+", "-", "*"],
|
189 |
+
unary_operators=["cos"],
|
190 |
+
**self.default_test_kwargs,
|
191 |
+
early_stop_condition="(loss, complexity) -> loss <= 1e-4 && complexity <= 6",
|
192 |
+
)
|
193 |
+
model.fit(X, y)
|
194 |
+
test_y = model.predict(X)
|
195 |
+
self.assertTrue(np.issubdtype(test_y.dtype, np.complexfloating))
|
196 |
+
self.assertLessEqual(np.average(np.abs(test_y - y) ** 2), 1e-4)
|
197 |
+
|
198 |
def test_empty_operators_single_input_warm_start(self):
|
199 |
X = self.rstate.randn(100, 1)
|
200 |
y = X[:, 0] + 3.0
|