Spaces:
Running
Running
tttc3
commited on
Commit
•
6881818
1
Parent(s):
3e8d44d
Updated parameter validation
Browse files- pysr/sr.py +112 -97
pysr/sr.py
CHANGED
@@ -529,6 +529,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
529 |
List of indices for input features that are selected when
|
530 |
:param`select_k_features` is set.
|
531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
raw_julia_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap]
|
533 |
The state for the julia SymbolicRegression.jl backend post fitting.
|
534 |
|
@@ -928,6 +934,71 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
928 |
else:
|
929 |
self.equation_file_ = self.equation_file
|
930 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
931 |
def _validate_fit_params(self, X, y, Xresampled, variable_names):
|
932 |
"""
|
933 |
Validates the parameters passed to the :term`fit` method.
|
@@ -965,39 +1036,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
965 |
|
966 |
"""
|
967 |
|
968 |
-
# Ensure instance parameters are allowable values:
|
969 |
-
if self.tournament_selection_n > self.population_size:
|
970 |
-
raise ValueError(
|
971 |
-
"tournament_selection_n parameter must be smaller than population_size."
|
972 |
-
)
|
973 |
-
|
974 |
-
if self.maxsize > 40:
|
975 |
-
warnings.warn(
|
976 |
-
"Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
|
977 |
-
)
|
978 |
-
elif self.maxsize < 7:
|
979 |
-
raise ValueError("PySR requires a maxsize of at least 7")
|
980 |
-
|
981 |
-
if self.extra_jax_mappings is not None:
|
982 |
-
for value in self.extra_jax_mappings.values():
|
983 |
-
if not isinstance(value, str):
|
984 |
-
raise ValueError(
|
985 |
-
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
986 |
-
)
|
987 |
-
|
988 |
-
if self.extra_torch_mappings is not None:
|
989 |
-
for value in self.extra_jax_mappings.values():
|
990 |
-
if not callable(value):
|
991 |
-
raise ValueError(
|
992 |
-
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
993 |
-
)
|
994 |
-
|
995 |
-
# NotImplementedError - Values that could be supported at a later time
|
996 |
-
if self.optimizer_algorithm not in VALID_OPTIMIZER_ALGORITHMS:
|
997 |
-
raise NotImplementedError(
|
998 |
-
f"PySR currently only supports the following optimizer algorithms: {VALID_OPTIMIZER_ALGORITHMS}"
|
999 |
-
)
|
1000 |
-
|
1001 |
if isinstance(X, pd.DataFrame):
|
1002 |
if variable_names:
|
1003 |
variable_names = None
|
@@ -1020,13 +1058,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1020 |
"Spaces have been replaced with underscores. \n"
|
1021 |
"Please use valid names instead."
|
1022 |
)
|
1023 |
-
# Only numpy values are needed from Xresampled, column metadata is
|
1024 |
-
# provided by X
|
1025 |
-
if isinstance(Xresampled, pd.DataFrame):
|
1026 |
-
Xresampled = Xresampled.values
|
1027 |
|
1028 |
# Data validation and feature name fetching via sklearn
|
1029 |
# This method sets the n_features_in_ attribute
|
|
|
1030 |
X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
|
1031 |
self.feature_names_in_ = _check_feature_names_in(self, variable_names)
|
1032 |
variable_names = self.feature_names_in_
|
@@ -1126,7 +1161,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1126 |
|
1127 |
return X, y, variable_names
|
1128 |
|
1129 |
-
def _run(self, X, y, weights, seed):
|
1130 |
"""
|
1131 |
Run the symbolic regression fitting process on the julia backend.
|
1132 |
|
@@ -1138,10 +1173,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1138 |
y : {ndarray | pandas.DataFrame} of shape (n_samples,) or (n_samples, n_targets)
|
1139 |
Target values. Will be cast to X's dtype if necessary.
|
1140 |
|
1141 |
-
|
|
|
|
|
|
|
1142 |
Each element is how to weight the mean-square-error loss
|
1143 |
for that particular element of y.
|
1144 |
|
|
|
|
|
|
|
1145 |
Returns
|
1146 |
-------
|
1147 |
self : object
|
@@ -1159,66 +1200,17 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1159 |
|
1160 |
# These are the parameters which may be modified from the ones
|
1161 |
# specified in init, so we define them here locally:
|
1162 |
-
binary_operators =
|
1163 |
-
unary_operators =
|
1164 |
-
|
|
|
1165 |
nested_constraints = self.nested_constraints
|
1166 |
complexity_of_operators = self.complexity_of_operators
|
1167 |
-
multithreading =
|
1168 |
-
update_verbosity = self.update_verbosity
|
1169 |
-
maxdepth = self.maxdepth
|
1170 |
-
batch_size = self.batch_size
|
1171 |
-
progress = self.progress
|
1172 |
cluster_manager = self.cluster_manager
|
1173 |
-
|
1174 |
-
|
1175 |
-
|
1176 |
-
|
1177 |
-
# Deal with default values, and type conversions:
|
1178 |
-
if binary_operators is None:
|
1179 |
-
binary_operators = "+ * - /".split(" ")
|
1180 |
-
elif isinstance(binary_operators, str):
|
1181 |
-
binary_operators = [binary_operators]
|
1182 |
-
|
1183 |
-
if unary_operators is None:
|
1184 |
-
unary_operators = []
|
1185 |
-
elif isinstance(unary_operators, str):
|
1186 |
-
unary_operators = [unary_operators]
|
1187 |
-
|
1188 |
-
assert len(unary_operators) + len(binary_operators) > 0
|
1189 |
-
|
1190 |
-
if constraints is None:
|
1191 |
-
constraints = {}
|
1192 |
-
|
1193 |
-
if multithreading is None:
|
1194 |
-
# Default is multithreading=True, unless explicitly set,
|
1195 |
-
# or procs is set to 0 (serial mode).
|
1196 |
-
multithreading = self.procs != 0 and cluster_manager is None
|
1197 |
-
|
1198 |
-
if update_verbosity is None:
|
1199 |
-
update_verbosity = self.verbosity
|
1200 |
-
|
1201 |
-
if maxdepth is None:
|
1202 |
-
maxdepth = self.maxsize
|
1203 |
-
|
1204 |
-
# Warn if instance parameters are not sensible values:
|
1205 |
-
if batch_size < 1:
|
1206 |
-
warnings.warn(
|
1207 |
-
"Given :param`batch_size` must be greater than or equal to one. "
|
1208 |
-
":param`batch_size` has been increased to equal one."
|
1209 |
-
)
|
1210 |
-
batch_size = 1
|
1211 |
-
|
1212 |
-
# Handle presentation of the progress bar:
|
1213 |
-
buffer_available = "buffer" in sys.stdout.__dir__()
|
1214 |
-
if progress is not None:
|
1215 |
-
if progress and not buffer_available:
|
1216 |
-
warnings.warn(
|
1217 |
-
"Note: it looks like you are running in Jupyter. The progress bar will be turned off."
|
1218 |
-
)
|
1219 |
-
progress = False
|
1220 |
-
else:
|
1221 |
-
progress = buffer_available
|
1222 |
|
1223 |
# Start julia backend processes
|
1224 |
if Main is None:
|
@@ -1455,6 +1447,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1455 |
|
1456 |
self._setup_equation_file()
|
1457 |
|
|
|
|
|
1458 |
# Parameter input validation (for parameters defined in __init__)
|
1459 |
X, y, Xresampled, variable_names = self._validate_fit_params(
|
1460 |
X, y, Xresampled, variable_names
|
@@ -1505,7 +1499,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1505 |
)
|
1506 |
|
1507 |
# Fitting procedure
|
1508 |
-
return self._run(X
|
1509 |
|
1510 |
def refresh(self, checkpoint_file=None):
|
1511 |
"""
|
@@ -1736,6 +1730,27 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1736 |
"Couldn't find equation file! The equation search likely exited before a single iteration completed."
|
1737 |
)
|
1738 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1739 |
ret_outputs = []
|
1740 |
|
1741 |
for output in all_outputs:
|
|
|
529 |
List of indices for input features that are selected when
|
530 |
:param`select_k_features` is set.
|
531 |
|
532 |
+
tempdir_ : Path
|
533 |
+
Path to the temporary equations directory.
|
534 |
+
|
535 |
+
equation_file_ : str
|
536 |
+
Output equation file name produced by the julia backend.
|
537 |
+
|
538 |
raw_julia_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap]
|
539 |
The state for the julia SymbolicRegression.jl backend post fitting.
|
540 |
|
|
|
934 |
else:
|
935 |
self.equation_file_ = self.equation_file
|
936 |
|
937 |
+
def _validate_init_params(self):
|
938 |
+
|
939 |
+
# Immutable parameter validation
|
940 |
+
# Ensure instance parameters are allowable values:
|
941 |
+
if self.tournament_selection_n > self.population_size:
|
942 |
+
raise ValueError(
|
943 |
+
"tournament_selection_n parameter must be smaller than population_size."
|
944 |
+
)
|
945 |
+
|
946 |
+
if self.maxsize > 40:
|
947 |
+
warnings.warn(
|
948 |
+
"Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
|
949 |
+
)
|
950 |
+
elif self.maxsize < 7:
|
951 |
+
raise ValueError("PySR requires a maxsize of at least 7")
|
952 |
+
|
953 |
+
# NotImplementedError - Values that could be supported at a later time
|
954 |
+
if self.optimizer_algorithm not in VALID_OPTIMIZER_ALGORITHMS:
|
955 |
+
raise NotImplementedError(
|
956 |
+
f"PySR currently only supports the following optimizer algorithms: {VALID_OPTIMIZER_ALGORITHMS}"
|
957 |
+
)
|
958 |
+
|
959 |
+
# 'Mutable' parameter validation
|
960 |
+
buffer_available = "buffer" in sys.stdout.__dir__()
|
961 |
+
modifiable_params = {
|
962 |
+
"binary_operators": "+ * - /".split(" "),
|
963 |
+
"unary_operators": [],
|
964 |
+
"maxdepth": self.maxsize,
|
965 |
+
"constraints": {},
|
966 |
+
"multithreading": self.procs != 0 and self.cluster_manager is None,
|
967 |
+
"batch_size": 1,
|
968 |
+
"update_verbosity": self.verbosity,
|
969 |
+
"progress": buffer_available,
|
970 |
+
}
|
971 |
+
packed_modified_params = {}
|
972 |
+
for parameter, default_value in modifiable_params.items():
|
973 |
+
parameter_value = getattr(self, parameter)
|
974 |
+
if parameter_value is None:
|
975 |
+
parameter_value = default_value
|
976 |
+
else:
|
977 |
+
# Special cases such as when binary_operators is a string
|
978 |
+
if parameter in ["binary_operators", "unary_operators"] and isinstance(
|
979 |
+
parameter_value, str
|
980 |
+
):
|
981 |
+
parameter_value = [parameter_value]
|
982 |
+
elif parameter is "batch_size" and parameter_value < 1:
|
983 |
+
warnings.warn(
|
984 |
+
"Given :param`batch_size` must be greater than or equal to one. "
|
985 |
+
":param`batch_size` has been increased to equal one."
|
986 |
+
)
|
987 |
+
parameter_value = 1
|
988 |
+
elif parameter is "progress" and not buffer_available:
|
989 |
+
warnings.warn(
|
990 |
+
"Note: it looks like you are running in Jupyter. The progress bar will be turned off."
|
991 |
+
)
|
992 |
+
parameter_value = False
|
993 |
+
packed_modified_params[parameter] = parameter_value
|
994 |
+
|
995 |
+
assert (
|
996 |
+
len(packed_modified_params["binary_operators"])
|
997 |
+
+ len(packed_modified_params["unary_operators"])
|
998 |
+
> 0
|
999 |
+
)
|
1000 |
+
return packed_modified_params
|
1001 |
+
|
1002 |
def _validate_fit_params(self, X, y, Xresampled, variable_names):
|
1003 |
"""
|
1004 |
Validates the parameters passed to the :term`fit` method.
|
|
|
1036 |
|
1037 |
"""
|
1038 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1039 |
if isinstance(X, pd.DataFrame):
|
1040 |
if variable_names:
|
1041 |
variable_names = None
|
|
|
1058 |
"Spaces have been replaced with underscores. \n"
|
1059 |
"Please use valid names instead."
|
1060 |
)
|
|
|
|
|
|
|
|
|
1061 |
|
1062 |
# Data validation and feature name fetching via sklearn
|
1063 |
# This method sets the n_features_in_ attribute
|
1064 |
+
Xresampled = check_array(Xresampled)
|
1065 |
X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
|
1066 |
self.feature_names_in_ = _check_feature_names_in(self, variable_names)
|
1067 |
variable_names = self.feature_names_in_
|
|
|
1161 |
|
1162 |
return X, y, variable_names
|
1163 |
|
1164 |
+
def _run(self, X, y, mutated_params, weights, seed):
|
1165 |
"""
|
1166 |
Run the symbolic regression fitting process on the julia backend.
|
1167 |
|
|
|
1173 |
y : {ndarray | pandas.DataFrame} of shape (n_samples,) or (n_samples, n_targets)
|
1174 |
Target values. Will be cast to X's dtype if necessary.
|
1175 |
|
1176 |
+
mutated_params : dict[str, Any]
|
1177 |
+
Dictionary of mutated versions of some parameters passed in __init__.
|
1178 |
+
|
1179 |
+
weights : {ndarray | pandas.DataFrame} of the same shape as y
|
1180 |
Each element is how to weight the mean-square-error loss
|
1181 |
for that particular element of y.
|
1182 |
|
1183 |
+
seed : int
|
1184 |
+
Random seed for julia backend process.
|
1185 |
+
|
1186 |
Returns
|
1187 |
-------
|
1188 |
self : object
|
|
|
1200 |
|
1201 |
# These are the parameters which may be modified from the ones
|
1202 |
# specified in init, so we define them here locally:
|
1203 |
+
binary_operators = mutated_params["binary_operators"]
|
1204 |
+
unary_operators = mutated_params["unary_operators"]
|
1205 |
+
maxdepth = mutated_params["maxdepth"]
|
1206 |
+
constraints = mutated_params["constraints"]
|
1207 |
nested_constraints = self.nested_constraints
|
1208 |
complexity_of_operators = self.complexity_of_operators
|
1209 |
+
multithreading = mutated_params["multithreading"]
|
|
|
|
|
|
|
|
|
1210 |
cluster_manager = self.cluster_manager
|
1211 |
+
batch_size = mutated_params["batch_size"]
|
1212 |
+
update_verbosity = mutated_params["update_verbosity"]
|
1213 |
+
progress = mutated_params["progress"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1214 |
|
1215 |
# Start julia backend processes
|
1216 |
if Main is None:
|
|
|
1447 |
|
1448 |
self._setup_equation_file()
|
1449 |
|
1450 |
+
mutated_params = self._validate_init_params()
|
1451 |
+
|
1452 |
# Parameter input validation (for parameters defined in __init__)
|
1453 |
X, y, Xresampled, variable_names = self._validate_fit_params(
|
1454 |
X, y, Xresampled, variable_names
|
|
|
1499 |
)
|
1500 |
|
1501 |
# Fitting procedure
|
1502 |
+
return self._run(X, y, mutated_params, weights=weights, seed=seed)
|
1503 |
|
1504 |
def refresh(self, checkpoint_file=None):
|
1505 |
"""
|
|
|
1730 |
"Couldn't find equation file! The equation search likely exited before a single iteration completed."
|
1731 |
)
|
1732 |
|
1733 |
+
# It is expected extra_jax/torch_mappings will be updated after fit.
|
1734 |
+
# Thus, validation is performed here instead of in _validate_init_params
|
1735 |
+
extra_jax_mappings = self.extra_jax_mappings
|
1736 |
+
extra_torch_mappings = self.extra_torch_mappings
|
1737 |
+
if extra_jax_mappings is not None:
|
1738 |
+
for value in self.extra_jax_mappings.values():
|
1739 |
+
if not isinstance(value, str):
|
1740 |
+
raise ValueError(
|
1741 |
+
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
1742 |
+
)
|
1743 |
+
else:
|
1744 |
+
extra_jax_mappings = {}
|
1745 |
+
if extra_torch_mappings is not None:
|
1746 |
+
for value in self.extra_jax_mappings.values():
|
1747 |
+
if not callable(value):
|
1748 |
+
raise ValueError(
|
1749 |
+
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
1750 |
+
)
|
1751 |
+
else:
|
1752 |
+
extra_torch_mappings = {}
|
1753 |
+
|
1754 |
ret_outputs = []
|
1755 |
|
1756 |
for output in all_outputs:
|