File size: 3,097 Bytes
b2d7f41
75c23d4
b2d7f41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21fb25d
2e51033
 
5dbc3b7
21fb25d
 
 
 
b2d7f41
 
 
fb5f0a1
 
 
 
 
 
b2d7f41
b896bd3
b2d7f41
 
 
 
 
fb5f0a1
 
 
 
b896bd3
fb5f0a1
 
b2d7f41
fb5f0a1
 
b2d7f41
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Define utilities to export to sympy"""

from typing import Callable, Dict, List, Optional

import sympy
from sympy import sympify

sympy_mappings = {
    "div": lambda x, y: x / y,
    "mult": lambda x, y: x * y,
    "sqrt": lambda x: sympy.sqrt(x),
    "sqrt_abs": lambda x: sympy.sqrt(abs(x)),
    "square": lambda x: x**2,
    "cube": lambda x: x**3,
    "plus": lambda x, y: x + y,
    "sub": lambda x, y: x - y,
    "neg": lambda x: -x,
    "pow": lambda x, y: x**y,
    "pow_abs": lambda x, y: abs(x) ** y,
    "cos": sympy.cos,
    "sin": sympy.sin,
    "tan": sympy.tan,
    "cosh": sympy.cosh,
    "sinh": sympy.sinh,
    "tanh": sympy.tanh,
    "exp": sympy.exp,
    "acos": sympy.acos,
    "asin": sympy.asin,
    "atan": sympy.atan,
    "acosh": lambda x: sympy.acosh(x),
    "acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
    "asinh": sympy.asinh,
    "atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
    "atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
    "abs": abs,
    "mod": sympy.Mod,
    "erf": sympy.erf,
    "erfc": sympy.erfc,
    "log": lambda x: sympy.log(x),
    "log10": lambda x: sympy.log(x, 10),
    "log2": lambda x: sympy.log(x, 2),
    "log1p": lambda x: sympy.log(x + 1),
    "log_abs": lambda x: sympy.log(abs(x)),
    "log10_abs": lambda x: sympy.log(abs(x), 10),
    "log2_abs": lambda x: sympy.log(abs(x), 2),
    "log1p_abs": lambda x: sympy.log(abs(x) + 1),
    "floor": sympy.floor,
    "ceil": sympy.ceiling,
    "sign": sympy.sign,
    "gamma": sympy.gamma,
    "round": lambda x: sympy.ceiling(x - 0.5),
    "max": lambda x, y: sympy.Piecewise((y, x < y), (x, True)),
    "min": lambda x, y: sympy.Piecewise((x, x < y), (y, True)),
    "greater": lambda x, y: sympy.Piecewise((1.0, x > y), (0.0, True)),
    "cond": lambda x, y: sympy.Piecewise((y, x > 0), (0.0, True)),
    "logical_or": lambda x, y: sympy.Piecewise((1.0, (x > 0) | (y > 0)), (0.0, True)),
    "logical_and": lambda x, y: sympy.Piecewise((1.0, (x > 0) & (y > 0)), (0.0, True)),
    "relu": lambda x: sympy.Piecewise((0.0, x < 0), (x, True)),
}


def create_sympy_symbols_map(
    feature_names_in: List[str],
) -> Dict[str, sympy.Symbol]:
    return {variable: sympy.Symbol(variable) for variable in feature_names_in}


def create_sympy_symbols(
    feature_names_in: List[str],
) -> List[sympy.Symbol]:
    return [sympy.Symbol(variable) for variable in feature_names_in]


def pysr2sympy(
    equation: str,
    *,
    feature_names_in: Optional[List[str]] = None,
    extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
):
    if feature_names_in is None:
        feature_names_in = []
    local_sympy_mappings = {
        **create_sympy_symbols_map(feature_names_in),
        **(extra_sympy_mappings if extra_sympy_mappings is not None else {}),
        **sympy_mappings,
    }

    return sympify(equation, locals=local_sympy_mappings)


def assert_valid_sympy_symbol(var_name: str) -> None:
    if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
        raise ValueError(f"Variable name {var_name} is already a function name.")