MilesCranmer commited on
Commit
ef66f4a
1 Parent(s): 4db7637

Try to fix nb sanitizer

Browse files
pysr/_cli/main.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import warnings
2
 
3
  import click
@@ -55,19 +56,27 @@ def _tests(tests):
55
 
56
  Choose from main, jax, torch, cli, dev, and startup. You can give multiple tests, separated by commas.
57
  """
 
58
  for test in tests.split(","):
59
  if test == "main":
60
- runtests()
61
  elif test == "jax":
62
- runtests_jax()
63
  elif test == "torch":
64
- runtests_torch()
65
  elif test == "cli":
66
  runtests_cli = get_runtests_cli()
67
- runtests_cli()
68
  elif test == "dev":
69
- runtests_dev()
70
  elif test == "startup":
71
- runtests_startup()
72
  else:
73
  warnings.warn(f"Invalid test {test}. Skipping.")
 
 
 
 
 
 
 
 
1
+ import unittest
2
  import warnings
3
 
4
  import click
 
56
 
57
  Choose from main, jax, torch, cli, dev, and startup. You can give multiple tests, separated by commas.
58
  """
59
+ test_cases = []
60
  for test in tests.split(","):
61
  if test == "main":
62
+ test_cases.extend(runtests(just_tests=True))
63
  elif test == "jax":
64
+ test_cases.extend(runtests_jax(just_tests=True))
65
  elif test == "torch":
66
+ test_cases.extend(runtests_torch(just_tests=True))
67
  elif test == "cli":
68
  runtests_cli = get_runtests_cli()
69
+ test_cases.extend(runtests_cli(just_tests=True))
70
  elif test == "dev":
71
+ test_cases.extend(runtests_dev(just_tests=True))
72
  elif test == "startup":
73
+ test_cases.extend(runtests_startup(just_tests=True))
74
  else:
75
  warnings.warn(f"Invalid test {test}. Skipping.")
76
+
77
+ loader = unittest.TestLoader()
78
+ suite = unittest.TestSuite()
79
+ for test_case in test_cases:
80
+ suite.addTests(loader.loadTestsFromTestCase(test_case))
81
+ runner = unittest.TextTestRunner()
82
+ return runner.run(suite)
pysr/test/test.py CHANGED
@@ -1127,10 +1127,8 @@ class TestDimensionalConstraints(unittest.TestCase):
1127
  # TODO: Determine desired behavior if second .fit() call does not have units
1128
 
1129
 
1130
- def runtests():
1131
  """Run all tests in test.py."""
1132
- suite = unittest.TestSuite()
1133
- loader = unittest.TestLoader()
1134
  test_cases = [
1135
  TestPipeline,
1136
  TestBest,
@@ -1139,8 +1137,11 @@ def runtests():
1139
  TestLaTeXTable,
1140
  TestDimensionalConstraints,
1141
  ]
 
 
 
 
1142
  for test_case in test_cases:
1143
- tests = loader.loadTestsFromTestCase(test_case)
1144
- suite.addTests(tests)
1145
  runner = unittest.TextTestRunner()
1146
  return runner.run(suite)
 
1127
  # TODO: Determine desired behavior if second .fit() call does not have units
1128
 
1129
 
1130
+ def runtests(just_tests=False):
1131
  """Run all tests in test.py."""
 
 
1132
  test_cases = [
1133
  TestPipeline,
1134
  TestBest,
 
1137
  TestLaTeXTable,
1138
  TestDimensionalConstraints,
1139
  ]
1140
+ if just_tests:
1141
+ return test_cases
1142
+ suite = unittest.TestSuite()
1143
+ loader = unittest.TestLoader()
1144
  for test_case in test_cases:
1145
+ suite.addTests(loader.loadTestsFromTestCase(test_case))
 
1146
  runner = unittest.TextTestRunner()
1147
  return runner.run(suite)
pysr/test/test_cli.py CHANGED
@@ -68,11 +68,15 @@ def get_runtests():
68
  self.assertEqual(result.output.strip(), expected.strip())
69
  self.assertEqual(result.exit_code, 0)
70
 
71
- def runtests():
72
  """Run all tests in cliTest.py."""
 
 
 
73
  loader = unittest.TestLoader()
74
  suite = unittest.TestSuite()
75
- suite.addTests(loader.loadTestsFromTestCase(TestCli))
 
76
  runner = unittest.TextTestRunner()
77
  return runner.run(suite)
78
 
 
68
  self.assertEqual(result.output.strip(), expected.strip())
69
  self.assertEqual(result.exit_code, 0)
70
 
71
+ def runtests(just_tests=False):
72
  """Run all tests in cliTest.py."""
