MilesCranmer commited on
Commit
cd54791
·
1 Parent(s): 989f731

Create scikit-learn API

Browse files
Files changed (2) hide show
  1. pysr/__init__.py +1 -0
  2. pysr/sklearn.py +57 -0
pysr/__init__.py CHANGED
@@ -11,3 +11,4 @@ from .sr import (
11
  from .feynman_problems import Problem, FeynmanProblem
12
  from .export_jax import sympy2jax
13
  from .export_torch import sympy2torch
 
 
11
  from .feynman_problems import Problem, FeynmanProblem
12
  from .export_jax import sympy2jax
13
  from .export_torch import sympy2torch
14
+ from .sklearn import PySRRegressor
pysr/sklearn.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pysr import pysr, best_row
2
+ from sklearn.base import BaseEstimator
3
+
4
+
5
+ class PySRRegressor(BaseEstimator):
6
+ def __init__(self, model_selection="accuracy", **params):
7
+ """Initialize settings for pysr.pysr call.
8
+
9
+ :param model_selection: How to select a model. Can be 'accuracy' or 'best'. 'best' will optimize a combination of complexity and accuracy.
10
+ :type model_selection: str
11
+ """
12
+ super().__init__()
13
+ self.model_selection = model_selection
14
+ self.params = params
15
+
16
+ # Stored equations:
17
+ self.equations = None
18
+
19
+ def __repr__(self):
20
+ return f"PySRRegressor(equations={self.get_best()['sympy_format']})"
21
+
22
+ def set_params(self, **params):
23
+ """Set parameters for pysr.pysr call or model_selection strategy."""
24
+ for key, value in params.items():
25
+ if key == "model_selection":
26
+ self.model_selection = value
27
+ self.params[key] = value
28
+
29
+ return self
30
+
31
+ def get_params(self, deep=True):
32
+ del deep
33
+ return {**self.params, "model_selection": self.model_selection}
34
+
35
+ def get_best(self):
36
+ if self.equations is None:
37
+ return 0.0
38
+ if self.model_selection == "accuracy":
39
+ return self.equations.iloc[-1]
40
+ elif self.model_selection == "best":
41
+ return best_row(self.equations)
42
+ else:
43
+ raise NotImplementedError
44
+
45
+ def fit(self, X, y):
46
+ self.equations = pysr(
47
+ X=X,
48
+ y=y,
49
+ **self.params,
50
+ )
51
+ return self
52
+
53
+ def predict(self, X):
54
+ equation_row = self.get_best()
55
+ np_format = equation_row["lambda_format"]
56
+
57
+ return np_format(X)