MilesCranmer commited on
Commit
e7f614d
2 Parent(s): ad8332d 781f479

Merge pull request #127 from MilesCranmer/Manual-model-selection

Browse files
Files changed (2) hide show
  1. pysr/sr.py +72 -18
  2. test/test.py +19 -0
pysr/sr.py CHANGED
@@ -271,12 +271,15 @@ class CallableEquation:
271
  return f"PySRFunction(X=>{self._sympy})"
272
 
273
  def __call__(self, X):
 
274
  if isinstance(X, pd.DataFrame):
275
  # Lambda function takes as argument:
276
- return self._lambda(**{k: X[k].values for k in X.columns})
 
 
277
  elif self._selection is not None:
278
- return self._lambda(*X[:, self._selection].T)
279
- return self._lambda(*X.T)
280
 
281
 
282
  def _get_julia_project(julia_project):
@@ -779,10 +782,25 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
779
  **{key: self.__getattribute__(key) for key in self.surface_parameters},
780
  }
781
 
782
- def get_best(self):
783
- """Get best equation using `model_selection`."""
 
 
 
 
 
 
 
 
784
  if self.equations is None:
785
  raise ValueError("No equations have been generated yet.")
 
 
 
 
 
 
 
786
  if self.model_selection == "accuracy":
787
  if isinstance(self.equations, list):
788
  return [eq.iloc[-1] for eq in self.equations]
@@ -826,7 +844,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
826
  # such as extra_sympy_mappings.
827
  self.equations = self.get_hof()
828
 
829
- def predict(self, X):
830
  """Predict y from input X using the equation chosen by `model_selection`.
831
 
832
  You may see what equation is used by printing this object. X should have the same
@@ -834,36 +852,64 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
834
 
835
  :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
836
  :type X: np.ndarray/pandas.DataFrame
837
- :return: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs).
 
 
 
 
 
838
  """
839
  self.refresh()
840
- best = self.get_best()
841
  if self.multioutput:
842
  return np.stack([eq["lambda_format"](X) for eq in best], axis=1)
843
  return best["lambda_format"](X)
844
 
845
- def sympy(self):
846
- """Return sympy representation of the equation(s) chosen by `model_selection`."""
 
 
 
 
 
 
 
847
  self.refresh()
848
- best = self.get_best()
849
  if self.multioutput:
850
  return [eq["sympy_format"] for eq in best]
851
  return best["sympy_format"]
852
 
853
- def latex(self):
854
- """Return latex representation of the equation(s) chosen by `model_selection`."""
 
 
 
 
 
 
 
 
855
  self.refresh()
856
- sympy_representation = self.sympy()
857
  if self.multioutput:
858
  return [sympy.latex(s) for s in sympy_representation]
859
  return sympy.latex(sympy_representation)
860
 
861
- def jax(self):
862
  """Return jax representation of the equation(s) chosen by `model_selection`.
863
 
864
  Each equation (multiple given if there are multiple outputs) is a dictionary
865
  containing {"callable": func, "parameters": params}. To call `func`, pass
866
  func(X, params). This function is differentiable using `jax.grad`.
 
 
 
 
 
 
 
 
867
  """
868
  if self.using_pandas:
869
  warnings.warn(
@@ -873,18 +919,26 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
873
  )
874
  self.set_params(output_jax_format=True)
875
  self.refresh()
876
- best = self.get_best()
877
  if self.multioutput:
878
  return [eq["jax_format"] for eq in best]
879
  return best["jax_format"]
880
 
881
- def pytorch(self):
882
  """Return pytorch representation of the equation(s) chosen by `model_selection`.
883
 
884
  Each equation (multiple given if there are multiple outputs) is a PyTorch module
885
  containing the parameters as trainable attributes. You can use the module like
886
  any other PyTorch module: `module(X)`, where `X` is a tensor with the same
887
  column ordering as trained with.
 
 
 
 
 
 
 
 
888
  """
889
  if self.using_pandas:
890
  warnings.warn(
@@ -894,7 +948,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
894
  )
895
  self.set_params(output_torch_format=True)
896
  self.refresh()
