MilesCranmer commited on
Commit
526d334
·
unverified ·
1 Parent(s): a5eaab9

fix: type inference issue in return value of get_best

Browse files
Files changed (1) hide show
  1. pysr/sr.py +7 -4
pysr/sr.py CHANGED
@@ -1179,8 +1179,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1179
  Raised when an invalid model selection strategy is provided.
1180
  """
1181
  check_is_fitted(self, attributes=["equations_"])
1182
- if self.equations_ is None:
1183
- raise ValueError("No equations have been generated yet.")
1184
 
1185
  if index is not None:
1186
  if isinstance(self.equations_, list):
@@ -1188,17 +1186,22 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1188
  index, list
1189
  ), "With multiple output features, index must be a list."
1190
  return [eq.iloc[i] for eq, i in zip(self.equations_, index)]
1191
- return self.equations_.iloc[index]
 
 
 
1192
 
1193
  if isinstance(self.equations_, list):
1194
  return [
1195
  eq.loc[idx_model_selection(eq, self.model_selection)]
1196
  for eq in self.equations_
1197
  ]
1198
- else:
1199
  return self.equations_.loc[
1200
  idx_model_selection(self.equations_, self.model_selection)
1201
  ]
 
 
1202
 
1203
  def _setup_equation_file(self):
1204
  """
 
1179
  Raised when an invalid model selection strategy is provided.
1180
  """
1181
  check_is_fitted(self, attributes=["equations_"])
 
 
1182
 
1183
  if index is not None:
1184
  if isinstance(self.equations_, list):
 
1186
  index, list
1187
  ), "With multiple output features, index must be a list."
1188
  return [eq.iloc[i] for eq, i in zip(self.equations_, index)]
1189
+ elif isinstance(self.equations_, pd.DataFrame):
1190
+ return self.equations_.iloc[index]
1191
+ else:
1192
+ raise ValueError("No equations have been generated yet.")
1193
 
1194
  if isinstance(self.equations_, list):
1195
  return [
1196
  eq.loc[idx_model_selection(eq, self.model_selection)]
1197
  for eq in self.equations_
1198
  ]
1199
+ elif isinstance(self.equations_, pd.DataFrame):
1200
  return self.equations_.loc[
1201
  idx_model_selection(self.equations_, self.model_selection)
1202
  ]
1203
+ else:
1204
+ raise ValueError("No equations have been generated yet.")
1205
 
1206
  def _setup_equation_file(self):
1207
  """