Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
c86910d
1
Parent(s):
1d19a08
Add full objective test
Browse files- pysr/test/test.py +15 -3
pysr/test/test.py
CHANGED
@@ -72,8 +72,10 @@ class TestPipeline(unittest.TestCase):
|
|
72 |
print(model.equations_)
|
73 |
self.assertLessEqual(model.get_best()["loss"], 1e-4)
|
74 |
|
75 |
-
def
|
|
|
76 |
y = self.X[:, 0]
|
|
|
77 |
model = PySRRegressor(
|
78 |
**self.default_test_kwargs,
|
79 |
# Turbo needs to work with unsafe operators:
|
@@ -81,11 +83,21 @@ class TestPipeline(unittest.TestCase):
|
|
81 |
procs=2,
|
82 |
multithreading=False,
|
83 |
turbo=True,
|
84 |
-
early_stop_condition="stop_if(loss, complexity) = loss < 1e-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
)
|
86 |
model.fit(self.X, y)
|
87 |
print(model.equations_)
|
88 |
-
|
|
|
|
|
89 |
|
90 |
def test_high_precision_search_custom_loss(self):
|
91 |
y = 1.23456789 * self.X[:, 0]
|
|
|
72 |
print(model.equations_)
|
73 |
self.assertLessEqual(model.get_best()["loss"], 1e-4)
|
74 |
|
75 |
+
def test_multiprocessing_turbo_custom_objective(self):
|
76 |
+
rstate = np.random.RandomState(0)
|
77 |
y = self.X[:, 0]
|
78 |
+
y += rstate.randn(*y.shape) * 1e-4
|
79 |
model = PySRRegressor(
|
80 |
**self.default_test_kwargs,
|
81 |
# Turbo needs to work with unsafe operators:
|
|
|
83 |
procs=2,
|
84 |
multithreading=False,
|
85 |
turbo=True,
|
86 |
+
early_stop_condition="stop_if(loss, complexity) = loss < 1e-10 && complexity == 1",
|
87 |
+
full_objective="""
|
88 |
+
function my_objective(tree::Node{T}, dataset::Dataset{T}, options::Options) where T
|
89 |
+
prediction, flag = eval_tree_array(tree, dataset.X, options)
|
90 |
+
!flag && return T(Inf)
|
91 |
+
abs3(x) = abs(x) ^ 3
|
92 |
+
return sum(abs3, prediction .- dataset.y) / length(prediction)
|
93 |
+
end
|
94 |
+
""",
|
95 |
)
|
96 |
model.fit(self.X, y)
|
97 |
print(model.equations_)
|
98 |
+
best_loss = model.equations_.iloc[-1]["loss"]
|
99 |
+
self.assertLessEqual(best_loss, 1e-10)
|
100 |
+
self.assertGreaterEqual(best_loss, 0.0)
|
101 |
|
102 |
def test_high_precision_search_custom_loss(self):
|
103 |
y = 1.23456789 * self.X[:, 0]
|