Spaces:
Running
Running
MilesCranmer
commited on
Refactor sympy and export functionality
Browse files- pysr/export_latex.py +7 -7
- pysr/export_numpy.py +6 -3
- pysr/export_sympy.py +72 -0
- pysr/sr.py +26 -84
- pysr/test/test.py +9 -10
pysr/export_latex.py
CHANGED
@@ -19,7 +19,7 @@ class PreciseLatexPrinter(LatexPrinter):
|
|
19 |
return super()._print_Float(reduced_float)
|
20 |
|
21 |
|
22 |
-
def
|
23 |
"""Convert sympy expression to LaTeX with custom precision."""
|
24 |
settings["full_prec"] = full_prec
|
25 |
printer = PreciseLatexPrinter(settings=settings, prec=prec)
|
@@ -56,7 +56,7 @@ def generate_table_environment(columns=["equation", "complexity", "loss"]):
|
|
56 |
return top_latex_table, bottom_latex_table
|
57 |
|
58 |
|
59 |
-
def
|
60 |
equations: pd.DataFrame,
|
61 |
indices: List[int] = None,
|
62 |
precision: int = 3,
|
@@ -74,16 +74,16 @@ def generate_single_table(
|
|
74 |
indices = range(len(equations))
|
75 |
|
76 |
for i in indices:
|
77 |
-
latex_equation =
|
78 |
equations.iloc[i]["sympy_format"],
|
79 |
prec=precision,
|
80 |
)
|
81 |
complexity = str(equations.iloc[i]["complexity"])
|
82 |
-
loss =
|
83 |
sympy.Float(equations.iloc[i]["loss"]),
|
84 |
prec=precision,
|
85 |
)
|
86 |
-
score =
|
87 |
sympy.Float(equations.iloc[i]["score"]),
|
88 |
prec=precision,
|
89 |
)
|
@@ -124,7 +124,7 @@ def generate_single_table(
|
|
124 |
return "\n".join([latex_top, *latex_table_content, latex_bottom])
|
125 |
|
126 |
|
127 |
-
def
|
128 |
equations: List[pd.DataFrame],
|
129 |
indices: List[List[int]] = None,
|
130 |
precision: int = 3,
|
@@ -135,7 +135,7 @@ def generate_multiple_tables(
|
|
135 |
# TODO: Let user specify custom output variable
|
136 |
|
137 |
latex_tables = [
|
138 |
-
|
139 |
equations[i],
|
140 |
(None if not indices else indices[i]),
|
141 |
precision=precision,
|
|
|
19 |
return super()._print_Float(reduced_float)
|
20 |
|
21 |
|
22 |
+
def sympy2latex(expr, prec=3, full_prec=True, **settings):
|
23 |
"""Convert sympy expression to LaTeX with custom precision."""
|
24 |
settings["full_prec"] = full_prec
|
25 |
printer = PreciseLatexPrinter(settings=settings, prec=prec)
|
|
|
56 |
return top_latex_table, bottom_latex_table
|
57 |
|
58 |
|
59 |
+
def sympy2latextable(
|
60 |
equations: pd.DataFrame,
|
61 |
indices: List[int] = None,
|
62 |
precision: int = 3,
|
|
|
74 |
indices = range(len(equations))
|
75 |
|
76 |
for i in indices:
|
77 |
+
latex_equation = sympy2latex(
|
78 |
equations.iloc[i]["sympy_format"],
|
79 |
prec=precision,
|
80 |
)
|
81 |
complexity = str(equations.iloc[i]["complexity"])
|
82 |
+
loss = sympy2latex(
|
83 |
sympy.Float(equations.iloc[i]["loss"]),
|
84 |
prec=precision,
|
85 |
)
|
86 |
+
score = sympy2latex(
|
87 |
sympy.Float(equations.iloc[i]["score"]),
|
88 |
prec=precision,
|
89 |
)
|
|
|
124 |
return "\n".join([latex_top, *latex_table_content, latex_bottom])
|
125 |
|
126 |
|
127 |
+
def sympy2multilatextable(
|
128 |
equations: List[pd.DataFrame],
|
129 |
indices: List[List[int]] = None,
|
130 |
precision: int = 3,
|
|
|
135 |
# TODO: Let user specify custom output variable
|
136 |
|
137 |
latex_tables = [
|
138 |
+
sympy2latextable(
|
139 |
equations[i],
|
140 |
(None if not indices else indices[i]),
|
141 |
precision=precision,
|
pysr/export_numpy.py
CHANGED
@@ -6,14 +6,17 @@ import pandas as pd
|
|
6 |
from sympy import lambdify
|
7 |
|
8 |
|
|
|
|
|
|
|
|
|
9 |
class CallableEquation:
|
10 |
"""Simple wrapper for numpy lambda functions built with sympy"""
|
11 |
|
12 |
-
def __init__(self,
|
13 |
self._sympy = eqn
|
14 |
self._sympy_symbols = sympy_symbols
|
15 |
self._selection = selection
|
16 |
-
self._variable_names = variable_names
|
17 |
|
18 |
def __repr__(self):
|
19 |
return f"PySRFunction(X=>{self._sympy})"
|
@@ -23,7 +26,7 @@ class CallableEquation:
|
|
23 |
if isinstance(X, pd.DataFrame):
|
24 |
# Lambda function takes as argument:
|
25 |
return self._lambda(
|
26 |
-
**{k: X[k].values for k in self.
|
27 |
) * np.ones(expected_shape)
|
28 |
if self._selection is not None:
|
29 |
if X.shape[1] != len(self._selection):
|
|
|
6 |
from sympy import lambdify
|
7 |
|
8 |
|
9 |
+
def sympy2numpy(eqn, sympy_symbols, *, selection=None):
|
10 |
+
return CallableEquation(eqn, sympy_symbols, selection=selection)
|
11 |
+
|
12 |
+
|
13 |
class CallableEquation:
|
14 |
"""Simple wrapper for numpy lambda functions built with sympy"""
|
15 |
|
16 |
+
def __init__(self, eqn, sympy_symbols, selection=None):
|
17 |
self._sympy = eqn
|
18 |
self._sympy_symbols = sympy_symbols
|
19 |
self._selection = selection
|
|
|
20 |
|
21 |
def __repr__(self):
|
22 |
return f"PySRFunction(X=>{self._sympy})"
|
|
|
26 |
if isinstance(X, pd.DataFrame):
|
27 |
# Lambda function takes as argument:
|
28 |
return self._lambda(
|
29 |
+
**{k: X[k].values for k in map(str, self._sympy_symbols)}
|
30 |
) * np.ones(expected_shape)
|
31 |
if self._selection is not None:
|
32 |
if X.shape[1] != len(self._selection):
|
pysr/export_sympy.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Define utilities to export to sympy"""
|
2 |
+
from typing import Callable, Dict, List, Optional
|
3 |
+
|
4 |
+
import sympy
|
5 |
+
from sympy import sympify
|
6 |
+
|
7 |
+
sympy_mappings = {
|
8 |
+
"div": lambda x, y: x / y,
|
9 |
+
"mult": lambda x, y: x * y,
|
10 |
+
"sqrt": lambda x: sympy.sqrt(x),
|
11 |
+
"sqrt_abs": lambda x: sympy.sqrt(abs(x)),
|
12 |
+
"square": lambda x: x**2,
|
13 |
+
"cube": lambda x: x**3,
|
14 |
+
"plus": lambda x, y: x + y,
|
15 |
+
"sub": lambda x, y: x - y,
|
16 |
+
"neg": lambda x: -x,
|
17 |
+
"pow": lambda x, y: x**y,
|
18 |
+
"pow_abs": lambda x, y: abs(x) ** y,
|
19 |
+
"cos": sympy.cos,
|
20 |
+
"sin": sympy.sin,
|
21 |
+
"tan": sympy.tan,
|
22 |
+
"cosh": sympy.cosh,
|
23 |
+
"sinh": sympy.sinh,
|
24 |
+
"tanh": sympy.tanh,
|
25 |
+
"exp": sympy.exp,
|
26 |
+
"acos": sympy.acos,
|
27 |
+
"asin": sympy.asin,
|
28 |
+
"atan": sympy.atan,
|
29 |
+
"acosh": lambda x: sympy.acosh(x),
|
30 |
+
"acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
|
31 |
+
"asinh": sympy.asinh,
|
32 |
+
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
|
33 |
+
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
|
34 |
+
"abs": abs,
|
35 |
+
"mod": sympy.Mod,
|
36 |
+
"erf": sympy.erf,
|
37 |
+
"erfc": sympy.erfc,
|
38 |
+
"log": lambda x: sympy.log(x),
|
39 |
+
"log10": lambda x: sympy.log(x, 10),
|
40 |
+
"log2": lambda x: sympy.log(x, 2),
|
41 |
+
"log1p": lambda x: sympy.log(x + 1),
|
42 |
+
"log_abs": lambda x: sympy.log(abs(x)),
|
43 |
+
"log10_abs": lambda x: sympy.log(abs(x), 10),
|
44 |
+
"log2_abs": lambda x: sympy.log(abs(x), 2),
|
45 |
+
"log1p_abs": lambda x: sympy.log(abs(x) + 1),
|
46 |
+
"floor": sympy.floor,
|
47 |
+
"ceil": sympy.ceiling,
|
48 |
+
"sign": sympy.sign,
|
49 |
+
"gamma": sympy.gamma,
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
def create_sympy_symbols(
|
54 |
+
feature_names_in: Optional[List[str]] = None,
|
55 |
+
) -> List[sympy.Symbol]:
|
56 |
+
return [sympy.Symbol(variable) for variable in feature_names_in]
|
57 |
+
|
58 |
+
|
59 |
+
def pysr2sympy(
|
60 |
+
equation: str, *, extra_sympy_mappings: Optional[Dict[str, Callable]] = None
|
61 |
+
) -> sympy.Expr:
|
62 |
+
local_sympy_mappings = {
|
63 |
+
**(extra_sympy_mappings if extra_sympy_mappings else {}),
|
64 |
+
**sympy_mappings,
|
65 |
+
}
|
66 |
+
|
67 |
+
return sympify(equation, locals=local_sympy_mappings)
|
68 |
+
|
69 |
+
|
70 |
+
def assert_valid_sympy_symbol(var_name: str) -> None:
|
71 |
+
if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
|
72 |
+
raise ValueError(f"Variable name {var_name} is already a function name.")
|
pysr/sr.py
CHANGED
@@ -14,15 +14,16 @@ from pathlib import Path
|
|
14 |
|
15 |
import numpy as np
|
16 |
import pandas as pd
|
17 |
-
import sympy
|
18 |
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
|
19 |
from sklearn.utils import check_array, check_consistent_length, check_random_state
|
20 |
from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
|
21 |
-
from sympy import sympify
|
22 |
|
23 |
from .deprecated import make_deprecated_kwargs_for_pysr_regressor
|
24 |
-
from .
|
25 |
-
from .
|
|
|
|
|
|
|
26 |
from .julia_helpers import (
|
27 |
_escape_filename,
|
28 |
_load_backend,
|
@@ -37,51 +38,6 @@ Main = None # TODO: Rename to more descriptive name like "julia_runtime"
|
|
37 |
|
38 |
already_ran = False
|
39 |
|
40 |
-
sympy_mappings = {
|
41 |
-
"div": lambda x, y: x / y,
|
42 |
-
"mult": lambda x, y: x * y,
|
43 |
-
"sqrt": lambda x: sympy.sqrt(x),
|
44 |
-
"sqrt_abs": lambda x: sympy.sqrt(abs(x)),
|
45 |
-
"square": lambda x: x**2,
|
46 |
-
"cube": lambda x: x**3,
|
47 |
-
"plus": lambda x, y: x + y,
|
48 |
-
"sub": lambda x, y: x - y,
|
49 |
-
"neg": lambda x: -x,
|
50 |
-
"pow": lambda x, y: x**y,
|
51 |
-
"pow_abs": lambda x, y: abs(x) ** y,
|
52 |
-
"cos": sympy.cos,
|
53 |
-
"sin": sympy.sin,
|
54 |
-
"tan": sympy.tan,
|
55 |
-
"cosh": sympy.cosh,
|
56 |
-
"sinh": sympy.sinh,
|
57 |
-
"tanh": sympy.tanh,
|
58 |
-
"exp": sympy.exp,
|
59 |
-
"acos": sympy.acos,
|
60 |
-
"asin": sympy.asin,
|
61 |
-
"atan": sympy.atan,
|
62 |
-
"acosh": lambda x: sympy.acosh(x),
|
63 |
-
"acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
|
64 |
-
"asinh": sympy.asinh,
|
65 |
-
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
|
66 |
-
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
|
67 |
-
"abs": abs,
|
68 |
-
"mod": sympy.Mod,
|
69 |
-
"erf": sympy.erf,
|
70 |
-
"erfc": sympy.erfc,
|
71 |
-
"log": lambda x: sympy.log(x),
|
72 |
-
"log10": lambda x: sympy.log(x, 10),
|
73 |
-
"log2": lambda x: sympy.log(x, 2),
|
74 |
-
"log1p": lambda x: sympy.log(x + 1),
|
75 |
-
"log_abs": lambda x: sympy.log(abs(x)),
|
76 |
-
"log10_abs": lambda x: sympy.log(abs(x), 10),
|
77 |
-
"log2_abs": lambda x: sympy.log(abs(x), 2),
|
78 |
-
"log1p_abs": lambda x: sympy.log(abs(x) + 1),
|
79 |
-
"floor": sympy.floor,
|
80 |
-
"ceil": sympy.ceiling,
|
81 |
-
"sign": sympy.sign,
|
82 |
-
"gamma": sympy.gamma,
|
83 |
-
}
|
84 |
-
|
85 |
|
86 |
def pysr(X, y, weights=None, **kwargs): # pragma: no cover
|
87 |
warnings.warn(
|
@@ -188,10 +144,6 @@ def _check_assertions(
|
|
188 |
assert len(variable_names) == X.shape[1]
|
189 |
# Check none of the variable names are function names:
|
190 |
for var_name in variable_names:
|
191 |
-
if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
|
192 |
-
raise ValueError(
|
193 |
-
f"Variable name {var_name} is already a function name."
|
194 |
-
)
|
195 |
# Check if alphanumeric only:
|
196 |
if not re.match(r"^[ββββββ
ββββa-zA-Z0-9_]+$", var_name):
|
197 |
raise ValueError(
|
@@ -199,6 +151,7 @@ def _check_assertions(
|
|
199 |
"Only alphanumeric characters, numbers, "
|
200 |
"and underscores are allowed."
|
201 |
)
|
|
|
202 |
if X_units is not None and len(X_units) != X.shape[1]:
|
203 |
raise ValueError(
|
204 |
"The number of units in `X_units` must equal the number of features in `X`."
|
@@ -2116,10 +2069,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2116 |
if self.nout_ > 1:
|
2117 |
output = []
|
2118 |
for s in sympy_representation:
|
2119 |
-
latex =
|
2120 |
output.append(latex)
|
2121 |
return output
|
2122 |
-
return
|
2123 |
|
2124 |
def jax(self, index=None):
|
2125 |
"""
|
@@ -2282,53 +2235,41 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2282 |
jax_format = []
|
2283 |
if self.output_torch_format:
|
2284 |
torch_format = []
|
2285 |
-
local_sympy_mappings = {
|
2286 |
-
**(self.extra_sympy_mappings if self.extra_sympy_mappings else {}),
|
2287 |
-
**sympy_mappings,
|
2288 |
-
}
|
2289 |
-
|
2290 |
-
sympy_symbols = [
|
2291 |
-
sympy.Symbol(variable) for variable in self.feature_names_in_
|
2292 |
-
]
|
2293 |
|
2294 |
for _, eqn_row in output.iterrows():
|
2295 |
-
eqn =
|
|
|
|
|
|
|
2296 |
sympy_format.append(eqn)
|
2297 |
|
2298 |
-
#
|
|
|
2299 |
lambda_format.append(
|
2300 |
-
|
2301 |
-
|
|
|
|
|
2302 |
)
|
2303 |
)
|
2304 |
|
2305 |
# JAX:
|
2306 |
if self.output_jax_format:
|
2307 |
-
from .export_jax import sympy2jax
|
2308 |
-
|
2309 |
func, params = sympy2jax(
|
2310 |
eqn,
|
2311 |
sympy_symbols,
|
2312 |
selection=self.selection_mask_,
|
2313 |
-
extra_jax_mappings=
|
2314 |
-
self.extra_jax_mappings if self.extra_jax_mappings else {}
|
2315 |
-
),
|
2316 |
)
|
2317 |
jax_format.append({"callable": func, "parameters": params})
|
2318 |
|
2319 |
# Torch:
|
2320 |
if self.output_torch_format:
|
2321 |
-
from .export_torch import sympy2torch
|
2322 |
-
|
2323 |
module = sympy2torch(
|
2324 |
eqn,
|
2325 |
sympy_symbols,
|
2326 |
selection=self.selection_mask_,
|
2327 |
-
extra_torch_mappings=
|
2328 |
-
self.extra_torch_mappings
|
2329 |
-
if self.extra_torch_mappings
|
2330 |
-
else {}
|
2331 |
-
),
|
2332 |
)
|
2333 |
torch_format.append(module)
|
2334 |
|
@@ -2410,17 +2351,18 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2410 |
assert isinstance(indices[0], list)
|
2411 |
assert len(indices) == self.nout_
|
2412 |
|
2413 |
-
|
|
|
|
|
2414 |
else:
|
2415 |
if indices is not None:
|
2416 |
assert isinstance(indices, list)
|
2417 |
assert isinstance(indices[0], int)
|
2418 |
|
2419 |
-
|
|
|
|
|
2420 |
|
2421 |
-
table_string = generator_fnc(
|
2422 |
-
self.equations_, indices=indices, precision=precision, columns=columns
|
2423 |
-
)
|
2424 |
preamble_string = [
|
2425 |
r"\usepackage{breqn}",
|
2426 |
r"\usepackage{booktabs}",
|
|
|
14 |
|
15 |
import numpy as np
|
16 |
import pandas as pd
|
|
|
17 |
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
|
18 |
from sklearn.utils import check_array, check_consistent_length, check_random_state
|
19 |
from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
|
|
|
20 |
|
21 |
from .deprecated import make_deprecated_kwargs_for_pysr_regressor
|
22 |
+
from .export_jax import sympy2jax
|
23 |
+
from .export_latex import sympy2latex, sympy2latextable, sympy2multilatextable
|
24 |
+
from .export_numpy import sympy2numpy
|
25 |
+
from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2sympy
|
26 |
+
from .export_torch import sympy2torch
|
27 |
from .julia_helpers import (
|
28 |
_escape_filename,
|
29 |
_load_backend,
|
|
|
38 |
|
39 |
already_ran = False
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
def pysr(X, y, weights=None, **kwargs): # pragma: no cover
|
43 |
warnings.warn(
|
|
|
144 |
assert len(variable_names) == X.shape[1]
|
145 |
# Check none of the variable names are function names:
|
146 |
for var_name in variable_names:
|
|
|
|
|
|
|
|
|
147 |
# Check if alphanumeric only:
|
148 |
if not re.match(r"^[ββββββ
ββββa-zA-Z0-9_]+$", var_name):
|
149 |
raise ValueError(
|
|
|
151 |
"Only alphanumeric characters, numbers, "
|
152 |
"and underscores are allowed."
|
153 |
)
|
154 |
+
assert_valid_sympy_symbol(var_name)
|
155 |
if X_units is not None and len(X_units) != X.shape[1]:
|
156 |
raise ValueError(
|
157 |
"The number of units in `X_units` must equal the number of features in `X`."
|
|
|
2069 |
if self.nout_ > 1:
|
2070 |
output = []
|
2071 |
for s in sympy_representation:
|
2072 |
+
latex = sympy2latex(s, prec=precision)
|
2073 |
output.append(latex)
|
2074 |
return output
|
2075 |
+
return sympy2latex(sympy_representation, prec=precision)
|
2076 |
|
2077 |
def jax(self, index=None):
|
2078 |
"""
|
|
|
2235 |
jax_format = []
|
2236 |
if self.output_torch_format:
|
2237 |
torch_format = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2238 |
|
2239 |
for _, eqn_row in output.iterrows():
|
2240 |
+
eqn = pysr2sympy(
|
2241 |
+
eqn_row["equation"],
|
2242 |
+
extra_sympy_mappings=self.extra_sympy_mappings,
|
2243 |
+
)
|
2244 |
sympy_format.append(eqn)
|
2245 |
|
2246 |
+
# NumPy:
|
2247 |
+
sympy_symbols = create_sympy_symbols(self.feature_names_in_)
|
2248 |
lambda_format.append(
|
2249 |
+
sympy2numpy(
|
2250 |
+
eqn,
|
2251 |
+
sympy_symbols,
|
2252 |
+
selection=self.selection_mask_,
|
2253 |
)
|
2254 |
)
|
2255 |
|
2256 |
# JAX:
|
2257 |
if self.output_jax_format:
|
|
|
|
|
2258 |
func, params = sympy2jax(
|
2259 |
eqn,
|
2260 |
sympy_symbols,
|
2261 |
selection=self.selection_mask_,
|
2262 |
+
extra_jax_mappings=self.extra_jax_mappings,
|
|
|
|
|
2263 |
)
|
2264 |
jax_format.append({"callable": func, "parameters": params})
|
2265 |
|
2266 |
# Torch:
|
2267 |
if self.output_torch_format:
|
|
|
|
|
2268 |
module = sympy2torch(
|
2269 |
eqn,
|
2270 |
sympy_symbols,
|
2271 |
selection=self.selection_mask_,
|
2272 |
+
extra_torch_mappings=self.extra_torch_mappings,
|
|
|
|
|
|
|
|
|
2273 |
)
|
2274 |
torch_format.append(module)
|
2275 |
|
|
|
2351 |
assert isinstance(indices[0], list)
|
2352 |
assert len(indices) == self.nout_
|
2353 |
|
2354 |
+
table_string = sympy2multilatextable(
|
2355 |
+
self.equations_, indices=indices, precision=precision, columns=columns
|
2356 |
+
)
|
2357 |
else:
|
2358 |
if indices is not None:
|
2359 |
assert isinstance(indices, list)
|
2360 |
assert isinstance(indices[0], int)
|
2361 |
|
2362 |
+
table_string = sympy2latextable(
|
2363 |
+
self.equations_, indices=indices, precision=precision, columns=columns
|
2364 |
+
)
|
2365 |
|
|
|
|
|
|
|
2366 |
preamble_string = [
|
2367 |
r"\usepackage{breqn}",
|
2368 |
r"\usepackage{booktabs}",
|
pysr/test/test.py
CHANGED
@@ -10,11 +10,10 @@ from pathlib import Path
|
|
10 |
import numpy as np
|
11 |
import pandas as pd
|
12 |
import sympy
|
13 |
-
from sklearn import model_selection
|
14 |
from sklearn.utils.estimator_checks import check_estimator
|
15 |
|
16 |
from .. import PySRRegressor, julia_helpers
|
17 |
-
from ..export_latex import
|
18 |
from ..sr import (
|
19 |
_check_assertions,
|
20 |
_csv_filename_to_pkl_filename,
|
@@ -884,23 +883,23 @@ class TestLaTeXTable(unittest.TestCase):
|
|
884 |
def test_latex_float_precision(self):
|
885 |
"""Test that we can print latex expressions with custom precision"""
|
886 |
expr = sympy.Float(4583.4485748, dps=50)
|
887 |
-
self.assertEqual(
|
888 |
-
self.assertEqual(
|
889 |
-
self.assertEqual(
|
890 |
-
self.assertEqual(
|
891 |
-
self.assertEqual(
|
892 |
|
893 |
# Multiple numbers:
|
894 |
x = sympy.Symbol("x")
|
895 |
expr = x * 3232.324857384 - 1.4857485e-10
|
896 |
self.assertEqual(
|
897 |
-
|
898 |
)
|
899 |
self.assertEqual(
|
900 |
-
|
901 |
)
|
902 |
self.assertEqual(
|
903 |
-
|
904 |
)
|
905 |
|
906 |
def test_latex_break_long_equation(self):
|
|
|
10 |
import numpy as np
|
11 |
import pandas as pd
|
12 |
import sympy
|
|
|
13 |
from sklearn.utils.estimator_checks import check_estimator
|
14 |
|
15 |
from .. import PySRRegressor, julia_helpers
|
16 |
+
from ..export_latex import sympy2latex
|
17 |
from ..sr import (
|
18 |
_check_assertions,
|
19 |
_csv_filename_to_pkl_filename,
|
|
|
883 |
def test_latex_float_precision(self):
|
884 |
"""Test that we can print latex expressions with custom precision"""
|
885 |
expr = sympy.Float(4583.4485748, dps=50)
|
886 |
+
self.assertEqual(sympy2latex(expr, prec=6), r"4583.45")
|
887 |
+
self.assertEqual(sympy2latex(expr, prec=5), r"4583.4")
|
888 |
+
self.assertEqual(sympy2latex(expr, prec=4), r"4583.")
|
889 |
+
self.assertEqual(sympy2latex(expr, prec=3), r"4.58 \cdot 10^{3}")
|
890 |
+
self.assertEqual(sympy2latex(expr, prec=2), r"4.6 \cdot 10^{3}")
|
891 |
|
892 |
# Multiple numbers:
|
893 |
x = sympy.Symbol("x")
|
894 |
expr = x * 3232.324857384 - 1.4857485e-10
|
895 |
self.assertEqual(
|
896 |
+
sympy2latex(expr, prec=2), r"3.2 \cdot 10^{3} x - 1.5 \cdot 10^{-10}"
|
897 |
)
|
898 |
self.assertEqual(
|
899 |
+
sympy2latex(expr, prec=3), r"3.23 \cdot 10^{3} x - 1.49 \cdot 10^{-10}"
|
900 |
)
|
901 |
self.assertEqual(
|
902 |
+
sympy2latex(expr, prec=8), r"3232.3249 x - 1.4857485 \cdot 10^{-10}"
|
903 |
)
|
904 |
|
905 |
def test_latex_break_long_equation(self):
|