73
+ tests = [TestCli]
74
+ if just_tests:
75
+ return tests
76
  loader = unittest.TestLoader()
77
  suite = unittest.TestSuite()
78
+ for test in tests:
79
+ suite.addTests(loader.loadTestsFromTestCase(test))
80
  runner = unittest.TextTestRunner()
81
  return runner.run(suite)
82
 
pysr/test/test_dev.py CHANGED
@@ -47,9 +47,13 @@ class TestDev(unittest.TestCase):
47
  self.assertEqual(test_result.stdout.decode("utf-8").strip(), "2.3")
48
 
49
 
50
- def runtests():
 
 
 
51
  suite = unittest.TestSuite()
52
  loader = unittest.TestLoader()
53
- suite.addTests(loader.loadTestsFromTestCase(TestDev))
 
54
  runner = unittest.TextTestRunner()
55
  return runner.run(suite)
 
47
  self.assertEqual(test_result.stdout.decode("utf-8").strip(), "2.3")
48
 
49
 
50
+ def runtests(just_tests=False):
51
+ tests = [TestDev]
52
+ if just_tests:
53
+ return tests
54
  suite = unittest.TestSuite()
55
  loader = unittest.TestLoader()
56
+ for test in tests:
57
+ suite.addTests(loader.loadTestsFromTestCase(test))
58
  runner = unittest.TextTestRunner()
59
  return runner.run(suite)
pysr/test/test_jax.py CHANGED
@@ -121,10 +121,14 @@ class TestJAX(unittest.TestCase):
121
  np.testing.assert_almost_equal(y.values, jax_output, decimal=3)
122
 
123
 
124
- def runtests():
125
  """Run all tests in test_jax.py."""
 
 
 
126
  loader = unittest.TestLoader()
127
  suite = unittest.TestSuite()
128
- suite.addTests(loader.loadTestsFromTestCase(TestJAX))
 
129
  runner = unittest.TextTestRunner()
130
  return runner.run(suite)
 
121
  np.testing.assert_almost_equal(y.values, jax_output, decimal=3)
122
 
123
 
124
+ def runtests(just_tests=False):
125
  """Run all tests in test_jax.py."""
126
+ tests = [TestJAX]
127
+ if just_tests:
128
+ return tests
129
  loader = unittest.TestLoader()
130
  suite = unittest.TestSuite()
131
+ for test in tests:
132
+ suite.addTests(loader.loadTestsFromTestCase(test))
133
  runner = unittest.TextTestRunner()
134
  return runner.run(suite)
pysr/test/test_startup.py CHANGED
@@ -143,9 +143,13 @@ class TestStartup(unittest.TestCase):
143
  self.assertEqual(result.returncode, 0)
144
 
145
 
146
- def runtests():
 
 
 
147
  suite = unittest.TestSuite()
148
  loader = unittest.TestLoader()
149
- suite.addTests(loader.loadTestsFromTestCase(TestStartup))
 
150
  runner = unittest.TextTestRunner()
151
  return runner.run(suite)
 
143
  self.assertEqual(result.returncode, 0)
144
 
145
 
146
+ def runtests(just_tests=False):
147
+ tests = [TestStartup]
148
+ if just_tests:
149
+ return tests
150
  suite = unittest.TestSuite()
151
  loader = unittest.TestLoader()
152
+ for test in tests:
153
+ suite.addTests(loader.loadTestsFromTestCase(test))
154
  runner = unittest.TextTestRunner()
155
  return runner.run(suite)
pysr/test/test_torch.py CHANGED
@@ -184,10 +184,14 @@ class TestTorch(unittest.TestCase):
184
  np.testing.assert_almost_equal(y.values, torch_output, decimal=3)
185
 
186
 
187
- def runtests():
188
  """Run all tests in test_torch.py."""
 
 
 
189
  loader = unittest.TestLoader()
190
  suite = unittest.TestSuite()
191
- suite.addTests(loader.loadTestsFromTestCase(TestTorch))
 
192
  runner = unittest.TextTestRunner()
193
  return runner.run(suite)
 
184
  np.testing.assert_almost_equal(y.values, torch_output, decimal=3)
185
 
186
 
187
+ def runtests(just_tests=False):
188
  """Run all tests in test_torch.py."""
189
+ tests = [TestTorch]
190
+ if just_tests:
191
+ return tests
192
  loader = unittest.TestLoader()
193
  suite = unittest.TestSuite()
194
+ for test in tests:
195
+ suite.addTests(loader.loadTestsFromTestCase(test))
196
  runner = unittest.TextTestRunner()
197
  return runner.run(suite)