File size: 7,819 Bytes
d2c7df2
d18011f
a06bfc4
d18011f
976f8d8
d2c7df2
e84bed4
a06bfc4
7d4300a
d18011f
 
 
7d4300a
d18011f
 
7d4300a
fb950bb
 
d18011f
fb950bb
7d4300a
fb950bb
 
 
d18011f
fb950bb
d18011f
fb950bb
 
d18011f
7d4300a
d18011f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d5aec3
f2a7a62
 
 
d18011f
 
 
aefe008
7d4300a
d18011f
 
 
 
fb950bb
d18011f
 
 
 
f119733
 
 
 
 
d18011f
7d4300a
 
 
 
 
 
 
d18011f
 
 
 
 
 
 
 
c3293a8
 
 
 
 
d18011f
 
 
 
 
99fff5c
 
 
 
 
 
 
 
d18011f
 
 
 
 
7d4300a
 
 
 
 
 
d18011f
 
 
fb950bb
d18011f
 
 
 
 
 
 
 
 
 
 
5af6354
aefe008
7d4300a
 
 
 
fb950bb
7d4300a
d18011f
 
 
 
 
7d4300a
 
 
d18011f
 
 
 
fb950bb
 
 
 
8c55475
 
7d4300a
d18011f
a06bfc4
5af6354
 
a06bfc4
7d4300a
9068541
 
 
 
 
d18011f
fb950bb
 
 
7d4300a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Fork of https://github.com/patrick-kidger/sympytorch

import collections as co
import functools as ft

import numpy as np  # noqa: F401
import sympy  # type: ignore


def _reduce(fn):
    def fn_(*args):
        return ft.reduce(fn, args)

    return fn_


torch_initialized = False
torch = None
SingleSymPyModule = None


