MilesCranmer commited on
Commit
4c39e04
·
1 Parent(s): bb97e2c

Update documentation for sklearn interface

Browse files
Files changed (1) hide show
  1. docs/options.md +59 -42
docs/options.md CHANGED
@@ -1,10 +1,8 @@
1
  # Features and Options
2
 
3
- You likely don't need to tune the hyperparameters yourself,
4
- but if you would like, you can use `hyperparamopt.py` as an example.
5
-
6
  Some configurable features and options in `PySR` which you
7
  may find useful include:
 
8
  - `binary_operators`, `unary_operators`
9
  - `niterations`
10
  - `ncyclesperiteration`
@@ -21,18 +19,31 @@ may find useful include:
21
 
22
  These are described below
23
 
24
- The program will output a pandas DataFrame containing the equations,
25
- mean square error, and complexity. It will also dump to a csv
 
 
 
26
  at the end of every iteration,
27
- which is `hall_of_fame_{date_time}.csv` by default. It also prints the
28
- equations to stdout.
 
 
 
 
 
 
 
 
 
 
29
 
30
  ## Operators
31
 
32
  A list of operators can be found on the operators page.
33
  One can define custom operators in Julia by passing a string:
34
  ```python
35
- equations = pysr.pysr(X, y, niterations=100,
36
  binary_operators=["mult", "plus", "special(x, y) = x^2 + y"],
37
  extra_sympy_mappings={'special': lambda x, y: x**2 + y},
38
  unary_operators=["cos"])
@@ -51,8 +62,6 @@ so that the SymPy code can understand the output equation from Julia,
51
  when constructing a useable function. This step is optional, but
52
  is necessary for the `lambda_format` to work.
53
 
54
- One can also edit `operators.jl`.
55
-
56
  ## Iterations
57
 
58
  This is the total number of generations that `pysr` will run for.
@@ -78,15 +87,15 @@ each population stay closer to the best current equations.
78
 
79
  One can adjust the number of workers used by Julia with the
80
  `procs` option. You should set this equal to the number of cores
81
- you want `pysr` to use. This will also run `procs` number of
82
- populations simultaneously by default.
83
 
84
  ## Populations
85
 
86
- By default, `populations=procs`, but you can set a different
87
- number of populations with this option. More populations may increase
 
88
  the diversity of equations discovered, though will take longer to train.
89
- However, it may be more efficient to have `populations>procs`,
90
  as there are multiple populations running
91
  on each core.
92
 
@@ -100,7 +109,8 @@ instead of the usual 4, which creates more populations
100
  sigma = ...
101
  weights = 1/sigma**2
102
 
103
- equations = pysr.pysr(X, y, weights=weights, procs=10)
 
104
  ```
105
 
106
  ## Max size
@@ -147,55 +157,62 @@ expressions of complexity 5 (e.g., 5.0 + x2 exp(x3)).
147
 
148
  ## LaTeX, SymPy
149
 
150
- The `pysr` command will return a pandas dataframe. The `sympy_format`
151
- column gives sympy equations, and the `lambda_format` gives callable
152
- functions. These use the variable names you have provided.
 
 
 
153
 
154
  There are also some helper functions for doing this quickly.
155
- You can call `get_hof()` (or pass an equation file explicitly to this)
156
- to get this pandas dataframe.
157
-
158
- You can call the functions `best()` to get the sympy format
159
- for the best equation, using the `score` column to sort equations.
160
- `best_latex()` returns the LaTeX form of this, and `best_callable()`
161
- returns a callable function.
162
 
163
 
164
  ## Callable exports: numpy, pytorch, jax
165
 
166
  By default, the dataframe of equations will contain columns
167
- with the identifier `lambda_format`. These are simple functions
168
- which correspond to the equation, but executed
169
- with numpy functions. You can pass your `X` matrix to these functions
170
- just as you did to the `pysr` call. Thus, this allows
 
171
  you to numerically evaluate the equations over different output.
172
 
 
 
 
 
 
173
 
174
  One can do the same thing for PyTorch, which uses code
