Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
9a5df63
1
Parent(s):
0cb1353
Use custom sympy LatexPrinter for precision
Browse files- pysr/export_latex.py +20 -8
- pysr/sr.py +4 -8
- test/test.py +23 -0
pysr/export_latex.py
CHANGED
@@ -1,14 +1,26 @@
|
|
1 |
"""Functions to help export PySR equations to LaTeX."""
|
2 |
-
import
|
|
|
3 |
|
4 |
|
5 |
-
|
6 |
-
"""
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
def generate_top_of_latex_table(columns=["Equation", "Complexity", "Loss"]):
|
|
|
1 |
"""Functions to help export PySR equations to LaTeX."""
|
2 |
+
import sympy
|
3 |
+
from sympy.printing.latex import LatexPrinter
|
4 |
|
5 |
|
6 |
+
class PreciseLatexPrinter(LatexPrinter):
|
7 |
+
"""Modified SymPy printer with custom float precision."""
|
8 |
+
def __init__(self, settings=None, prec=3):
|
9 |
+
super().__init__(settings)
|
10 |
+
self.prec = prec
|
11 |
+
|
12 |
+
def _print_Float(self, expr):
|
13 |
+
# Reduce precision of float:
|
14 |
+
reduced_float = sympy.Float(expr, self.prec)
|
15 |
+
return super()._print_Float(reduced_float)
|
16 |
+
|
17 |
+
|
18 |
+
def to_latex(expr, prec=3, **settings):
|
19 |
+
"""Convert sympy expression to LaTeX with custom precision."""
|
20 |
+
if len(settings) == 0:
|
21 |
+
settings = None
|
22 |
+
printer = PreciseLatexPrinter(settings=settings, prec=prec)
|
23 |
+
return printer.doprint(expr)
|
24 |
|
25 |
|
26 |
def generate_top_of_latex_table(columns=["Equation", "Complexity", "Loss"]):
|
pysr/sr.py
CHANGED
@@ -28,7 +28,7 @@ from .julia_helpers import (
|
|
28 |
)
|
29 |
from .export_numpy import CallableEquation
|
30 |
from .export_latex import (
|
31 |
-
|
32 |
generate_top_of_latex_table,
|
33 |
generate_bottom_of_latex_table,
|
34 |
)
|
@@ -1752,14 +1752,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1752 |
if self.nout_ > 1:
|
1753 |
output = []
|
1754 |
for s in sympy_representation:
|
1755 |
-
|
1756 |
-
|
1757 |
-
raw_latex, precision
|
1758 |
-
)
|
1759 |
-
output.append(reduced_latex)
|
1760 |
return output
|
1761 |
-
|
1762 |
-
return set_precision_of_constants_in_string(raw_latex, precision)
|
1763 |
|
1764 |
def jax(self, index=None):
|
1765 |
"""
|
|
|
28 |
)
|
29 |
from .export_numpy import CallableEquation
|
30 |
from .export_latex import (
|
31 |
+
to_latex,
|
32 |
generate_top_of_latex_table,
|
33 |
generate_bottom_of_latex_table,
|
34 |
)
|
|
|
1752 |
if self.nout_ > 1:
|
1753 |
output = []
|
1754 |
for s in sympy_representation:
|
1755 |
+
latex = to_latex(s, prec=precision)
|
1756 |
+
output.append(latex)
|
|
|
|
|
|
|
1757 |
return output
|
1758 |
+
return to_latex(sympy_representation, prec=precision)
|
|
|
1759 |
|
1760 |
def jax(self, index=None):
|
1761 |
"""
|
test/test.py
CHANGED
@@ -6,6 +6,7 @@ import numpy as np
|
|
6 |
from sklearn import model_selection
|
7 |
from pysr import PySRRegressor
|
8 |
from pysr.sr import run_feature_selection, _handle_feature_selection
|
|
|
9 |
from sklearn.utils.estimator_checks import check_estimator
|
10 |
import sympy
|
11 |
import pandas as pd
|
@@ -573,3 +574,25 @@ class TestLaTeXTable(unittest.TestCase):
|
|
573 |
"""
|
574 |
true_latex_table_str = self.create_true_latex(middle_part)
|
575 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from sklearn import model_selection
|
7 |
from pysr import PySRRegressor
|
8 |
from pysr.sr import run_feature_selection, _handle_feature_selection
|
9 |
+
from pysr.export_latex import to_latex
|
10 |
from sklearn.utils.estimator_checks import check_estimator
|
11 |
import sympy
|
12 |
import pandas as pd
|
|
|
574 |
"""
|
575 |
true_latex_table_str = self.create_true_latex(middle_part)
|
576 |
self.assertEqual(latex_table_str, true_latex_table_str)
|
577 |
+
|
578 |
+
def test_latex_float_precision(self):
|
579 |
+
"""Test that we can print latex expressions with custom precision"""
|
580 |
+
expr = sympy.Float(4583.4485748, dps=50)
|
581 |
+
self.assertEqual(to_latex(expr, prec=6), r"4583.45")
|
582 |
+
self.assertEqual(to_latex(expr, prec=5), r"4583.4")
|
583 |
+
self.assertEqual(to_latex(expr, prec=4), r"4583.0")
|
584 |
+
self.assertEqual(to_latex(expr, prec=3), r"4.58 \cdot 10^{3}")
|
585 |
+
self.assertEqual(to_latex(expr, prec=2), r"4.6 \cdot 10^{3}")
|
586 |
+
|
587 |
+
# Multiple numbers:
|
588 |
+
x = sympy.Symbol("x")
|
589 |
+
expr = x * 3232.324857384 - 1.4857485e-10
|
590 |
+
self.assertEqual(
|
591 |
+
to_latex(expr, prec=2), "3.2 \cdot 10^{3} x - 1.5 \cdot 10^{-10}"
|
592 |
+
)
|
593 |
+
self.assertEqual(
|
594 |
+
to_latex(expr, prec=3), "3.23 \cdot 10^{3} x - 1.49 \cdot 10^{-10}"
|
595 |
+
)
|
596 |
+
self.assertEqual(
|
597 |
+
to_latex(expr, prec=8), "3232.3249 x - 1.4857485 \cdot 10^{-10}"
|
598 |
+
)
|