tttc3 commited on
Commit
4819728
·
1 Parent(s): 9750ff9

Added validation for weights

Browse files
Files changed (1) hide show
  1. pysr/sr.py +17 -7
pysr/sr.py CHANGED
@@ -13,7 +13,11 @@ from datetime import datetime
13
  import warnings
14
  from multiprocessing import cpu_count
15
  from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
16
- from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
 
 
 
 
17
 
18
  from .julia_helpers import (
19
  init_julia,
@@ -980,13 +984,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
980
  parameter_value, str
981
  ):
982
  parameter_value = [parameter_value]
983
- elif parameter is "batch_size" and parameter_value < 1:
984
  warnings.warn(
985
  "Given :param`batch_size` must be greater than or equal to one. "
986
  ":param`batch_size` has been increased to equal one."
987
  )
988
  parameter_value = 1
989
- elif parameter is "progress" and not buffer_available:
990
  warnings.warn(
991
  "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
992
  )
@@ -1000,7 +1004,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1000
  )
1001
  return packed_modified_params
1002
 
1003
- def _validate_fit_params(self, X, y, Xresampled, variable_names):
1004
  """
1005
  Validates the parameters passed to the :term`fit` method.
1006
 
@@ -1018,6 +1022,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1018
  (n_resampled, n_features), default=None
1019
  Resampled training data used for denoising.
1020
 
 
 
 
 
1021
  variable_names : list[str] of length n_features
1022
  Names of each variable in the training dataset, `X`.
1023
 
@@ -1064,6 +1072,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1064
  # This method sets the n_features_in_ attribute
1065
  if Xresampled is not None:
1066
  Xresampled = check_array(Xresampled)
 
 
1067
  X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
1068
  self.feature_names_in_ = _check_feature_names_in(self, variable_names)
1069
  variable_names = self.feature_names_in_
@@ -1076,7 +1086,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1076
  else:
1077
  raise NotImplementedError("y shape not supported!")
1078
 
1079
- return X, y, Xresampled, variable_names
1080
 
1081
  def _pre_transform_training_data(
1082
  self, X, y, Xresampled, variable_names, random_state
@@ -1452,8 +1462,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1452
  mutated_params = self._validate_init_params()
1453
 
1454
  # Parameter input validation (for parameters defined in __init__)
1455
- X, y, Xresampled, variable_names = self._validate_fit_params(
1456
- X, y, Xresampled, variable_names
1457
  )
1458
 
1459
  if X.shape[0] > 10000 and not self.batching:
 
13
  import warnings
14
  from multiprocessing import cpu_count
15
  from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
16
+ from sklearn.utils.validation import (
17
+ _check_feature_names_in,
18
+ _check_sample_weight,
19
+ check_is_fitted,
20
+ )
21
 
22
  from .julia_helpers import (
23
  init_julia,
 
984
  parameter_value, str
985
  ):
986
  parameter_value = [parameter_value]
987
+ elif parameter == "batch_size" and parameter_value < 1:
988
  warnings.warn(
989
  "Given :param`batch_size` must be greater than or equal to one. "
990
  ":param`batch_size` has been increased to equal one."
991
  )
992
  parameter_value = 1
993
+ elif parameter == "progress" and not buffer_available:
994
  warnings.warn(
995
  "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
996
  )
 
1004
  )
1005
  return packed_modified_params
1006
 
1007
+ def _validate_fit_params(self, X, y, Xresampled, weights, variable_names):
1008
  """
1009
  Validates the parameters passed to the :term`fit` method.
1010
 
 
1022
  (n_resampled, n_features), default=None
1023
  Resampled training data used for denoising.
1024
 
1025
+ weights : {ndarray | pandas.DataFrame} of the same shape as y
1026
+ Each element is how to weight the mean-square-error loss
1027
+ for that particular element of y.
1028
+
1029
  variable_names : list[str] of length n_features
1030
  Names of each variable in the training dataset, `X`.
1031
 
 
1072
  # This method sets the n_features_in_ attribute
1073
  if Xresampled is not None:
1074
  Xresampled = check_array(Xresampled)
1075
+ if weights is not None:
1076
+ weights = _check_sample_weight(weights, y)
1077
  X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
1078
  self.feature_names_in_ = _check_feature_names_in(self, variable_names)
1079
  variable_names = self.feature_names_in_
 
1086
  else:
1087
  raise NotImplementedError("y shape not supported!")
1088
 
1089
+ return X, y, Xresampled, weights, variable_names
1090
 
1091
  def _pre_transform_training_data(
1092
  self, X, y, Xresampled, variable_names, random_state
 
1462
  mutated_params = self._validate_init_params()
1463
 
1464
  # Parameter input validation (for parameters defined in __init__)
1465
+ X, y, Xresampled, weights, variable_names = self._validate_fit_params(
1466
+ X, y, Xresampled, weights, variable_names
1467
  )
1468
 
1469
  if X.shape[0] > 10000 and not self.batching: