Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
c6c8728
1
Parent(s):
8f218cc
Test all aspects of generated LaTeX table
Browse files- test/test.py +107 -19
test/test.py
CHANGED
@@ -281,19 +281,38 @@ class TestPipeline(unittest.TestCase):
|
|
281 |
self.assertLess(np.average((model.predict(X.values) - y.values) ** 2), 1e-4)
|
282 |
|
283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
class TestBest(unittest.TestCase):
|
285 |
def setUp(self):
|
286 |
self.rstate = np.random.RandomState(0)
|
287 |
self.X = self.rstate.randn(10, 2)
|
288 |
self.y = np.cos(self.X[:, 0]) ** 2
|
289 |
-
self.model = PySRRegressor(
|
290 |
-
progress=False,
|
291 |
-
niterations=1,
|
292 |
-
extra_sympy_mappings={},
|
293 |
-
output_jax_format=False,
|
294 |
-
model_selection="accuracy",
|
295 |
-
equation_file="equation_file.csv",
|
296 |
-
)
|
297 |
equations = pd.DataFrame(
|
298 |
{
|
299 |
"equation": ["1.0", "cos(x0)", "square(cos(x0))"],
|
@@ -301,17 +320,7 @@ class TestBest(unittest.TestCase):
|
|
301 |
"complexity": [1, 2, 3],
|
302 |
}
|
303 |
)
|
304 |
-
|
305 |
-
# Set up internal parameters as if it had been fitted:
|
306 |
-
self.model.equation_file_ = "equation_file.csv"
|
307 |
-
self.model.nout_ = 1
|
308 |
-
self.model.selection_mask_ = None
|
309 |
-
self.model.feature_names_in_ = np.array(["x0", "x1"], dtype=object)
|
310 |
-
equations["complexity loss equation".split(" ")].to_csv(
|
311 |
-
"equation_file.csv.bkup", sep="|"
|
312 |
-
)
|
313 |
-
|
314 |
-
self.model.refresh()
|
315 |
self.equations_ = self.model.equations_
|
316 |
|
317 |
def test_best(self):
|
@@ -485,3 +494,82 @@ class TestMiscellaneous(unittest.TestCase):
|
|
485 |
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
486 |
# If any checks failed don't let the test pass.
|
487 |
self.assertEqual(len(exception_messages), 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
self.assertLess(np.average((model.predict(X.values) - y.values) ** 2), 1e-4)
|
282 |
|
283 |
|
284 |
+
def manually_create_model(equations, feature_names=None):
|
285 |
+
if feature_names is None:
|
286 |
+
feature_names = ["x0", "x1"]
|
287 |
+
|
288 |
+
model = PySRRegressor(
|
289 |
+
progress=False,
|
290 |
+
niterations=1,
|
291 |
+
extra_sympy_mappings={},
|
292 |
+
output_jax_format=False,
|
293 |
+
model_selection="accuracy",
|
294 |
+
equation_file="equation_file.csv",
|
295 |
+
)
|
296 |
+
|
297 |
+
# Set up internal parameters as if it had been fitted:
|
298 |
+
model.equation_file_ = "equation_file.csv"
|
299 |
+
model.nout_ = 1
|
300 |
+
model.selection_mask_ = None
|
301 |
+
model.feature_names_in_ = np.array(feature_names, dtype=object)
|
302 |
+
equations["complexity loss equation".split(" ")].to_csv(
|
303 |
+
"equation_file.csv.bkup", sep="|"
|
304 |
+
)
|
305 |
+
|
306 |
+
model.refresh()
|
307 |
+
|
308 |
+
return model
|
309 |
+
|
310 |
+
|
311 |
class TestBest(unittest.TestCase):
|
312 |
def setUp(self):
|
313 |
self.rstate = np.random.RandomState(0)
|
314 |
self.X = self.rstate.randn(10, 2)
|
315 |
self.y = np.cos(self.X[:, 0]) ** 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
equations = pd.DataFrame(
|
317 |
{
|
318 |
"equation": ["1.0", "cos(x0)", "square(cos(x0))"],
|
|
|
320 |
"complexity": [1, 2, 3],
|
321 |
}
|
322 |
)
|
323 |
+
self.model = manually_create_model(equations)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
self.equations_ = self.model.equations_
|
325 |
|
326 |
def test_best(self):
|
|
|
494 |
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
495 |
# If any checks failed don't let the test pass.
|
496 |
self.assertEqual(len(exception_messages), 0)
|
497 |
+
|
498 |
+
|
499 |
+
class TestLaTeXTable(unittest.TestCase):
|
500 |
+
def create_true_latex(self, middle_part, include_score=False):
|
501 |
+
if include_score:
|
502 |
+
true_latex_table_str = r"""
|
503 |
+
\begin{table}[h]
|
504 |
+
\begin{center}
|
505 |
+
\begin{tabular}{@{}clll@{}}
|
506 |
+
\toprule
|
507 |
+
Equation & Complexity & Loss & Score \\
|
508 |
+
\midrule"""
|
509 |
+
else:
|
510 |
+
true_latex_table_str = r"""
|
511 |
+
\begin{table}[h]
|
512 |
+
\begin{center}
|
513 |
+
\begin{tabular}{@{}cll@{}}
|
514 |
+
\toprule
|
515 |
+
Equation & Complexity & Loss \\
|
516 |
+
\midrule"""
|
517 |
+
true_latex_table_str += middle_part
|
518 |
+
true_latex_table_str += r"""\bottomrule
|
519 |
+
\end{tabular}
|
520 |
+
\end{center}
|
521 |
+
\end{table}
|
522 |
+
"""
|
523 |
+
# First, remove empty lines:
|
524 |
+
true_latex_table_str = "\n".join(
|
525 |
+
[line.strip() for line in true_latex_table_str.split("\n") if len(line) > 0]
|
526 |
+
)
|
527 |
+
return true_latex_table_str.strip()
|
528 |
+
|
529 |
+
def test_simple_table(self):
|
530 |
+
equations = pd.DataFrame(
|
531 |
+
dict(
|
532 |
+
equation=["x0", "cos(x0)", "x0 + x1 - cos(x1 * x0)"],
|
533 |
+
loss=[1.052, 0.02315, 1.12347e-15],
|
534 |
+
complexity=[1, 2, 8],
|
535 |
+
)
|
536 |
+
)
|
537 |
+
model = manually_create_model(equations)
|
538 |
+
|
539 |
+
# Regular table:
|
540 |
+
latex_table_str = model.latex_table()
|
541 |
+
middle_part = r"""
|
542 |
+
$x_{0}$ & 1 & 1.05 \\
|
543 |
+
$\cos{\left(x_{0} \right)}$ & 2 & 0.0232 \\
|
544 |
+
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & 8 & 1.12e-15 \\
|
545 |
+
"""
|
546 |
+
true_latex_table_str = self.create_true_latex(middle_part)
|
547 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|
548 |
+
|
549 |
+
# Different precision:
|
550 |
+
latex_table_str = model.latex_table(precision=5)
|
551 |
+
middle_part = r"""
|
552 |
+
$x_{0}$ & 1 & 1.052 \\
|
553 |
+
$\cos{\left(x_{0} \right)}$ & 2 & 0.02315 \\
|
554 |
+
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & 8 & 1.1235e-15 \\
|
555 |
+
"""
|
556 |
+
true_latex_table_str = self.create_true_latex(middle_part)
|
557 |
+
self.assertEqual(latex_table_str, self.create_true_latex(middle_part))
|
558 |
+
|
559 |
+
# Including score:
|
560 |
+
latex_table_str = model.latex_table(include_score=True)
|
561 |
+
middle_part = r"""
|
562 |
+
$x_{0}$ & 1 & 1.05 & 0 \\
|
563 |
+
$\cos{\left(x_{0} \right)}$ & 2 & 0.0232 & 3.82 \\
|
564 |
+
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & 8 & 1.12e-15 & 5.11 \\
|
565 |
+
"""
|
566 |
+
true_latex_table_str = self.create_true_latex(middle_part, include_score=True)
|
567 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|
568 |
+
|
569 |
+
# Only last equation:
|
570 |
+
latex_table_str = model.latex_table(indices=[2])
|
571 |
+
middle_part = r"""
|
572 |
+
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & 8 & 1.12e-15 \\
|
573 |
+
"""
|
574 |
+
true_latex_table_str = self.create_true_latex(middle_part)
|
575 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|