Spaces:
Running
Running
MilesCranmer
commited on
Commit
·
8f218cc
1
Parent(s):
6d5ddcb
Include score in latex table; apply precision
Browse files- pysr/export_latex.py +4 -3
- pysr/sr.py +22 -4
pysr/export_latex.py
CHANGED
@@ -11,13 +11,14 @@ def set_precision_of_constants_in_string(s, precision=3):
|
|
11 |
return s
|
12 |
|
13 |
|
14 |
-
def generate_top_of_latex_table():
|
|
|
15 |
latex_table_pieces = [
|
16 |
r"\begin{table}[h]",
|
17 |
r"\begin{center}",
|
18 |
-
r"\begin{tabular}{@{}
|
19 |
r"\toprule",
|
20 |
-
|
21 |
r"\midrule",
|
22 |
]
|
23 |
return "\n".join(latex_table_pieces)
|
|
|
11 |
return s
|
12 |
|
13 |
|
14 |
+
def generate_top_of_latex_table(columns=["Equation", "Complexity", "Loss"]):
|
15 |
+
margins = "".join([("c" if col == "Equation" else "l") for col in columns])
|
16 |
latex_table_pieces = [
|
17 |
r"\begin{table}[h]",
|
18 |
r"\begin{center}",
|
19 |
+
r"\begin{tabular}{@{}" + margins + r"@{}}",
|
20 |
r"\toprule",
|
21 |
+
" & ".join(columns) + r" \\",
|
22 |
r"\midrule",
|
23 |
]
|
24 |
return "\n".join(latex_table_pieces)
|
pysr/sr.py
CHANGED
@@ -2004,7 +2004,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2004 |
return ret_outputs
|
2005 |
return ret_outputs[0]
|
2006 |
|
2007 |
-
def latex_table(self, indices=None, precision=3):
|
2008 |
"""Create a LaTeX/booktabs table for all, or some, of the equations.
|
2009 |
|
2010 |
Parameters
|
@@ -2016,6 +2016,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2016 |
precision : int, default=3
|
2017 |
The number of significant figures shown in the LaTeX
|
2018 |
representations.
|
|
|
|
|
2019 |
|
2020 |
Returns
|
2021 |
-------
|
@@ -2028,17 +2030,33 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2028 |
)
|
2029 |
if indices is None:
|
2030 |
indices = range(len(self.equations_))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2031 |
latex_table_content = []
|
2032 |
for i in indices:
|
2033 |
equation = self.latex(i, precision=precision)
|
2034 |
-
|
2035 |
-
loss =
|
|
|
|
|
|
|
|
|
|
|
2036 |
row_pieces = ["$" + equation + "$", complexity, loss]
|
|
|
|
|
|
|
2037 |
latex_table_content.append(
|
2038 |
" & ".join(row_pieces) + r" \\",
|
2039 |
)
|
2040 |
-
|
2041 |
latex_table_bottom = generate_bottom_of_latex_table()
|
|
|
2042 |
latex_table_str = "\n".join(
|
2043 |
[
|
2044 |
latex_table_top,
|
|
|
2004 |
return ret_outputs
|
2005 |
return ret_outputs[0]
|
2006 |
|
2007 |
+
def latex_table(self, indices=None, precision=3, include_score=False):
|
2008 |
"""Create a LaTeX/booktabs table for all, or some, of the equations.
|
2009 |
|
2010 |
Parameters
|
|
|
2016 |
precision : int, default=3
|
2017 |
The number of significant figures shown in the LaTeX
|
2018 |
representations.
|
2019 |
+
include_score : bool, default=False
|
2020 |
+
Whether to include the score in the table.
|
2021 |
|
2022 |
Returns
|
2023 |
-------
|
|
|
2030 |
)
|
2031 |
if indices is None:
|
2032 |
indices = range(len(self.equations_))
|
2033 |
+
|
2034 |
+
columns = ["Equation", "Complexity", "Loss"]
|
2035 |
+
if include_score:
|
2036 |
+
columns.append("Score")
|
2037 |
+
|
2038 |
+
latex_table_top = generate_top_of_latex_table(columns)
|
2039 |
+
|
2040 |
latex_table_content = []
|
2041 |
for i in indices:
|
2042 |
equation = self.latex(i, precision=precision)
|
2043 |
+
# Also convert these to reduced precision:
|
2044 |
+
# loss = self.equations_.iloc[i]["loss"]
|
2045 |
+
# score = self.equations_.iloc[i]["score"]
|
2046 |
+
complexity = "{:d}".format(self.equations_.iloc[i]["complexity"])
|
2047 |
+
loss = "{:.{p}g}".format(self.equations_.iloc[i]["loss"], p=precision)
|
2048 |
+
score = "{:.{p}g}".format(self.equations_.iloc[i]["score"], p=precision)
|
2049 |
+
|
2050 |
row_pieces = ["$" + equation + "$", complexity, loss]
|
2051 |
+
if include_score:
|
2052 |
+
row_pieces.append(score)
|
2053 |
+
|
2054 |
latex_table_content.append(
|
2055 |
" & ".join(row_pieces) + r" \\",
|
2056 |
)
|
2057 |
+
|
2058 |
latex_table_bottom = generate_bottom_of_latex_table()
|
2059 |
+
|
2060 |
latex_table_str = "\n".join(
|
2061 |
[
|
2062 |
latex_table_top,
|