897
- best = self.get_best()
898
  if self.multioutput:
899
  return [eq["torch_format"] for eq in best]
900
  return best["torch_format"]
 
271
  return f"PySRFunction(X=>{self._sympy})"
272
 
273
  def __call__(self, X):
274
+ expected_shape = (X.shape[0],)
275
  if isinstance(X, pd.DataFrame):
276
  # Lambda function takes as argument:
277
+ return self._lambda(**{k: X[k].values for k in X.columns}) * np.ones(
278
+ expected_shape
279
+ )
280
  elif self._selection is not None:
281
+ return self._lambda(*X[:, self._selection].T) * np.ones(expected_shape)
282
+ return self._lambda(*X.T) * np.ones(expected_shape)
283
 
284
 
285
  def _get_julia_project(julia_project):
 
782
  **{key: self.__getattribute__(key) for key in self.surface_parameters},
783
  }
784
 
785
+ def get_best(self, index=None):
786
+ """Get best equation using `model_selection`.
787
+
788
+ :param index: Optional. If you wish to select a particular equation
789
+ from `self.equations`, give the row number here. This overrides
790
+ the `model_selection` parameter.
791
+ :type index: int
792
+ :returns: Dictionary representing the best expression found.
793
+ :type: pd.Series
794
+ """
795
  if self.equations is None:
796
  raise ValueError("No equations have been generated yet.")
797
+
798
+ if index is not None:
799
+ if isinstance(self.equations, list):
800
+ assert isinstance(index, list)
801
+ return [eq.iloc[i] for eq, i in zip(self.equations, index)]
802
+ return self.equations.iloc[index]
803
+
804
  if self.model_selection == "accuracy":
805
  if isinstance(self.equations, list):
806
  return [eq.iloc[-1] for eq in self.equations]
 
844
  # such as extra_sympy_mappings.
845
  self.equations = self.get_hof()
846
 
847
+ def predict(self, X, index=None):
848
  """Predict y from input X using the equation chosen by `model_selection`.
849
 
850
  You may see what equation is used by printing this object. X should have the same
 
852
 
853
  :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
854
  :type X: np.ndarray/pandas.DataFrame
855
+ :param index: Optional. If you want to compute the output of
856
+ an expression using a particular row of
857
+ `self.equations`, you may specify the index here.
858
+ :type index: int
859
+ :returns: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs).
860
+ :type: np.ndarray
861
  """
862
  self.refresh()
863
+ best = self.get_best(index=index)
864
  if self.multioutput:
865
  return np.stack([eq["lambda_format"](X) for eq in best], axis=1)
866
  return best["lambda_format"](X)
867
 
868
+ def sympy(self, index=None):
869
+ """Return sympy representation of the equation(s) chosen by `model_selection`.
870
+
871
+ :param index: Optional. If you wish to select a particular equation
872
+ from `self.equations`, give the index number here. This overrides
873
+ the `model_selection` parameter.
874
+ :type index: int
875
+ :returns: SymPy representation of the best expression.
876
+ """
877
  self.refresh()
878
+ best = self.get_best(index=index)
879
  if self.multioutput:
880
  return [eq["sympy_format"] for eq in best]
881
  return best["sympy_format"]
882
 
883
+ def latex(self, index=None):
884
+ """Return latex representation of the equation(s) chosen by `model_selection`.
885
+
886
+ :param index: Optional. If you wish to select a particular equation
887
+ from `self.equations`, give the index number here. This overrides
888
+ the `model_selection` parameter.
889
+ :type index: int
890
+ :returns: LaTeX expression as a string
891
+ :type: str
892
+ """
893
  self.refresh()
894
+ sympy_representation = self.sympy(index=index)
895
  if self.multioutput:
896
  return [sympy.latex(s) for s in sympy_representation]
897
  return sympy.latex(sympy_representation)
898
 
