MilesCranmer commited on
Commit
2bd7782
·
unverified ·
1 Parent(s): dca10d6

refactor: improved type inference in return values

Browse files
Files changed (1) hide show
  1. pysr/sr.py +24 -14
pysr/sr.py CHANGED
@@ -2006,11 +2006,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2006
  X = self._validate_data(X, reset=False)
2007
 
2008
  try:
2009
- if self.nout_ > 1:
 
2010
  return np.stack(
2011
  [eq["lambda_format"](X) for eq in best_equation], axis=1
2012
  )
2013
- return best_equation["lambda_format"](X)
 
2014
  except Exception as error:
2015
  raise ValueError(
2016
  "Failed to evaluate the expression. "
@@ -2040,9 +2042,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2040
  """
2041
  self.refresh()
2042
  best_equation = self.get_best(index=index)
2043
- if self.nout_ > 1:
 
2044
  return [eq["sympy_format"] for eq in best_equation]
2045
- return best_equation["sympy_format"]
 
2046
 
2047
  def latex(self, index=None, precision=3):
2048
  """
@@ -2102,9 +2106,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2102
  self.set_params(output_jax_format=True)
2103
  self.refresh()
2104
  best_equation = self.get_best(index=index)
2105
- if self.nout_ > 1:
 
2106
  return [eq["jax_format"] for eq in best_equation]
2107
- return best_equation["jax_format"]
 
2108
 
2109
  def pytorch(self, index=None):
2110
  """
@@ -2132,9 +2138,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2132
  self.set_params(output_torch_format=True)
2133
  self.refresh()
2134
  best_equation = self.get_best(index=index)
2135
- if self.nout_ > 1:
 
 
2136
  return [eq["torch_format"] for eq in best_equation]
2137
- return best_equation["torch_format"]
2138
 
2139
  def _read_equation_file(self):
2140
  """Read the hall of fame file created by `SymbolicRegression.jl`."""
@@ -2233,10 +2240,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2233
  lastComplexity = 0
2234
  sympy_format = []
2235
  lambda_format = []
2236
- if self.output_jax_format:
2237
- jax_format = []
2238
- if self.output_torch_format:
2239
- torch_format = []
2240
 
2241
  for _, eqn_row in output.iterrows():
2242
  eqn = pysr2sympy(
@@ -2348,7 +2353,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2348
  """
2349
  self.refresh()
2350
 
2351
- if self.nout_ > 1:
2352
  if indices is not None:
2353
  assert isinstance(indices, list)
2354
  assert isinstance(indices[0], list)
@@ -2357,7 +2362,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2357
  table_string = sympy2multilatextable(
2358
  self.equations_, indices=indices, precision=precision, columns=columns
2359
  )
2360
- else:
2361
  if indices is not None:
2362
  assert isinstance(indices, list)
2363
  assert isinstance(indices[0], int)
@@ -2365,6 +2370,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2365
  table_string = sympy2latextable(
2366
  self.equations_, indices=indices, precision=precision, columns=columns
2367
  )
 
 
 
 
 
2368
 
2369
  preamble_string = [
2370
  r"\usepackage{breqn}",
 
2006
  X = self._validate_data(X, reset=False)
2007
 
2008
  try:
2009
+ if isinstance(best_equation, list):
2010
+ assert self.nout_ > 1
2011
  return np.stack(
2012
  [eq["lambda_format"](X) for eq in best_equation], axis=1
2013
  )
2014
+ else:
2015
+ return best_equation["lambda_format"](X)
2016
  except Exception as error:
2017
  raise ValueError(
2018
  "Failed to evaluate the expression. "
 
2042
  """
2043
  self.refresh()
2044
  best_equation = self.get_best(index=index)
2045
+ if isinstance(best_equation, list):
2046
+ assert self.nout_ > 1
2047
  return [eq["sympy_format"] for eq in best_equation]
2048
+ else:
2049
+ return best_equation["sympy_format"]
2050
 
2051
  def latex(self, index=None, precision=3):
2052
  """
 
2106
  self.set_params(output_jax_format=True)
2107
  self.refresh()
2108
  best_equation = self.get_best(index=index)
2109
+ if isinstance(best_equation, list):
2110
+ assert self.nout_ > 1
2111
  return [eq["jax_format"] for eq in best_equation]
2112
+ else:
2113
+ return best_equation["jax_format"]
2114
 
2115
  def pytorch(self, index=None):
2116
  """
 
2138
  self.set_params(output_torch_format=True)
2139
  self.refresh()
2140
  best_equation = self.get_best(index=index)
2141
+ if isinstance(best_equation, pd.Series):
2142
+ return best_equation["torch_format"]
2143
+ else:
2144
  return [eq["torch_format"] for eq in best_equation]
 
2145
 
2146
  def _read_equation_file(self):
2147
  """Read the hall of fame file created by `SymbolicRegression.jl`."""
 
2240
  lastComplexity = 0
2241
  sympy_format = []
2242
  lambda_format = []
2243
+ jax_format = []
2244
+ torch_format = []
 
 
2245
 
2246
  for _, eqn_row in output.iterrows():
2247
  eqn = pysr2sympy(
 
2353
  """
2354
  self.refresh()
2355
 
2356
+ if isinstance(self.equations_, list):
2357
  if indices is not None:
2358
  assert isinstance(indices, list)
2359
  assert isinstance(indices[0], list)
 
2362
  table_string = sympy2multilatextable(
2363
  self.equations_, indices=indices, precision=precision, columns=columns
2364
  )
2365
+ elif isinstance(self.equations_, pd.DataFrame):
2366
  if indices is not None:
2367
  assert isinstance(indices, list)
2368
  assert isinstance(indices[0], int)
 
2370
  table_string = sympy2latextable(
2371
  self.equations_, indices=indices, precision=precision, columns=columns
2372
  )
2373
+ else:
2374
+ raise ValueError(
2375
+ "Invalid type for equations_ to pass to `latex_table`. "
2376
+ "Expected a DataFrame or a list of DataFrames."
2377
+ )
2378
 
2379
  preamble_string = [
2380
  r"\usepackage{breqn}",