MilesCranmer commited on
Commit
ef7a292
1 Parent(s): d146a3e

Fix variable_names in test

Browse files
Files changed (1) hide show
  1. test/test.py +2 -0
test/test.py CHANGED
@@ -1,6 +1,7 @@
1
  import unittest
2
  from unittest.mock import patch
3
  import numpy as np
 
4
  from pysr import pysr, get_hof, best, best_tex, best_callable, best_row
5
  from pysr.sr import run_feature_selection, _handle_feature_selection, _yesno
6
  import sympy
@@ -219,6 +220,7 @@ class TestFeatureSelection(unittest.TestCase):
219
  y=y,
220
  )
221
  self.assertTrue((2 in selection) and (3 in selection))
 
222
  self.assertEqual(set(selected_var_names), set("x2 x3".split(" ")))
223
  np.testing.assert_array_equal(
224
  np.sort(selected_X, axis=1), np.sort(X[:, [2, 3]], axis=1)
 
1
  import unittest
2
  from unittest.mock import patch
3
  import numpy as np
4
+ from uritemplate import variables
5
  from pysr import pysr, get_hof, best, best_tex, best_callable, best_row
6
  from pysr.sr import run_feature_selection, _handle_feature_selection, _yesno
7
  import sympy
 
220
  y=y,
221
  )
222
  self.assertTrue((2 in selection) and (3 in selection))
223
+ selected_var_names = [var_names[i] for i in selection]
224
  self.assertEqual(set(selected_var_names), set("x2 x3".split(" ")))
225
  np.testing.assert_array_equal(
226
  np.sort(selected_X, axis=1), np.sort(X[:, [2, 3]], axis=1)