MilesCranmer commited on
Commit
bb81a9a
·
unverified ·
1 Parent(s): f457f66

test: add test of normal complexity_of_variables

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +21 -18
pysr/test/test.py CHANGED
@@ -178,24 +178,27 @@ class TestPipeline(unittest.TestCase):
178
  self.assertLessEqual(mse2, 1e-4)
179
 
180
  def test_custom_variable_complexity(self):
181
- y = self.X[:, [0, 1]] ** 2
182
- model = PySRRegressor(
183
- binary_operators=["*", "+"],
184
- verbosity=0,
185
- **self.default_test_kwargs,
186
- early_stop_condition="stop_if(l, c) = l < 1e-4 && c <= 7",
187
- )
188
- model.fit(
189
- self.X,
190
- y,
191
- complexity_of_variables=[2, 3] + [100 for _ in range(self.X.shape[1] - 2)],
192
- )
193
- equations = model.equations_
194
- self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4)
195
- self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4)
196
-
197
- self.assertEqual(model.get_best()[0]["complexity"], 5)
198
- self.assertEqual(model.get_best()[1]["complexity"], 7)
 
 
 
199
 
200
  def test_multioutput_weighted_with_callable_temp_equation(self):
201
  X = self.X.copy()
 
178
  self.assertLessEqual(mse2, 1e-4)
179
 
180
  def test_custom_variable_complexity(self):
181
+ for case in (1, 2):
182
+ y = self.X[:, [0, 1]] ** 2
183
+ model = PySRRegressor(
184
+ binary_operators=["*", "+"],
185
+ verbosity=0,
186
+ **self.default_test_kwargs,
187
+ early_stop_condition="stop_if(l, c) = l < 1e-4 && c <= 7",
188
+ )
189
+ if case == 1:
190
+ complexity_of_variables = [2, 3] + [
191
+ 100 for _ in range(self.X.shape[1] - 2)
192
+ ]
193
+ elif case == 2:
194
+ complexity_of_variables = 2
195
+ model.fit(self.X, y, complexity_of_variables=complexity_of_variables)
196
+ equations = model.equations_
197
+ self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4)
198
+ self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4)
199
+
200
+ self.assertEqual(model.get_best()[0]["complexity"], 5)
201
+ self.assertEqual(model.get_best()[1]["complexity"], 7 if case == 1 else 5)
202
 
203
  def test_multioutput_weighted_with_callable_temp_equation(self):
204
  X = self.X.copy()