Spaces:
Running
Running
MilesCranmer
commited on
Commit
·
3752ba6
1
Parent(s):
0dbee97
Include necessary packages in latex_table()
Browse files- pysr/export_latex.py +0 -10
- pysr/sr.py +10 -1
- test/test.py +33 -5
pysr/export_latex.py
CHANGED
@@ -6,9 +6,6 @@ from typing import List
|
|
6 |
import warnings
|
7 |
|
8 |
|
9 |
-
raised_long_equation_warning = False
|
10 |
-
|
11 |
-
|
12 |
class PreciseLatexPrinter(LatexPrinter):
|
13 |
"""Modified SymPy printer with custom float precision."""
|
14 |
|
@@ -70,8 +67,6 @@ def generate_single_table(
|
|
70 |
"""Generate a booktabs-style LaTeX table for a single set of equations."""
|
71 |
assert isinstance(equations, pd.DataFrame)
|
72 |
|
73 |
-
global raised_long_equation_warning
|
74 |
-
|
75 |
latex_top, latex_bottom = generate_table_environment(columns)
|
76 |
latex_table_content = []
|
77 |
|
@@ -101,11 +96,6 @@ def generate_single_table(
|
|
101 |
"$" + output_variable_name + " = " + latex_equation + "$"
|
102 |
)
|
103 |
else:
|
104 |
-
if not raised_long_equation_warning:
|
105 |
-
warnings.warn(
|
106 |
-
"Please add \\usepackage{breqn} to your LaTeX preamble."
|
107 |
-
)
|
108 |
-
raised_long_equation_warning = True
|
109 |
|
110 |
broken_latex_equation = " ".join(
|
111 |
[
|
|
|
6 |
import warnings
|
7 |
|
8 |
|
|
|
|
|
|
|
9 |
class PreciseLatexPrinter(LatexPrinter):
|
10 |
"""Modified SymPy printer with custom float precision."""
|
11 |
|
|
|
67 |
"""Generate a booktabs-style LaTeX table for a single set of equations."""
|
68 |
assert isinstance(equations, pd.DataFrame)
|
69 |
|
|
|
|
|
70 |
latex_top, latex_bottom = generate_table_environment(columns)
|
71 |
latex_table_content = []
|
72 |
|
|
|
96 |
"$" + output_variable_name + " = " + latex_equation + "$"
|
97 |
)
|
98 |
else:
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
broken_latex_equation = " ".join(
|
101 |
[
|
pysr/sr.py
CHANGED
@@ -2197,9 +2197,18 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2197 |
|
2198 |
generator_fnc = generate_single_table
|
2199 |
|
2200 |
-
|
2201 |
self.equations_, indices=indices, precision=precision, columns=columns
|
2202 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2203 |
|
2204 |
|
2205 |
def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int:
|
|
|
2197 |
|
2198 |
generator_fnc = generate_single_table
|
2199 |
|
2200 |
+
table_string = generator_fnc(
|
2201 |
self.equations_, indices=indices, precision=precision, columns=columns
|
2202 |
)
|
2203 |
+
preamble_string = [
|
2204 |
+
r"\usepackage{breqn}",
|
2205 |
+
r"\usepackage{booktabs}",
|
2206 |
+
r"\usepackage{tabularx}",
|
2207 |
+
"",
|
2208 |
+
"...",
|
2209 |
+
"",
|
2210 |
+
]
|
2211 |
+
return "\n".join(preamble_string + [table_string])
|
2212 |
|
2213 |
|
2214 |
def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int:
|
test/test.py
CHANGED
@@ -608,6 +608,18 @@ class TestMiscellaneous(unittest.TestCase):
|
|
608 |
self.assertEqual(len(exception_messages), 0)
|
609 |
|
610 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
611 |
class TestLaTeXTable(unittest.TestCase):
|
612 |
def setUp(self):
|
613 |
equations = pd.DataFrame(
|
@@ -618,6 +630,7 @@ class TestLaTeXTable(unittest.TestCase):
|
|
618 |
)
|
619 |
)
|
620 |
self.model = manually_create_model(equations)
|
|
|
621 |
|
622 |
def create_true_latex(self, middle_part, include_score=False):
|
623 |
if include_score:
|
@@ -657,7 +670,9 @@ class TestLaTeXTable(unittest.TestCase):
|
|
657 |
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ \\
|
658 |
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
|
659 |
"""
|
660 |
-
true_latex_table_str =
|
|
|
|
|
661 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
662 |
|
663 |
def test_other_precision(self):
|
@@ -669,7 +684,9 @@ class TestLaTeXTable(unittest.TestCase):
|
|
669 |
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.023150$ \\
|
670 |
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.1235 \cdot 10^{-15}$ \\
|
671 |
"""
|
672 |
-
true_latex_table_str =
|
|
|
|
|
673 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
674 |
|
675 |
def test_include_score(self):
|
@@ -679,7 +696,11 @@ class TestLaTeXTable(unittest.TestCase):
|
|
679 |
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
|
680 |
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
|
681 |
"""
|
682 |
-
true_latex_table_str =
|
|
|
|
|
|
|
|
|
683 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
684 |
|
685 |
def test_last_equation(self):
|
@@ -689,7 +710,9 @@ class TestLaTeXTable(unittest.TestCase):
|
|
689 |
middle_part = r"""
|
690 |
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
|
691 |
"""
|
692 |
-
true_latex_table_str =
|
|
|
|
|
693 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
694 |
|
695 |
def test_multi_output(self):
|
@@ -723,6 +746,7 @@ class TestLaTeXTable(unittest.TestCase):
|
|
723 |
self.create_true_latex(part, include_score=True)
|
724 |
for part in [middle_part_1, middle_part_2]
|
725 |
)
|
|
|
726 |
latex_table_str = model.latex_table()
|
727 |
|
728 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
@@ -771,5 +795,9 @@ class TestLaTeXTable(unittest.TestCase):
|
|
771 |
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
|
772 |
\begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0}^{5} + x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} - 5.20 \sin{\left(2.60 x_{0} - 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
|
773 |
"""
|
774 |
-
true_latex_table_str =
|
|
|
|
|
|
|
|
|
775 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
|
|
608 |
self.assertEqual(len(exception_messages), 0)
|
609 |
|
610 |
|
611 |
+
TRUE_PREAMBLE = "\n".join(
|
612 |
+
[
|
613 |
+
r"\usepackage{breqn}",
|
614 |
+
r"\usepackage{booktabs}",
|
615 |
+
r"\usepackage{tabularx}",
|
616 |
+
"",
|
617 |
+
"...",
|
618 |
+
"",
|
619 |
+
]
|
620 |
+
)
|
621 |
+
|
622 |
+
|
623 |
class TestLaTeXTable(unittest.TestCase):
|
624 |
def setUp(self):
|
625 |
equations = pd.DataFrame(
|
|
|
630 |
)
|
631 |
)
|
632 |
self.model = manually_create_model(equations)
|
633 |
+
self.maxDiff = None
|
634 |
|
635 |
def create_true_latex(self, middle_part, include_score=False):
|
636 |
if include_score:
|
|
|
670 |
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ \\
|
671 |
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
|
672 |
"""
|
673 |
+
true_latex_table_str = (
|
674 |
+
TRUE_PREAMBLE + "\n" + self.create_true_latex(middle_part)
|
675 |
+
)
|
676 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
677 |
|
678 |
def test_other_precision(self):
|
|
|
684 |
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.023150$ \\
|
685 |
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.1235 \cdot 10^{-15}$ \\
|
686 |
"""
|
687 |
+
true_latex_table_str = (
|
688 |
+
TRUE_PREAMBLE + "\n" + self.create_true_latex(middle_part)
|
689 |
+
)
|
690 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
691 |
|
692 |
def test_include_score(self):
|
|
|
696 |
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
|
697 |
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
|
698 |
"""
|
699 |
+
true_latex_table_str = (
|
700 |
+
TRUE_PREAMBLE
|
701 |
+
+ "\n"
|
702 |
+
+ self.create_true_latex(middle_part, include_score=True)
|
703 |
+
)
|
704 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
705 |
|
706 |
def test_last_equation(self):
|
|
|
710 |
middle_part = r"""
|
711 |
$y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
|
712 |
"""
|
713 |
+
true_latex_table_str = (
|
714 |
+
TRUE_PREAMBLE + "\n" + self.create_true_latex(middle_part)
|
715 |
+
)
|
716 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
717 |
|
718 |
def test_multi_output(self):
|
|
|
746 |
self.create_true_latex(part, include_score=True)
|
747 |
for part in [middle_part_1, middle_part_2]
|
748 |
)
|
749 |
+
true_latex_table_str = TRUE_PREAMBLE + "\n" + true_latex_table_str
|
750 |
latex_table_str = model.latex_table()
|
751 |
|
752 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
|
|
795 |
$y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
|
796 |
\begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0}^{5} + x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} - 5.20 \sin{\left(2.60 x_{0} - 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
|
797 |
"""
|
798 |
+
true_latex_table_str = (
|
799 |
+
TRUE_PREAMBLE
|
800 |
+
+ "\n"
|
801 |
+
+ self.create_true_latex(middle_part, include_score=True)
|
802 |
+
)
|
803 |
self.assertEqual(latex_table_str, true_latex_table_str)
|