File size: 5,186 Bytes
2f38c9c
976f8d8
 
41e5fd5
9bfcbfa
e84bed4
41e5fd5
144e3ff
0501132
a2fd8f3
7d4300a
2f38c9c
51a6b05
 
144e3ff
 
 
51a6b05
2f38c9c
a2fd8f3
 
7d4300a
2f38c9c
 
 
144e3ff
2f38c9c
144e3ff
7d4300a
c7187a6
a2fd8f3
c7187a6
fbb7cf7
 
4b56660
fbb7cf7
 
 
 
 
c7187a6
 
 
593c674
c7187a6
 
 
 
593c674
 
c7187a6
 
fbb7cf7
c7187a6
 
 
144e3ff
c7187a6
a15823e
c7187a6
 
9bfcbfa
b07eb2d
fbb7cf7
4b56660
fbb7cf7
 
7d4300a
 
b444c7e
593c674
7d4300a
 
 
9bfcbfa
593c674
 
7d4300a
9bfcbfa
fbb7cf7
d398bf9
 
9bfcbfa
144e3ff
7d4300a
f5577ea
9bfcbfa
ce5b119
144e3ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e84bed4
144e3ff
 
 
 
 
 
 
 
 
 
 
7cda629
beaf20b
 
0501132
 
 
 
beaf20b
ce5b119
 
4b56660
7cda629
4b56660
7cda629
4b56660
7cda629
 
 
 
beaf20b
 
 
 
ce5b119
beaf20b
ce5b119
 
 
 
 
 
 
 
 
a15823e
 
a2fd8f3
 
ef66f4a
a2fd8f3
ef66f4a
 
 
a2fd8f3
 
ef66f4a
 
a2fd8f3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import unittest
from functools import partial

import numpy as np
import pandas as pd
import sympy  # type: ignore

import pysr
from pysr import PySRRegressor, sympy2jax


class TestJAX(unittest.TestCase):
    def setUp(self):
        np.random.seed(0)
        from jax import numpy as jnp

        self.jnp = jnp

    def test_sympy2jax(self):
        from jax import random

        x, y, z = sympy.symbols("x y z")
        cosx = 1.0 * sympy.cos(x) + y
        key = random.PRNGKey(0)
        X = random.normal(key, (1000, 2))
        true = 1.0 * self.jnp.cos(X[:, 0]) + X[:, 1]
        f, params = sympy2jax(cosx, [x, y, z])
        self.assertTrue(self.jnp.all(self.jnp.isclose(f(X, params), true)).item())

    def test_pipeline_pandas(self):

        X = pd.DataFrame(np.random.randn(100, 10))
        y = np.ones(X.shape[0])
        model = PySRRegressor(
            progress=False,
            max_evals=10000,
            output_jax_format=True,
        )
        model.fit(X, y)

        equations = pd.DataFrame(
            {
                "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
                "Loss": [1.0, 0.1, 1e-5],
                "Complexity": [1, 2, 3],
            }
        )

        equations["Complexity Loss Equation".split(" ")].to_csv(
            "equation_file.csv.bkup"
        )

        model.refresh(checkpoint_file="equation_file.csv")
        jformat = model.jax()

        np.testing.assert_almost_equal(
            np.array(jformat["callable"](self.jnp.array(X), jformat["parameters"])),
            np.square(np.cos(X.values[:, 1])),  # Select feature 1
            decimal=3,
        )

    def test_pipeline(self):
        X = np.random.randn(100, 10)
        y = np.ones(X.shape[0])
        model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
        model.fit(X, y)

        equations = pd.DataFrame(
            {
                "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
                "Loss": [1.0, 0.1, 1e-5],
                "Complexity": [1, 2, 3],
            }
        )

        equations["Complexity Loss Equation".split(" ")].to_csv(
            "equation_file.csv.bkup"
        )

        model.refresh(checkpoint_file="equation_file.csv")
        jformat = model.jax()

        np.testing.assert_almost_equal(
            np.array(jformat["callable"](self.jnp.array(X), jformat["parameters"])),
            np.square(np.cos(X[:, 1])),  # Select feature 1
            decimal=3,
        )

    def test_avoid_simplification(self):
        ex = pysr.export_sympy.pysr2sympy(
            "square(exp(sign(0.44796443))) + 1.5 * x1",
            feature_names_in=["x1"],
            extra_sympy_mappings={"square": lambda x: x**2},
        )
        f, params = pysr.export_jax.sympy2jax(ex, [sympy.symbols("x1")])
        key = np.random.RandomState(0)
        X = key.randn(10, 1)
        np.testing.assert_almost_equal(
            np.array(f(self.jnp.array(X), params)),
            np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0],
            decimal=3,
        )

    def test_issue_656(self):
        import sympy  # type: ignore

        E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
        f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
        key = np.random.RandomState(0)
        X = key.randn(10, 1)
        np.testing.assert_almost_equal(
            np.array(f(self.jnp.array(X), params)),
            np.exp(1) + X[:, 0],
            decimal=3,
        )

    def test_feature_selection_custom_operators(self):
        rstate = np.random.RandomState(0)
        X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})

        def cos_approx(x):
            return 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720

        y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])

        model = PySRRegressor(
            progress=False,
            unary_operators=["cos_approx(x) = 1 - x^2 / 2 + x^4 / 24 + x^6 / 720"],
            select_k_features=3,
            maxsize=10,
            early_stop_condition=1e-5,
            extra_sympy_mappings={"cos_approx": cos_approx},
            extra_jax_mappings={
                "cos_approx": "(lambda x: 1 - x**2 / 2 + x**4 / 24 + x**6 / 720)"
            },
            random_state=0,
            deterministic=True,
            procs=0,
            multithreading=False,
        )
        np.random.seed(0)
        model.fit(X.values, y.values)
        f, parameters = model.jax().values()

        np_prediction = model.predict
        jax_prediction = partial(f, parameters=parameters)

        np_output = np_prediction(X.values)
        jax_output = jax_prediction(X.values)

        np.testing.assert_almost_equal(y.values, np_output, decimal=3)
        np.testing.assert_almost_equal(y.values, jax_output, decimal=3)


def runtests(just_tests=False):
    """Run all tests in test_jax.py."""
    tests = [TestJAX]
    if just_tests:
        return tests
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()
    for test in tests:
        suite.addTests(loader.loadTestsFromTestCase(test))
    runner = unittest.TextTestRunner()
    return runner.run(suite)