Spaces:
Running
Running
File size: 3,442 Bytes
b2d7f41 75c23d4 b2d7f41 e84bed4 b2d7f41 583beaf b2d7f41 018c3a1 b2d7f41 21fb25d 2e51033 5dbc3b7 21fb25d b2d7f41 fb5f0a1 583beaf fb5f0a1 b2d7f41 583beaf b2d7f41 fb5f0a1 583beaf fb5f0a1 b896bd3 fb5f0a1 b2d7f41 fb5f0a1 b2d7f41 36e5dde b2d7f41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
"""Define utilities to export to sympy"""
from typing import Callable, Dict, List, Optional
import sympy # type: ignore
from sympy import sympify
from .utils import ArrayLike
sympy_mappings = {
"div": lambda x, y: x / y,
"mult": lambda x, y: x * y,
"sqrt": lambda x: sympy.sqrt(x),
"sqrt_abs": lambda x: sympy.sqrt(abs(x)),
"square": lambda x: x**2,
"cube": lambda x: x**3,
"plus": lambda x, y: x + y,
"sub": lambda x, y: x - y,
"neg": lambda x: -x,
"pow": lambda x, y: x**y,
"pow_abs": lambda x, y: abs(x) ** y,
"cos": sympy.cos,
"sin": sympy.sin,
"tan": sympy.tan,
"cosh": sympy.cosh,
"sinh": sympy.sinh,
"tanh": sympy.tanh,
"exp": sympy.exp,
"acos": sympy.acos,
"asin": sympy.asin,
"atan": sympy.atan,
"acosh": lambda x: sympy.acosh(x),
"acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
"asinh": sympy.asinh,
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - sympy.S(1)),
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - sympy.S(1)),
"abs": abs,
"mod": sympy.Mod,
"erf": sympy.erf,
"erfc": sympy.erfc,
"log": lambda x: sympy.log(x),
"log10": lambda x: sympy.log(x, 10),
"log2": lambda x: sympy.log(x, 2),
"log1p": lambda x: sympy.log(x + 1),
"log_abs": lambda x: sympy.log(abs(x)),
"log10_abs": lambda x: sympy.log(abs(x), 10),
"log2_abs": lambda x: sympy.log(abs(x), 2),
"log1p_abs": lambda x: sympy.log(abs(x) + 1),
"floor": sympy.floor,
"ceil": sympy.ceiling,
"sign": sympy.sign,
"gamma": sympy.gamma,
"round": lambda x: sympy.ceiling(x - 0.5),
"max": lambda x, y: sympy.Piecewise((y, x < y), (x, True)),
"min": lambda x, y: sympy.Piecewise((x, x < y), (y, True)),
"greater": lambda x, y: sympy.Piecewise((1.0, x > y), (0.0, True)),
"cond": lambda x, y: sympy.Piecewise((y, x > 0), (0.0, True)),
"logical_or": lambda x, y: sympy.Piecewise((1.0, (x > 0) | (y > 0)), (0.0, True)),
"logical_and": lambda x, y: sympy.Piecewise((1.0, (x > 0) & (y > 0)), (0.0, True)),
"relu": lambda x: sympy.Piecewise((0.0, x < 0), (x, True)),
}
def create_sympy_symbols_map(
feature_names_in: ArrayLike[str],
) -> Dict[str, sympy.Symbol]:
return {variable: sympy.Symbol(variable) for variable in feature_names_in}
def create_sympy_symbols(
feature_names_in: ArrayLike[str],
) -> List[sympy.Symbol]:
return [sympy.Symbol(variable) for variable in feature_names_in]
def pysr2sympy(
equation: str,
*,
feature_names_in: Optional[ArrayLike[str]] = None,
extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
):
if feature_names_in is None:
feature_names_in = []
local_sympy_mappings = {
**create_sympy_symbols_map(feature_names_in),
**(extra_sympy_mappings if extra_sympy_mappings is not None else {}),
**sympy_mappings,
}
try:
return sympify(equation, locals=local_sympy_mappings, evaluate=False)
except TypeError as e:
if "got an unexpected keyword argument 'evaluate'" in str(e):
return sympify(equation, locals=local_sympy_mappings)
raise TypeError(f"Error processing equation '{equation}'") from e
def assert_valid_sympy_symbol(var_name: str) -> None:
if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
raise ValueError(f"Variable name {var_name} is already a function name.")
|