Spaces:
Running
Running
File size: 1,558 Bytes
5a01e6f 976f8d8 5a01e6f b2d7f41 5a01e6f b2d7f41 5a01e6f b2d7f41 5a01e6f 83d8e67 5a01e6f bd90cfc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
"""Code for exporting discovered expressions to numpy"""
import warnings
import numpy as np
import pandas as pd
from sympy import lambdify
def sympy2numpy(eqn, sympy_symbols, *, selection=None):
return CallableEquation(eqn, sympy_symbols, selection=selection)
class CallableEquation:
"""Simple wrapper for numpy lambda functions built with sympy"""
def __init__(self, eqn, sympy_symbols, selection=None):
self._sympy = eqn
self._sympy_symbols = sympy_symbols
self._selection = selection
def __repr__(self):
return f"PySRFunction(X=>{self._sympy})"
def __call__(self, X):
expected_shape = (X.shape[0],)
if isinstance(X, pd.DataFrame):
# Lambda function takes as argument:
return self._lambda(
**{k: X[k].values for k in map(str, self._sympy_symbols)}
) * np.ones(expected_shape)
if self._selection is not None:
if X.shape[1] != len(self._selection):
warnings.warn(
"`X` should be of shape (n_samples, len(self._selection)). "
"Automatically filtering `X` to selection. "
"Note: Filtered `X` column order may not match column order in fit "
"this may lead to incorrect predictions and other errors."
)
X = X[:, self._selection]
return self._lambda(*X.T) * np.ones(expected_shape)
@property
def _lambda(self):
return lambdify(self._sympy_symbols, self._sympy)
|