175
  from [sympytorch](https://github.com/patrick-kidger/sympytorch),
176
  and for JAX, which uses code from
177
  [sympy2jax](https://github.com/MilesCranmer/sympy2jax).
178
 
179
- For torch, set the argument `output_torch_format=True`, which
180
- will generate a column `torch_format`. Each element of this column
181
- is a PyTorch module which runs the equation, using PyTorch functions,
182
  over `X` (as a PyTorch tensor). This is differentiable, and the
183
  parameters of this PyTorch module correspond to the learned parameters
184
  in the equation, and are trainable.
 
 
 
 
185
 
186
- For jax, set the argument `output_jax_format=True`, which
187
- will generate a column `jax_format`. Each element of this column
188
- is a dictionary containing a `'callable'` (a JAX function),
189
  and `'parameters'` (a list of parameters in the equation).
190
- One can execute this function with: `element['callable'](X, element['parameters'])`.
 
 
 
 
191
  Since the parameter list is a jax array, this therefore lets you also
192
  train the parameters within JAX (and is differentiable).
193
 
194
- If you forget to turn these on when calling the function initially,
195
- you can re-run `get_hof(output_jax_format=True)`, and it will re-use
196
- the equations and other state properties, assuming you haven't
197
- re-run `pysr` in the meantime!
198
-
199
  ## `loss`
200
 
201
  The default loss is mean-square error, and weighted mean-square error.
 
1
  # Features and Options
2
 
 
 
 
3
  Some configurable features and options in `PySR` which you
4
  may find useful include:
5
+ - `model_selection`
6
  - `binary_operators`, `unary_operators`
7
  - `niterations`
8
  - `ncyclesperiteration`
 
19
 
20
  These are described below
21
 
22
+ The program will output a pandas DataFrame containing the equations
23
+ to `PySRRegressor.equations` containing the loss value
24
+ and complexity.
25
+
26
+ It will also dump to a csv
27
  at the end of every iteration,
28
+ which is `hall_of_fame_{date_time}.csv` by default.
29
+ It also prints the equations to stdout.
30
+
31
+ ## Model selection
32
+
33
+ By default, `PySRRegressor` uses `model_selection='best'`
34
+ which selects an equation from `PySRRegressor.equations` using
35
+ a combination of accuracy and complexity.
36
+ You can also select `model_selection='accuracy'`.
37
+
38
+ By printing a model (i.e., `print(model)`), you can see
39
+ the equation selection with the arrow shown in the `pick` column.
40
 
41
  ## Operators
42
 
43
  A list of operators can be found on the operators page.
44
  One can define custom operators in Julia by passing a string:
45
  ```python
46
+ PySRRegressor(niterations=100,
47
  binary_operators=["mult", "plus", "special(x, y) = x^2 + y"],
48
  extra_sympy_mappings={'special': lambda x, y: x**2 + y},
49
  unary_operators=["cos"])
 
62
  when constructing a useable function. This step is optional, but
63
  is necessary for the `lambda_format` to work.
64
 
 
 
65
  ## Iterations
66
 
67
  This is the total number of generations that `pysr` will run for.
 
87
 
88
  One can adjust the number of workers used by Julia with the
89
  `procs` option. You should set this equal to the number of cores
90
+ you want `pysr` to use.
 
91
 
92
  ## Populations
93
 
94
+ By default, `populations=20`, but you can set a different
95
+ number of populations with this option.
96
+ More populations may increase
97
  the diversity of equations discovered, though will take longer to train.
98
+ However, it is usually more efficient to have `populations>procs`,
99
  as there are multiple populations running
100
  on each core.
101
 
 
109
  sigma = ...
110
  weights = 1/sigma**2
111
 
112
+ model = PySRRegressor(procs=10)
113
+ model.fit(X, y, weights=weights)
114
  ```
115
 
116
  ## Max size
 
157
 
158
  ## LaTeX, SymPy
159
 
160
+ After running `model.fit(...)`, you can look at
161
+ `model.equations` which is a pandas dataframe.
162
+ The `sympy_format` column gives sympy equations,
163
+ and the `lambda_format` gives callable functions.
164
+ You can optionally pass a pandas dataframe to the callable function,
165
+ if you called `.fit` on a pandas dataframe as well.
166
 
167
  There are also some helper functions for doing this quickly.
168
+ - `model.latex()` will generate a TeX formatted output of your equation.
169
+ - `model.sympy()` will return the SymPy representation.
170
+ - `model.jax()` will return a callable JAX function combined with parameters (see below)
171
+ - `model.pytorch()` will return a PyTorch model (see below).
 
 
 
172
 
173
 
174
  ## Callable exports: numpy, pytorch, jax
175
 
176
  By default, the dataframe of equations will contain columns
177
+ with the identifier `lambda_format`.
178
+ These are simple functions which correspond to the equation, but executed
179
+ with numpy functions.
180
+ You can pass your `X` matrix to these functions
181
+ just as you did to the `model.fit` call. Thus, this allows
182
  you to numerically evaluate the equations over different output.
183
 
184
+ Calling `model.predict` will execute the `lambda_format` of
185
+ the best equation, and return the result. If you selected
186
+ `model_selection="best"`, this will use an equation that combines
187
+ accuracy with simplicity. For `model_selection="accuracy"`, this will just
188
+ look at accuracy.
189
 
190
  One can do the same thing for PyTorch, which uses code
191
  from [sympytorch](https://github.com/patrick-kidger/sympytorch),
192
  and for JAX, which uses code from
193
  [sympy2jax](https://github.com/MilesCranmer/sympy2jax).
194
 
195
+ Calling `model.pytorch()` will return
196
+ a PyTorch module which runs the equation, using PyTorch functions,
 
197
  over `X` (as a PyTorch tensor). This is differentiable, and the
198
  parameters of this PyTorch module correspond to the learned parameters
199
  in the equation, and are trainable.
200
+ ```python
201
+ output = model.pytorch()
202
+ output['callable'](X)
203
+ ```
204
 
205
+ For JAX, you can equivalently set the argument `output_jax_format=True`.
206
+ This will return a dictionary containing a `'callable'` (a JAX function),
 
207
  and `'parameters'` (a list of parameters in the equation).
208
+ You can execute this function with:
209
+ ```python
210
+ output = model.jax()
211
+ output['callable'](X, output['parameters'])
212
+ ```
213
  Since the parameter list is a jax array, this therefore lets you also
214
  train the parameters within JAX (and is differentiable).
215
 
 
 
 
 
 
216
  ## `loss`
217
 
218
  The default loss is mean-square error, and weighted mean-square error.