MilesCranmer commited on
Commit
c5cd4bb
1 Parent(s): c6f5c09

Move `latex_table` code to `export_latex.py`

Browse files
Files changed (2) hide show
  1. pysr/export_latex.py +62 -51
  2. pysr/sr.py +11 -19
pysr/export_latex.py CHANGED
@@ -55,61 +55,72 @@ def generate_table_environment(columns=["equation", "complexity", "loss"]):
55
  return top_latex_table, bottom_latex_table
56
 
57
 
58
- def generate_table(
59
- equations: List[pd.DataFrame],
60
- indices: List[List[int]],
61
- precision=3,
62
  columns=["equation", "complexity", "loss", "score"],
63
  ):
 
 
64
  latex_top, latex_bottom = generate_table_environment(columns)
 
65
 
66
- latex_equations = [
67
- [to_latex(eq, prec=precision) for eq in equation_set["sympy_format"]]
68
- for equation_set in equations
69
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- all_latex_table_str = []
72
-
73
- for output_feature, index_set in enumerate(indices):
74
- latex_table_content = []
75
- for i in index_set:
76
- latex_equation = latex_equations[output_feature][i]
77
- complexity = str(equations[output_feature].iloc[i]["complexity"])
78
- loss = to_latex(
79
- sympy.Float(equations[output_feature].iloc[i]["loss"]),
80
- prec=precision,
81
- )
82
- score = to_latex(
83
- sympy.Float(equations[output_feature].iloc[i]["score"]),
84
- prec=precision,
85
- )
86
-
87
- row_pieces = []
88
- for col in columns:
89
- if col == "equation":
90
- row_pieces.append(latex_equation)
91
- elif col == "complexity":
92
- row_pieces.append(complexity)
93
- elif col == "loss":
94
- row_pieces.append(loss)
95
- elif col == "score":
96
- row_pieces.append(score)
97
- else:
98
- raise ValueError(f"Unknown column: {col}")
99
-
100
- row_pieces = ["$" + piece + "$" for piece in row_pieces]
101
-
102
- latex_table_content.append(
103
- " & ".join(row_pieces) + r" \\",
104
- )
105
-
106
- this_latex_table = "\n".join(
107
- [
108
- latex_top,
109
- *latex_table_content,
110
- latex_bottom,
111
- ]
112
  )
113
- all_latex_table_str.append(this_latex_table)
 
114
 
115
- return "\n\n".join(all_latex_table_str)
 
55
  return top_latex_table, bottom_latex_table
56
 
57
 
58
+ def generate_single_table(
59
+ equations: pd.DataFrame,
60
+ indices: List[int] = None,
61
+ precision: int = 3,
62
  columns=["equation", "complexity", "loss", "score"],
63
  ):
64
+ assert isinstance(equations, pd.DataFrame)
65
+
66
  latex_top, latex_bottom = generate_table_environment(columns)
67
+ latex_table_content = []
68
 
69
+ if indices is None:
70
+ indices = range(len(equations))
71
+
72
+ for i in indices:
73
+ latex_equation = to_latex(
74
+ equations.iloc[i]["sympy_format"],
75
+ prec=precision,
76
+ )
77
+ complexity = str(equations.iloc[i]["complexity"])
78
+ loss = to_latex(
79
+ sympy.Float(equations.iloc[i]["loss"]),
80
+ prec=precision,
81
+ )
82
+ score = to_latex(
83
+ sympy.Float(equations.iloc[i]["score"]),
84
+ prec=precision,
85
+ )
86
+
87
+ row_pieces = []
88
+ for col in columns:
89
+ if col == "equation":
90
+ row_pieces.append(latex_equation)
91
+ elif col == "complexity":
92
+ row_pieces.append(complexity)
93
+ elif col == "loss":
94
+ row_pieces.append(loss)
95
+ elif col == "score":
96
+ row_pieces.append(score)
97
+ else:
98
+ raise ValueError(f"Unknown column: {col}")
99
+
100
+ row_pieces = ["$" + piece + "$" for piece in row_pieces]
101
+
102
+ latex_table_content.append(
103
+ " & ".join(row_pieces) + r" \\",
104
+ )
105
+
106
+ return "\n".join([latex_top, *latex_table_content, latex_bottom])
107
+
108
+
109
+ def generate_multiple_tables(
110
+ equations: List[pd.DataFrame],
111
+ indices: List[List[int]] = None,
112
+ precision: int = 3,
113
+ columns=["equation", "complexity", "loss", "score"],
114
+ ):
115
 
116
+ latex_tables = [
117
+ generate_single_table(
118
+ equations[i],
119
+ (None if not indices else indices[i]),
120
+ precision=precision,
121
+ columns=columns,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
+ for i in range(len(equations))
124
+ ]
125
 
126
+ return "\n\n".join(latex_tables)
pysr/sr.py CHANGED
@@ -27,7 +27,7 @@ from .julia_helpers import (
27
  import_error_string,
28
  )
29
  from .export_numpy import CallableEquation
30
- from .export_latex import to_latex, generate_table
31
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
32
 
33
 
@@ -2024,26 +2024,18 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2024
  """
2025
  self.refresh()
2026
 
2027
- # All indices:
2028
- if indices is None:
2029
- if self.nout_ > 1:
2030
- indices = [
2031
- list(range(len(out_equations))) for out_equations in self.equations_
2032
- ]
2033
- else:
2034
- indices = list(range(len(self.equations_)))
2035
-
2036
- equations = self.equations_
2037
-
2038
- if isinstance(indices[0], int):
2039
- assert self.nout_ == 1, "For multiple outputs, pass a list of lists."
2040
- indices = [indices]
2041
- equations = [equations]
2042
 
2043
- assert len(indices) == self.nout_
 
 
2044
 
2045
- return generate_table(
2046
- equations, indices=indices, precision=precision, columns=columns
2047
  )
2048
 
2049
 
 
27
  import_error_string,
28
  )
29
  from .export_numpy import CallableEquation
30
+ from .export_latex import generate_single_table, generate_multiple_tables, to_latex
31
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
32
 
33
 
 
2024
  """
2025
  self.refresh()
2026
 
2027
+ if self.nout_ > 1:
2028
+ if indices is not None:
2029
+ assert isinstance(indices, list)
2030
+ assert isinstance(indices[0], list)
2031
+ assert isinstance(len(indices), self.nout_)
 
 
 
 
 
 
 
 
 
 
2032
 
2033
+ generator_fnc = generate_multiple_tables
2034
+ else:
2035
+ generator_fnc = generate_single_table
2036
 
2037
+ return generator_fnc(
2038
+ self.equations_, indices=indices, precision=precision, columns=columns
2039
  )
2040
 
2041