Spaces:
Running
Running
File size: 4,849 Bytes
f257e58 b896bd3 976f8d8 9a5df63 fab6f87 9a5df63 118c5f6 6024c83 9a5df63 6024c83 9a5df63 b896bd3 9a5df63 6024c83 9a5df63 f257e58 b896bd3 fab6f87 215a692 d423f0c f257e58 6d5ddcb 8f218cc f257e58 8f218cc f257e58 d423f0c f257e58 6d5ddcb f257e58 d423f0c c6f5c09 b2d7f41 c5cd4bb b896bd3 c5cd4bb b896bd3 2a802ab 3ef2b32 b896bd3 de4d559 c5cd4bb c6f5c09 c5cd4bb c6f5c09 c5cd4bb b896bd3 c5cd4bb b2d7f41 c5cd4bb b2d7f41 c5cd4bb b2d7f41 c5cd4bb 2a802ab 3ef2b32 2a802ab fab6f87 2a802ab fab6f87 3ef2b32 fab6f87 2a802ab c5cd4bb 2a802ab c5cd4bb 2a802ab c5cd4bb 2a802ab c5cd4bb b2d7f41 c5cd4bb b896bd3 c5cd4bb b896bd3 de4d559 3ef2b32 c6f5c09 c5cd4bb b2d7f41 c5cd4bb 3ef2b32 c6f5c09 c5cd4bb c6f5c09 c5cd4bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
"""Functions to help export PySR equations to LaTeX."""
from typing import List, Optional, Tuple
import pandas as pd
import sympy
from sympy.printing.latex import LatexPrinter
class PreciseLatexPrinter(LatexPrinter):
"""Modified SymPy printer with custom float precision."""
def __init__(self, settings=None, prec=3):
super().__init__(settings)
self.prec = prec
def _print_Float(self, expr):
# Reduce precision of float:
reduced_float = sympy.Float(expr, self.prec)
return super()._print_Float(reduced_float)
def sympy2latex(expr, prec=3, full_prec=True, **settings) -> str:
"""Convert sympy expression to LaTeX with custom precision."""
settings["full_prec"] = full_prec
printer = PreciseLatexPrinter(settings=settings, prec=prec)
return printer.doprint(expr)
def generate_table_environment(
columns: List[str] = ["equation", "complexity", "loss"]
) -> Tuple[str, str]:
margins = "c" * len(columns)
column_map = {
"complexity": "Complexity",
"loss": "Loss",
"equation": "Equation",
"score": "Score",
}
columns = [column_map[col] for col in columns]
top_pieces = [
r"\begin{table}[h]",
r"\begin{center}",
r"\begin{tabular}{@{}" + margins + r"@{}}",
r"\toprule",
" & ".join(columns) + r" \\",
r"\midrule",
]
bottom_pieces = [
r"\bottomrule",
r"\end{tabular}",
r"\end{center}",
r"\end{table}",
]
top_latex_table = "\n".join(top_pieces)
bottom_latex_table = "\n".join(bottom_pieces)
return top_latex_table, bottom_latex_table
def sympy2latextable(
equations: pd.DataFrame,
indices: Optional[List[int]] = None,
precision: int = 3,
columns: List[str] = ["equation", "complexity", "loss", "score"],
max_equation_length: int = 50,
output_variable_name: str = "y",
) -> str:
"""Generate a booktabs-style LaTeX table for a single set of equations."""
assert isinstance(equations, pd.DataFrame)
latex_top, latex_bottom = generate_table_environment(columns)
latex_table_content = []
if indices is None:
indices = list(equations.index)
for i in indices:
latex_equation = sympy2latex(
equations.iloc[i]["sympy_format"],
prec=precision,
)
complexity = str(equations.iloc[i]["complexity"])
loss = sympy2latex(
sympy.Float(equations.iloc[i]["loss"]),
prec=precision,
)
score = sympy2latex(
sympy.Float(equations.iloc[i]["score"]),
prec=precision,
)
row_pieces = []
for col in columns:
if col == "equation":
if len(latex_equation) < max_equation_length:
row_pieces.append(
"$" + output_variable_name + " = " + latex_equation + "$"
)
else:
broken_latex_equation = " ".join(
[
r"\begin{minipage}{0.8\linewidth}",
r"\vspace{-1em}",
r"\begin{dmath*}",
output_variable_name + " = " + latex_equation,
r"\end{dmath*}",
r"\end{minipage}",
]
)
row_pieces.append(broken_latex_equation)
elif col == "complexity":
row_pieces.append("$" + complexity + "$")
elif col == "loss":
row_pieces.append("$" + loss + "$")
elif col == "score":
row_pieces.append("$" + score + "$")
else:
raise ValueError(f"Unknown column: {col}")
latex_table_content.append(
" & ".join(row_pieces) + r" \\",
)
return "\n".join([latex_top, *latex_table_content, latex_bottom])
def sympy2multilatextable(
equations: List[pd.DataFrame],
indices: Optional[List[List[int]]] = None,
precision: int = 3,
columns: List[str] = ["equation", "complexity", "loss", "score"],
output_variable_names: Optional[List[str]] = None,
) -> str:
"""Generate multiple latex tables for a list of equation sets."""
# TODO: Let user specify custom output variable
latex_tables = [
sympy2latextable(
equations[i],
(None if not indices else indices[i]),
precision=precision,
columns=columns,
output_variable_name=(
"y_{" + str(i) + "}"
if output_variable_names is None
else output_variable_names[i]
),
)
for i in range(len(equations))
]
return "\n\n".join(latex_tables)
|