def _initialize_torch():
    global torch_initialized
    global torch
    global SingleSymPyModule

    # Way to lazy load torch, only if this is called,
    # but still allow this module to be loaded in __init__
    if not torch_initialized:
        import torch as _torch

        torch = _torch

        _global_func_lookup = {
            sympy.Mul: _reduce(torch.mul),
            sympy.Add: _reduce(torch.add),
            sympy.div: torch.div,
            sympy.Abs: torch.abs,
            sympy.sign: torch.sign,
            # Note: May raise error for ints.
            sympy.ceiling: torch.ceil,
            sympy.floor: torch.floor,
            sympy.log: torch.log,
            sympy.exp: torch.exp,
            sympy.sqrt: torch.sqrt,
            sympy.cos: torch.cos,
            sympy.acos: torch.acos,
            sympy.sin: torch.sin,
            sympy.asin: torch.asin,
            sympy.tan: torch.tan,
            sympy.atan: torch.atan,
            sympy.atan2: torch.atan2,
            # Note: May give NaN for complex results.
            sympy.cosh: torch.cosh,
            sympy.acosh: torch.acosh,
            sympy.sinh: torch.sinh,
            sympy.asinh: torch.asinh,
            sympy.tanh: torch.tanh,
            sympy.atanh: torch.atanh,
            sympy.Pow: torch.pow,
            sympy.re: torch.real,
            sympy.im: torch.imag,
            sympy.arg: torch.angle,
            # Note: May raise error for ints and complexes
            sympy.erf: torch.erf,
            sympy.loggamma: torch.lgamma,
            sympy.Eq: torch.eq,
            sympy.Ne: torch.ne,
            sympy.StrictGreaterThan: torch.gt,
            sympy.StrictLessThan: torch.lt,
            sympy.LessThan: torch.le,
            sympy.GreaterThan: torch.ge,
            sympy.And: torch.logical_and,
            sympy.Or: torch.logical_or,
            sympy.Not: torch.logical_not,
            sympy.Max: torch.max,
            sympy.Min: torch.min,
            sympy.Mod: torch.remainder,
            sympy.Heaviside: torch.heaviside,
            sympy.core.numbers.Half: (lambda: 0.5),
            sympy.core.numbers.One: (lambda: 1.0),
        }

        class _Node(torch.nn.Module):
            """Forked from https://github.com/patrick-kidger/sympytorch"""

            def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
                super().__init__(**kwargs)

                self._sympy_func = expr.func

                if issubclass(expr.func, sympy.Float):
                    self._value = torch.nn.Parameter(torch.tensor(float(expr)))
                    self._torch_func = lambda: self._value
                    self._args = ()
                elif issubclass(expr.func, sympy.Rational):
                    # This is some fraction fixed in the operator.
                    self._value = float(expr)
                    self._torch_func = lambda: self._value
                    self._args = ()
                elif issubclass(expr.func, sympy.UnevaluatedExpr):
                    if len(expr.args) != 1 or not issubclass(
                        expr.args[0].func, sympy.Float
                    ):
                        raise ValueError(
                            "UnevaluatedExpr should only be used to wrap floats."
                        )
                    self.register_buffer("_value", torch.tensor(float(expr.args[0])))
                    self._torch_func = lambda: self._value
                    self._args = ()
                elif issubclass(expr.func, sympy.Integer):
                    # Can get here if expr is one of the Integer special cases,
                    # e.g. NegativeOne
                    self._value = int(expr)
                    self._torch_func = lambda: self._value
                    self._args = ()
                elif issubclass(expr.func, sympy.NumberSymbol):
                    # Can get here from exp(1) or exact pi
                    self._value = float(expr)
                    self._torch_func = lambda: self._value
                    self._args = ()
                elif issubclass(expr.func, sympy.Symbol):
                    self._name = expr.name
                    self._torch_func = lambda value: value
                    self._args = ((lambda memodict: memodict[expr.name]),)
                else:
                    try:
                        self._torch_func = _func_lookup[expr.func]
                    except KeyError:
                        raise KeyError(
                            f"Function {expr.func} was not found in Torch function mappings."
                            "Please add it to extra_torch_mappings in the format, e.g., "
                            "{sympy.sqrt: torch.sqrt}."
                        )
                    args = []
                    for arg in expr.args:
                        try:
                            arg_ = _memodict[arg]
                        except KeyError:
                            arg_ = type(self)(
                                expr=arg,
                                _memodict=_memodict,
                                _func_lookup=_func_lookup,
                                **kwargs,
                            )
                            _memodict[arg] = arg_
                        args.append(arg_)
                    self._args = torch.nn.ModuleList(args)

            def forward(self, memodict):
                args = []
                for arg in self._args:
                    try:
                        arg_ = memodict[arg]
                    except KeyError:
                        arg_ = arg(memodict)
                        memodict[arg] = arg_
                    args.append(arg_)
                return self._torch_func(*args)

        class _SingleSymPyModule(torch.nn.Module):
            """Forked from https://github.com/patrick-kidger/sympytorch"""

            def __init__(
                self, expression, symbols_in, selection=None, extra_funcs=None, **kwargs
            ):
                super().__init__(**kwargs)

                if extra_funcs is None:
                    extra_funcs = {}
                _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)

                _memodict = {}
                self._node = _Node(
                    expr=expression, _memodict=_memodict, _func_lookup=_func_lookup
                )
                self._expression_string = str(expression)
                self._selection = selection
                self.symbols_in = [str(symbol) for symbol in symbols_in]

            def __repr__(self):
                return f"{type(self).__name__}(expression={self._expression_string})"

            def forward(self, X):
                if self._selection is not None:
                    X = X[:, self._selection]
                symbols = {symbol: X[:, i] for i, symbol in enumerate(self.symbols_in)}
                return self._node(symbols)

        SingleSymPyModule = _SingleSymPyModule


def sympy2torch(expression, symbols_in, selection=None, extra_torch_mappings=None):
    """Returns a module for a given sympy expression with trainable parameters;

    This function will assume the input to the module is a matrix X, where
        each column corresponds to each symbol you pass in `symbols_in`.
    """
    global SingleSymPyModule

    _initialize_torch()

    return SingleSymPyModule(
        expression, symbols_in, selection=selection, extra_funcs=extra_torch_mappings
    )