Spaces:
Running
Running
tttc3
commited on
Commit
·
73c6ffd
1
Parent(s):
32a2de6
Fixed jax export compatibility with refactor
Browse files- pysr/export_jax.py +1 -4
pysr/export_jax.py
CHANGED
@@ -109,7 +109,7 @@ def _initialize_jax():
|
|
109 |
jsp = _jsp
|
110 |
|
111 |
|
112 |
-
def sympy2jax(expression, symbols_in,
|
113 |
"""Returns a function f and its parameters;
|
114 |
the function takes an input matrix, and a list of arguments:
|
115 |
f(X, parameters)
|
@@ -192,9 +192,6 @@ def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None):
|
|
192 |
)
|
193 |
hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
|
194 |
text = f"def {hash_string}(X, parameters):\n"
|
195 |
-
if selection is not None:
|
196 |
-
# Impose the feature selection:
|
197 |
-
text += f" X = X[:, {list(selection)}]\n"
|
198 |
text += " return "
|
199 |
text += functional_form_text
|
200 |
ldict = {}
|
|
|
109 |
jsp = _jsp
|
110 |
|
111 |
|
112 |
+
def sympy2jax(expression, symbols_in, extra_jax_mappings=None):
|
113 |
"""Returns a function f and its parameters;
|
114 |
the function takes an input matrix, and a list of arguments:
|
115 |
f(X, parameters)
|
|
|
192 |
)
|
193 |
hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
|
194 |
text = f"def {hash_string}(X, parameters):\n"
|
|
|
|
|
|
|
195 |
text += " return "
|
196 |
text += functional_form_text
|
197 |
ldict = {}
|