Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
b4cb407
1
Parent(s):
ce5b119
Fix feature selection for JAX export
Browse files- pysr/export_jax.py +4 -1
- pysr/sr.py +1 -0
pysr/export_jax.py
CHANGED
@@ -109,7 +109,7 @@ def _initialize_jax():
|
|
109 |
jsp = _jsp
|
110 |
|
111 |
|
112 |
-
def sympy2jax(expression, symbols_in, extra_jax_mappings=None):
|
113 |
"""Returns a function f and its parameters;
|
114 |
the function takes an input matrix, and a list of arguments:
|
115 |
f(X, parameters)
|
@@ -192,6 +192,9 @@ def sympy2jax(expression, symbols_in, extra_jax_mappings=None):
|
|
192 |
)
|
193 |
hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
|
194 |
text = f"def {hash_string}(X, parameters):\n"
|
|
|
|
|
|
|
195 |
text += " return "
|
196 |
text += functional_form_text
|
197 |
ldict = {}
|
|
|
109 |
jsp = _jsp
|
110 |
|
111 |
|
112 |
+
def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None):
|
113 |
"""Returns a function f and its parameters;
|
114 |
the function takes an input matrix, and a list of arguments:
|
115 |
f(X, parameters)
|
|
|
192 |
)
|
193 |
hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
|
194 |
text = f"def {hash_string}(X, parameters):\n"
|
195 |
+
if selection is not None:
|
196 |
+
# Impose the feature selection:
|
197 |
+
text += f" X = X[:, {list(selection)}]\n"
|
198 |
text += " return "
|
199 |
text += functional_form_text
|
200 |
ldict = {}
|
pysr/sr.py
CHANGED
@@ -1740,6 +1740,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1740 |
func, params = sympy2jax(
|
1741 |
eqn,
|
1742 |
sympy_symbols,
|
|
|
1743 |
extra_jax_mappings=self.extra_jax_mappings,
|
1744 |
)
|
1745 |
jax_format.append({"callable": func, "parameters": params})
|
|
|
1740 |
func, params = sympy2jax(
|
1741 |
eqn,
|
1742 |
sympy_symbols,
|
1743 |
+
selection=self.selection_mask_,
|
1744 |
extra_jax_mappings=self.extra_jax_mappings,
|
1745 |
)
|
1746 |
jax_format.append({"callable": func, "parameters": params})
|