PySR / pysr /export_latex.py
MilesCranmer's picture
Move `latex_table` code to `export_latex.py`
c5cd4bb
raw
history blame
3.6 kB
"""Functions to help export PySR equations to LaTeX."""
import sympy
from sympy.printing.latex import LatexPrinter
import pandas as pd
from typing import List
class PreciseLatexPrinter(LatexPrinter):
"""Modified SymPy printer with custom float precision."""
def __init__(self, settings=None, prec=3):
super().__init__(settings)
self.prec = prec
def _print_Float(self, expr):
# Reduce precision of float:
reduced_float = sympy.Float(expr, self.prec)
return super()._print_Float(reduced_float)
def to_latex(expr, prec=3, full_prec=True, **settings):
"""Convert sympy expression to LaTeX with custom precision."""
settings["full_prec"] = full_prec
printer = PreciseLatexPrinter(settings=settings, prec=prec)
return printer.doprint(expr)
def generate_table_environment(columns=["equation", "complexity", "loss"]):
margins = "".join([("l" if col == "equation" else "c") for col in columns])
column_map = {
"complexity": "Complexity",
"loss": "Loss",
"equation": "Equation",
"score": "Score",
}
columns = [column_map[col] for col in columns]
top_pieces = [
r"\begin{table}[h]",
r"\begin{center}",
r"\begin{tabular}{@{}" + margins + r"@{}}",
r"\toprule",
" & ".join(columns) + r" \\",
r"\midrule",
]
bottom_pieces = [
r"\bottomrule",
r"\end{tabular}",
r"\end{center}",
r"\end{table}",
]
top_latex_table = "\n".join(top_pieces)
bottom_latex_table = "\n".join(bottom_pieces)
return top_latex_table, bottom_latex_table
def generate_single_table(
equations: pd.DataFrame,
indices: List[int] = None,
precision: int = 3,
columns=["equation", "complexity", "loss", "score"],
):
assert isinstance(equations, pd.DataFrame)
latex_top, latex_bottom = generate_table_environment(columns)
latex_table_content = []
if indices is None:
indices = range(len(equations))
for i in indices:
latex_equation = to_latex(
equations.iloc[i]["sympy_format"],
prec=precision,
)
complexity = str(equations.iloc[i]["complexity"])
loss = to_latex(
sympy.Float(equations.iloc[i]["loss"]),
prec=precision,
)
score = to_latex(
sympy.Float(equations.iloc[i]["score"]),
prec=precision,
)
row_pieces = []
for col in columns:
if col == "equation":
row_pieces.append(latex_equation)
elif col == "complexity":
row_pieces.append(complexity)
elif col == "loss":
row_pieces.append(loss)
elif col == "score":
row_pieces.append(score)
else:
raise ValueError(f"Unknown column: {col}")
row_pieces = ["$" + piece + "$" for piece in row_pieces]
latex_table_content.append(
" & ".join(row_pieces) + r" \\",
)
return "\n".join([latex_top, *latex_table_content, latex_bottom])
def generate_multiple_tables(
equations: List[pd.DataFrame],
indices: List[List[int]] = None,
precision: int = 3,
columns=["equation", "complexity", "loss", "score"],
):
latex_tables = [
generate_single_table(
equations[i],
(None if not indices else indices[i]),
precision=precision,
columns=columns,
)
for i in range(len(equations))
]
return "\n\n".join(latex_tables)