Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
d3b73f7
1
Parent(s):
7847c48
Fix multi-output scoring
Browse files- pysr/sr.py +24 -20
pysr/sr.py
CHANGED
@@ -372,7 +372,7 @@ def pysr(X=None, y=None, weights=None,
|
|
372 |
|
373 |
|
374 |
def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
|
375 |
-
multioutput, **kwargs):
|
376 |
global global_n_features
|
377 |
global global_equation_file
|
378 |
global global_variable_names
|
@@ -730,7 +730,7 @@ def run_feature_selection(X, y, select_k_features):
|
|
730 |
|
731 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
732 |
extra_sympy_mappings=None, output_jax_format=False,
|
733 |
-
multioutput=False, nout=
|
734 |
"""Get the equations from a hall of fame file. If no arguments
|
735 |
entered, the ones used previously from a call to PySR will be used."""
|
736 |
|
@@ -763,26 +763,28 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
763 |
except FileNotFoundError:
|
764 |
raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
|
765 |
|
766 |
-
scores = []
|
767 |
-
lastMSE = None
|
768 |
-
lastComplexity = 0
|
769 |
-
sympy_format = []
|
770 |
-
lambda_format = []
|
771 |
-
if output_jax_format:
|
772 |
-
jax_format = []
|
773 |
-
use_custom_variable_names = (len(variable_names) != 0)
|
774 |
-
local_sympy_mappings = {
|
775 |
-
**extra_sympy_mappings,
|
776 |
-
**sympy_mappings
|
777 |
-
}
|
778 |
-
|
779 |
-
if use_custom_variable_names:
|
780 |
-
sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
|
781 |
-
else:
|
782 |
-
sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
|
783 |
-
|
784 |
ret_outputs = []
|
|
|
785 |
for output in all_outputs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
786 |
for i in range(len(output)):
|
787 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
788 |
sympy_format.append(eqn)
|
@@ -842,6 +844,7 @@ def best(equations=None):
|
|
842 |
By default this uses the last equation file.
|
843 |
"""
|
844 |
if equations is None: equations = get_hof()
|
|
|
845 |
return [best_row(eq)['sympy_format'].simplify() for eq in equations]
|
846 |
else:
|
847 |
return best_row(equations)['sympy_format'].simplify()
|
@@ -851,6 +854,7 @@ def best_callable(equations=None):
|
|
851 |
By default this uses the last equation file.
|
852 |
"""
|
853 |
if equations is None: equations = get_hof()
|
|
|
854 |
return [best_row(eq)['lambda_format'] for eq in equations]
|
855 |
else:
|
856 |
return best_row(equations)['lambda_format']
|
|
|
372 |
|
373 |
|
374 |
def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
|
375 |
+
multioutput, nout, **kwargs):
|
376 |
global global_n_features
|
377 |
global global_equation_file
|
378 |
global global_variable_names
|
|
|
730 |
|
731 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
732 |
extra_sympy_mappings=None, output_jax_format=False,
|
733 |
+
multioutput=False, nout=None, **kwargs):
|
734 |
"""Get the equations from a hall of fame file. If no arguments
|
735 |
entered, the ones used previously from a call to PySR will be used."""
|
736 |
|
|
|
763 |
except FileNotFoundError:
|
764 |
raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
|
765 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
766 |
ret_outputs = []
|
767 |
+
|
768 |
for output in all_outputs:
|
769 |
+
|
770 |
+
scores = []
|
771 |
+
lastMSE = None
|
772 |
+
lastComplexity = 0
|
773 |
+
sympy_format = []
|
774 |
+
lambda_format = []
|
775 |
+
if output_jax_format:
|
776 |
+
jax_format = []
|
777 |
+
use_custom_variable_names = (len(variable_names) != 0)
|
778 |
+
local_sympy_mappings = {
|
779 |
+
**extra_sympy_mappings,
|
780 |
+
**sympy_mappings
|
781 |
+
}
|
782 |
+
|
783 |
+
if use_custom_variable_names:
|
784 |
+
sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
|
785 |
+
else:
|
786 |
+
sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
|
787 |
+
|
788 |
for i in range(len(output)):
|
789 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
790 |
sympy_format.append(eqn)
|
|
|
844 |
By default this uses the last equation file.
|
845 |
"""
|
846 |
if equations is None: equations = get_hof()
|
847 |
+
if isinstance(equations, list):
|
848 |
return [best_row(eq)['sympy_format'].simplify() for eq in equations]
|
849 |
else:
|
850 |
return best_row(equations)['sympy_format'].simplify()
|
|
|
854 |
By default this uses the last equation file.
|
855 |
"""
|
856 |
if equations is None: equations = get_hof()
|
857 |
+
if isinstance(equations, list):
|
858 |
return [best_row(eq)['lambda_format'] for eq in equations]
|
859 |
else:
|
860 |
return best_row(equations)['lambda_format']
|