Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
962c25c
1
Parent(s):
4d5aec3
Add test for mod mapping in torch
Browse files- test/test_torch.py +17 -0
test/test_torch.py
CHANGED
@@ -51,3 +51,20 @@ class TestTorch(unittest.TestCase):
|
|
51 |
np.square(np.cos(X[:, 1])), # Selection 1st feature
|
52 |
decimal=4,
|
53 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
np.square(np.cos(X[:, 1])), # Selection 1st feature
|
52 |
decimal=4,
|
53 |
)
|
54 |
+
|
55 |
+
def test_mod_mapping(self):
|
56 |
+
x, y, z = sympy.symbols("x y z")
|
57 |
+
expression = x ** 2 + sympy.atanh(sympy.Mod(y + 1, 2) - 1) * 3.2 * z
|
58 |
+
|
59 |
+
module = sympy2torch(expression, [x, y, z])
|
60 |
+
|
61 |
+
X = torch.rand(100, 3).float() * 10
|
62 |
+
|
63 |
+
true_out = (
|
64 |
+
X[:, 0] ** 2 + torch.atanh(torch.fmod(X[:, 1] + 1, 2) - 1) * 3.2 * X[:, 2]
|
65 |
+
)
|
66 |
+
torch_out = module(X)
|
67 |
+
|
68 |
+
np.testing.assert_array_almost_equal(
|
69 |
+
true_out.detach(), torch_out.detach(), decimal=4
|
70 |
+
)
|