Spaces:
Running
Running
File size: 7,819 Bytes
d2c7df2 d18011f a06bfc4 d18011f 976f8d8 d2c7df2 e84bed4 a06bfc4 7d4300a d18011f 7d4300a d18011f 7d4300a fb950bb d18011f fb950bb 7d4300a fb950bb d18011f fb950bb d18011f fb950bb d18011f 7d4300a d18011f 4d5aec3 f2a7a62 d18011f aefe008 7d4300a d18011f fb950bb d18011f f119733 d18011f 7d4300a d18011f c3293a8 d18011f 99fff5c d18011f 7d4300a d18011f fb950bb d18011f 5af6354 aefe008 7d4300a fb950bb 7d4300a d18011f 7d4300a d18011f fb950bb 8c55475 7d4300a d18011f a06bfc4 5af6354 a06bfc4 7d4300a 9068541 d18011f fb950bb 7d4300a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# Fork of https://github.com/patrick-kidger/sympytorch
import collections as co
import functools as ft
import numpy as np # noqa: F401
import sympy # type: ignore
def _reduce(fn):
def fn_(*args):
return ft.reduce(fn, args)
return fn_
torch_initialized = False
torch = None
SingleSymPyModule = None
def _initialize_torch():
global torch_initialized
global torch
global SingleSymPyModule
# Way to lazy load torch, only if this is called,
# but still allow this module to be loaded in __init__
if not torch_initialized:
import torch as _torch
torch = _torch
_global_func_lookup = {
sympy.Mul: _reduce(torch.mul),
sympy.Add: _reduce(torch.add),
sympy.div: torch.div,
sympy.Abs: torch.abs,
sympy.sign: torch.sign,
# Note: May raise error for ints.
sympy.ceiling: torch.ceil,
sympy.floor: torch.floor,
sympy.log: torch.log,
sympy.exp: torch.exp,
sympy.sqrt: torch.sqrt,
sympy.cos: torch.cos,
sympy.acos: torch.acos,
sympy.sin: torch.sin,
sympy.asin: torch.asin,
sympy.tan: torch.tan,
sympy.atan: torch.atan,
sympy.atan2: torch.atan2,
# Note: May give NaN for complex results.
sympy.cosh: torch.cosh,
sympy.acosh: torch.acosh,
sympy.sinh: torch.sinh,
sympy.asinh: torch.asinh,
sympy.tanh: torch.tanh,
sympy.atanh: torch.atanh,
sympy.Pow: torch.pow,
sympy.re: torch.real,
sympy.im: torch.imag,
sympy.arg: torch.angle,
# Note: May raise error for ints and complexes
sympy.erf: torch.erf,
sympy.loggamma: torch.lgamma,
sympy.Eq: torch.eq,
sympy.Ne: torch.ne,
sympy.StrictGreaterThan: torch.gt,
sympy.StrictLessThan: torch.lt,
sympy.LessThan: torch.le,
sympy.GreaterThan: torch.ge,
sympy.And: torch.logical_and,
sympy.Or: torch.logical_or,
sympy.Not: torch.logical_not,
sympy.Max: torch.max,
sympy.Min: torch.min,
sympy.Mod: torch.remainder,
sympy.Heaviside: torch.heaviside,
sympy.core.numbers.Half: (lambda: 0.5),
sympy.core.numbers.One: (lambda: 1.0),
}
class _Node(torch.nn.Module):
"""Forked from https://github.com/patrick-kidger/sympytorch"""
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
super().__init__(**kwargs)
self._sympy_func = expr.func
if issubclass(expr.func, sympy.Float):
self._value = torch.nn.Parameter(torch.tensor(float(expr)))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Rational):
# This is some fraction fixed in the operator.
self._value = float(expr)
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.UnevaluatedExpr):
if len(expr.args) != 1 or not issubclass(
expr.args[0].func, sympy.Float
):
raise ValueError(
"UnevaluatedExpr should only be used to wrap floats."
)
self.register_buffer("_value", torch.tensor(float(expr.args[0])))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Integer):
# Can get here if expr is one of the Integer special cases,
# e.g. NegativeOne
self._value = int(expr)
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.NumberSymbol):
# Can get here from exp(1) or exact pi
self._value = float(expr)
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Symbol):
self._name = expr.name
self._torch_func = lambda value: value
self._args = ((lambda memodict: memodict[expr.name]),)
else:
try:
self._torch_func = _func_lookup[expr.func]
except KeyError:
raise KeyError(
f"Function {expr.func} was not found in Torch function mappings."
"Please add it to extra_torch_mappings in the format, e.g., "
"{sympy.sqrt: torch.sqrt}."
)
args = []
for arg in expr.args:
try:
arg_ = _memodict[arg]
except KeyError:
arg_ = type(self)(
expr=arg,
_memodict=_memodict,
_func_lookup=_func_lookup,
**kwargs,
)
_memodict[arg] = arg_
args.append(arg_)
self._args = torch.nn.ModuleList(args)
def forward(self, memodict):
args = []
for arg in self._args:
try:
arg_ = memodict[arg]
except KeyError:
arg_ = arg(memodict)
memodict[arg] = arg_
args.append(arg_)
return self._torch_func(*args)
class _SingleSymPyModule(torch.nn.Module):
"""Forked from https://github.com/patrick-kidger/sympytorch"""
def __init__(
self, expression, symbols_in, selection=None, extra_funcs=None, **kwargs
):
super().__init__(**kwargs)
if extra_funcs is None:
extra_funcs = {}
_func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
_memodict = {}
self._node = _Node(
expr=expression, _memodict=_memodict, _func_lookup=_func_lookup
)
self._expression_string = str(expression)
self._selection = selection
self.symbols_in = [str(symbol) for symbol in symbols_in]
def __repr__(self):
return f"{type(self).__name__}(expression={self._expression_string})"
def forward(self, X):
if self._selection is not None:
X = X[:, self._selection]
symbols = {symbol: X[:, i] for i, symbol in enumerate(self.symbols_in)}
return self._node(symbols)
SingleSymPyModule = _SingleSymPyModule
def sympy2torch(expression, symbols_in, selection=None, extra_torch_mappings=None):
"""Returns a module for a given sympy expression with trainable parameters;
This function will assume the input to the module is a matrix X, where
each column corresponds to each symbol you pass in `symbols_in`.
"""
global SingleSymPyModule
_initialize_torch()
return SingleSymPyModule(
expression, symbols_in, selection=selection, extra_funcs=extra_torch_mappings
)
|