Spaces:
Running
Running
"""Code for exporting discovered expressions to numpy""" | |
import warnings | |
from typing import List, Union | |
import numpy as np | |
import pandas as pd | |
from numpy.typing import NDArray | |
from sympy import Expr, Symbol, lambdify # type: ignore | |
def sympy2numpy(eqn, sympy_symbols, *, selection=None): | |
return CallableEquation(eqn, sympy_symbols, selection=selection) | |
class CallableEquation: | |
"""Simple wrapper for numpy lambda functions built with sympy""" | |
_sympy: Expr | |
_sympy_symbols: List[Symbol] | |
_selection: Union[NDArray[np.bool_], None] | |
def __init__(self, eqn, sympy_symbols, selection=None): | |
self._sympy = eqn | |
self._sympy_symbols = sympy_symbols | |
self._selection = selection | |
def __repr__(self): | |
return f"PySRFunction(X=>{self._sympy})" | |
def __call__(self, X): | |
expected_shape = (X.shape[0],) | |
if isinstance(X, pd.DataFrame): | |
# Lambda function takes as argument: | |
return self._lambda( | |
**{k: X[k].values for k in map(str, self._sympy_symbols)} | |
) * np.ones(expected_shape) | |
if self._selection is not None: | |
if X.shape[1] != self._selection.sum(): | |
warnings.warn( | |
"`X` should be of shape (n_samples, len(self._selection)). " | |
"Automatically filtering `X` to selection. " | |
"Note: Filtered `X` column order may not match column order in fit " | |
"this may lead to incorrect predictions and other errors." | |
) | |
X = X[:, self._selection] | |
return self._lambda(*X.T) * np.ones(expected_shape) | |
def _lambda(self): | |
return lambdify(self._sympy_symbols, self._sympy) | |