899
+ def jax(self, index=None):
900
  """Return jax representation of the equation(s) chosen by `model_selection`.
901
 
902
  Each equation (multiple given if there are multiple outputs) is a dictionary
903
  containing {"callable": func, "parameters": params}. To call `func`, pass
904
  func(X, params). This function is differentiable using `jax.grad`.
905
+
906
+ :param index: Optional. If you wish to select a particular equation
907
+ from `self.equations`, give the index number here. This overrides
908
+ the `model_selection` parameter.
909
+ :type index: int
910
+ :returns: Dictionary of callable jax function in "callable" key,
911
+ and jax array of parameters as "parameters" key.
912
+ :type: dict
913
  """
914
  if self.using_pandas:
915
  warnings.warn(
 
919
  )
920
  self.set_params(output_jax_format=True)
921
  self.refresh()
922
+ best = self.get_best(index=index)
923
  if self.multioutput:
924
  return [eq["jax_format"] for eq in best]
925
  return best["jax_format"]
926
 
927
+ def pytorch(self, index=None):
928
  """Return pytorch representation of the equation(s) chosen by `model_selection`.
929
 
930
  Each equation (multiple given if there are multiple outputs) is a PyTorch module
931
  containing the parameters as trainable attributes. You can use the module like
932
  any other PyTorch module: `module(X)`, where `X` is a tensor with the same
933
  column ordering as trained with.
934
+
935
+
936
+ :param index: Optional. If you wish to select a particular equation
937
+ from `self.equations`, give the row number here. This overrides
938
+ the `model_selection` parameter.
939
+ :type index: int
940
+ :returns: PyTorch module representing the expression.
941
+ :type: torch.nn.Module
942
  """
943
  if self.using_pandas:
944
  warnings.warn(
 
948
  )
949
  self.set_params(output_torch_format=True)
950
  self.refresh()
951
+ best = self.get_best(index=index)
952
  if self.multioutput:
953
  return [eq["torch_format"] for eq in best]
954
  return best["torch_format"]
test/test.py CHANGED
@@ -52,6 +52,19 @@ class TestPipeline(unittest.TestCase):
52
  self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4)
53
  self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def test_multioutput_weighted_with_callable_temp_equation(self):
56
  y = self.X[:, [0, 1]] ** 2
57
  w = np.random.rand(*y.shape)
@@ -206,6 +219,12 @@ class TestBest(unittest.TestCase):
206
  def test_best(self):
207
  self.assertEqual(self.model.sympy(), sympy.cos(sympy.Symbol("x0")) ** 2)
208
 
 
 
 
 
 
 
209
  def test_best_tex(self):
210
  self.assertEqual(self.model.latex(), "\\cos^{2}{\\left(x_{0} \\right)}")
211
 
 
52
  self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4)
53
  self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4)
54
 
55
+ test_y1 = model.predict(self.X)
56
+ test_y2 = model.predict(self.X, index=[-1, -1])
57
+
58
+ mse1 = np.average((test_y1 - y) ** 2)
59
+ mse2 = np.average((test_y2 - y) ** 2)
60
+
61
+ self.assertLessEqual(mse1, 1e-4)
62
+ self.assertLessEqual(mse2, 1e-4)
63
+
64
+ bad_y = model.predict(self.X, index=[0, 0])
65
+ bad_mse = np.average((bad_y - y) ** 2)
66
+ self.assertGreater(bad_mse, 1e-4)
67
+
68
  def test_multioutput_weighted_with_callable_temp_equation(self):
69
  y = self.X[:, [0, 1]] ** 2
70
  w = np.random.rand(*y.shape)
 
219
  def test_best(self):
220
  self.assertEqual(self.model.sympy(), sympy.cos(sympy.Symbol("x0")) ** 2)
221
 
222
+ def test_index_selection(self):
223
+ self.assertEqual(self.model.sympy(-1), sympy.cos(sympy.Symbol("x0")) ** 2)
224
+ self.assertEqual(self.model.sympy(2), sympy.cos(sympy.Symbol("x0")) ** 2)
225
+ self.assertEqual(self.model.sympy(1), sympy.cos(sympy.Symbol("x0")))
226
+ self.assertEqual(self.model.sympy(0), 1.0)
227
+
228
  def test_best_tex(self):
229
  self.assertEqual(self.model.latex(), "\\cos^{2}{\\left(x_{0} \\right)}")
230