MilesCranmer commited on
Commit
9a5df63
1 Parent(s): 0cb1353

Use custom sympy LatexPrinter for precision

Browse files
Files changed (3) hide show
  1. pysr/export_latex.py +20 -8
  2. pysr/sr.py +4 -8
  3. test/test.py +23 -0
pysr/export_latex.py CHANGED
@@ -1,14 +1,26 @@
1
  """Functions to help export PySR equations to LaTeX."""
2
- import re
 
3
 
4
 
5
- def set_precision_of_constants_in_string(s, precision=3):
6
- """Set precision of constants in string."""
7
- constants = re.findall(r"\b[-+]?\d*\.\d+|\b[-+]?\d+\.?\d*", s)
8
- for c in constants:
9
- reduced_c = "{:.{precision}g}".format(float(c), precision=precision)
10
- s = s.replace(c, reduced_c)
11
- return s
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  def generate_top_of_latex_table(columns=["Equation", "Complexity", "Loss"]):
 
1
  """Functions to help export PySR equations to LaTeX."""
2
+ import sympy
3
+ from sympy.printing.latex import LatexPrinter
4
 
5
 
6
+ class PreciseLatexPrinter(LatexPrinter):
7
+ """Modified SymPy printer with custom float precision."""
8
+ def __init__(self, settings=None, prec=3):
9
+ super().__init__(settings)
10
+ self.prec = prec
11
+
12
+ def _print_Float(self, expr):
13
+ # Reduce precision of float:
14
+ reduced_float = sympy.Float(expr, self.prec)
15
+ return super()._print_Float(reduced_float)
16
+
17
+
18
+ def to_latex(expr, prec=3, **settings):
19
+ """Convert sympy expression to LaTeX with custom precision."""
20
+ if len(settings) == 0:
21
+ settings = None
22
+ printer = PreciseLatexPrinter(settings=settings, prec=prec)
23
+ return printer.doprint(expr)
24
 
25
 
26
  def generate_top_of_latex_table(columns=["Equation", "Complexity", "Loss"]):
pysr/sr.py CHANGED
@@ -28,7 +28,7 @@ from .julia_helpers import (
28
  )
29
  from .export_numpy import CallableEquation
30
  from .export_latex import (
31
- set_precision_of_constants_in_string,
32
  generate_top_of_latex_table,
33
  generate_bottom_of_latex_table,
34
  )
@@ -1752,14 +1752,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1752
  if self.nout_ > 1:
1753
  output = []
1754
  for s in sympy_representation:
1755
- raw_latex = sympy.latex(s)
1756
- reduced_latex = set_precision_of_constants_in_string(
1757
- raw_latex, precision
1758
- )
1759
- output.append(reduced_latex)
1760
  return output
1761
- raw_latex = sympy.latex(sympy_representation)
1762
- return set_precision_of_constants_in_string(raw_latex, precision)
1763
 
1764
  def jax(self, index=None):
1765
  """
 
28
  )
29
  from .export_numpy import CallableEquation
30
  from .export_latex import (
31
+ to_latex,
32
  generate_top_of_latex_table,
33
  generate_bottom_of_latex_table,
34
  )
 
1752
  if self.nout_ > 1:
1753
  output = []
1754
  for s in sympy_representation:
1755
+ latex = to_latex(s, prec=precision)
1756
+ output.append(latex)
 
 
 
1757
  return output
1758
+ return to_latex(sympy_representation, prec=precision)
 
1759
 
1760
  def jax(self, index=None):
1761
  """
test/test.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  from sklearn import model_selection
7
  from pysr import PySRRegressor
8
  from pysr.sr import run_feature_selection, _handle_feature_selection
 
9
  from sklearn.utils.estimator_checks import check_estimator
10
  import sympy
11
  import pandas as pd
@@ -573,3 +574,25 @@ class TestLaTeXTable(unittest.TestCase):
573
  """
574
  true_latex_table_str = self.create_true_latex(middle_part)
575
  self.assertEqual(latex_table_str, true_latex_table_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from sklearn import model_selection
7
  from pysr import PySRRegressor
8
  from pysr.sr import run_feature_selection, _handle_feature_selection
9
+ from pysr.export_latex import to_latex
10
  from sklearn.utils.estimator_checks import check_estimator
11
  import sympy
12
  import pandas as pd
 
574
  """
575
  true_latex_table_str = self.create_true_latex(middle_part)
576
  self.assertEqual(latex_table_str, true_latex_table_str)
577
+
578
+ def test_latex_float_precision(self):
579
+ """Test that we can print latex expressions with custom precision"""
580
+ expr = sympy.Float(4583.4485748, dps=50)
581
+ self.assertEqual(to_latex(expr, prec=6), r"4583.45")
582
+ self.assertEqual(to_latex(expr, prec=5), r"4583.4")
583
+ self.assertEqual(to_latex(expr, prec=4), r"4583.0")
584
+ self.assertEqual(to_latex(expr, prec=3), r"4.58 \cdot 10^{3}")
585
+ self.assertEqual(to_latex(expr, prec=2), r"4.6 \cdot 10^{3}")
586
+
587
+ # Multiple numbers:
588
+ x = sympy.Symbol("x")
589
+ expr = x * 3232.324857384 - 1.4857485e-10
590
+ self.assertEqual(
591
+ to_latex(expr, prec=2), "3.2 \cdot 10^{3} x - 1.5 \cdot 10^{-10}"
592
+ )
593
+ self.assertEqual(
594
+ to_latex(expr, prec=3), "3.23 \cdot 10^{3} x - 1.49 \cdot 10^{-10}"
595
+ )
596
+ self.assertEqual(
597
+ to_latex(expr, prec=8), "3232.3249 x - 1.4857485 \cdot 10^{-10}"
598
+ )