MilesCranmer commited on
Commit
2ceb526
·
1 Parent(s): 66dcb6d

Add JAX export functionality

Browse files
Files changed (3) hide show
  1. pysr/__init__.py +1 -0
  2. pysr/export.py +158 -0
  3. pysr/sr.py +2 -2
pysr/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
  from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
2
  from .feynman_problems import Problem, FeynmanProblem
 
 
1
  from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
2
  from .feynman_problems import Problem, FeynmanProblem
3
+ from .export import sympy2jax
pysr/export.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools as ft
2
+ import sympy
3
+ import string
4
+ import random
5
+
6
+ try:
7
+ import jax
8
+ from jax import numpy as jnp
9
+ from jax.scipy import special as jsp
10
+
11
+ # Special since need to reduce arguments.
12
+ MUL = 0
13
+ ADD = 1
14
+
15
+ _jnp_func_lookup = {
16
+ sympy.Mul: MUL,
17
+ sympy.Add: ADD,
18
+ sympy.div: "jnp.div",
19
+ sympy.Abs: "jnp.abs",
20
+ sympy.sign: "jnp.sign",
21
+ # Note: May raise error for ints.
22
+ sympy.ceiling: "jnp.ceil",
23
+ sympy.floor: "jnp.floor",
24
+ sympy.log: "jnp.log",
25
+ sympy.exp: "jnp.exp",
26
+ sympy.sqrt: "jnp.sqrt",
27
+ sympy.cos: "jnp.cos",
28
+ sympy.acos: "jnp.acos",
29
+ sympy.sin: "jnp.sin",
30
+ sympy.asin: "jnp.asin",
31
+ sympy.tan: "jnp.tan",
32
+ sympy.atan: "jnp.atan",
33
+ sympy.atan2: "jnp.atan2",
34
+ # Note: Also may give NaN for complex results.
35
+ sympy.cosh: "jnp.cosh",
36
+ sympy.acosh: "jnp.acosh",
37
+ sympy.sinh: "jnp.sinh",
38
+ sympy.asinh: "jnp.asinh",
39
+ sympy.tanh: "jnp.tanh",
40
+ sympy.atanh: "jnp.atanh",
41
+ sympy.Pow: "jnp.power",
42
+ sympy.re: "jnp.real",
43
+ sympy.im: "jnp.imag",
44
+ sympy.arg: "jnp.angle",
45
+ # Note: May raise error for ints and complexes
46
+ sympy.erf: "jsp.erf",
47
+ sympy.erfc: "jsp.erfc",
48
+ sympy.LessThan: "jnp.le",
49
+ sympy.GreaterThan: "jnp.ge",
50
+ sympy.And: "jnp.logical_and",
51
+ sympy.Or: "jnp.logical_or",
52
+ sympy.Not: "jnp.logical_not",
53
+ sympy.Max: "jnp.max",
54
+ sympy.Min: "jnp.min",
55
+ sympy.Mod: "jnp.mod",
56
+ sympy.round: 'jnp.round'
57
+ }
58
+ except ImportError:
59
+ ...
60
+
61
+ def sympy2jaxtext(expr, parameters, symbols_in):
62
+ if issubclass(expr.func, sympy.Float):
63
+ parameters.append(float(expr))
64
+ return f"parameters[{len(parameters) - 1}]"
65
+ elif issubclass(expr.func, sympy.Integer):
66
+ return "{int(expr)}"
67
+ elif issubclass(expr.func, sympy.Symbol):
68
+ return f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
69
+ else:
70
+ _func = _jnp_func_lookup[expr.func]
71
+ args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
72
+ if _func == MUL:
73
+ return ' * '.join(['(' + arg + ')' for arg in args])
74
+ elif _func == ADD:
75
+ return ' + '.join(['(' + arg + ')' for arg in args])
76
+ else:
77
+ return f'{_func}({", ".join(args)})'
78
+
79
+ def sympy2jax(equation, symbols_in):
80
+ """Returns a function f and its parameters;
81
+ the function takes an input matrix, and a list of arguments:
82
+ f(X, parameters)
83
+ where the parameters appear in the JAX equation.
84
+
85
+ # Examples:
86
+
87
+ Let's create a function in SymPy:
88
+ ```python
89
+ x, y = symbols('x y')
90
+ cosx = 1.0 * sympy.cos(x) + 3.2 * y
91
+ ```
92
+ Let's get the JAX version. We pass the equation, and
93
+ the symbols required.
94
+ ```python
95
+ f, params = sympy2jax(cosx, [x, y])
96
+ ```
97
+ The order you supply the symbols is the same order
98
+ you should supply the features when calling
99
+ the function `f` (shape `[nrows, nfeatures]`).
100
+ In this case, features=2 for x and y.
101
+ The `params` in this case will be
102
+ `jnp.array([1.0, 3.2])`. You pass these parameters
103
+ when calling the function, which will let you change them
104
+ and take gradients.
105
+
106
+ Let's generate some JAX data to pass:
107
+ ```python
108
+ key = random.PRNGKey(0)
109
+ X = random.normal(key, (10, 2))
110
+ ```
111
+
112
+ We can call the function with:
113
+ ```python
114
+ f(X, params)
115
+
116
+ #> DeviceArray([-2.6080756 , 0.72633684, -6.7557726 , -0.2963162 ,
117
+ # 6.6014843 , 5.032483 , -0.810931 , 4.2520013 ,
118
+ # 3.5427954 , -2.7479894 ], dtype=float32)
119
+ ```
120
+
121
+ We can take gradients with respect
122
+ to the parameters for each row with JAX
123
+ gradient parameters now:
124
+ ```python
125
+ jac_f = jax.jacobian(f, argnums=1)
126
+ jac_f(X, params)
127
+
128
+ #> DeviceArray([[ 0.49364874, -0.9692889 ],
129
+ # [ 0.8283714 , -0.0318858 ],
130
+ # [-0.7447336 , -1.8784496 ],
131
+ # [ 0.70755106, -0.3137085 ],
132
+ # [ 0.944834 , 1.767703 ],
133
+ # [ 0.51673377, 1.4111717 ],
134
+ # [ 0.87347716, -0.52637756],
135
+ # [ 0.8760679 , 1.0549792 ],
136
+ # [ 0.9961824 , 0.79581654],
137
+ # [-0.88465923, -0.5822907 ]], dtype=float32)
138
+ ```
139
+
140
+ We can also JIT-compile our function:
141
+ ```python
142
+ compiled_f = jax.jit(f)
143
+ compiled_f(X, params)
144
+
145
+ #> DeviceArray([-2.6080756 , 0.72633684, -6.7557726 , -0.2963162 ,
146
+ # 6.6014843 , 5.032483 , -0.810931 , 4.2520013 ,
147
+ # 3.5427954 , -2.7479894 ], dtype=float32)
148
+ ```
149
+ """
150
+ parameters = []
151
+ functional_form_text = sympy2jaxtext(equation, parameters, symbols_in)
152
+ hash_string = 'A' + str(hash([equation, symbols_in]))
153
+ text = f"def {hash_string}(X, parameters):\n"
154
+ text += " return "
155
+ text += functional_form_text
156
+ ldict = {}
157
+ exec(text, globals(), ldict)
158
+ return ldict['f'], jnp.array(parameters)
pysr/sr.py CHANGED
@@ -47,8 +47,8 @@ sympy_mappings = {
47
  'erf': lambda x : sympy.erf(x),
48
  'erfc': lambda x : sympy.erfc(x),
49
  'logm': lambda x : sympy.log(abs(x)),
50
- 'logm10':lambda x : sympy.log(abs(x), base=10),
51
- 'logm2': lambda x : sympy.log(abs(x), base=2),
52
  'log1p': lambda x : sympy.log(x + 1),
53
  'floor': lambda x : sympy.floor(x),
54
  'ceil': lambda x : sympy.ceil(x),
 
47
  'erf': lambda x : sympy.erf(x),
48
  'erfc': lambda x : sympy.erfc(x),
49
  'logm': lambda x : sympy.log(abs(x)),
50
+ 'logm10':lambda x : sympy.log(abs(x), 10),
51
+ 'logm2': lambda x : sympy.log(abs(x), 2),
52
  'log1p': lambda x : sympy.log(x + 1),
53
  'floor': lambda x : sympy.floor(x),
54
  'ceil': lambda x : sympy.ceil(x),