jadechoghari commited on
Commit
713a07a
1 Parent(s): 59c0616

Create math_utils.py

Browse files

we are flattening the directory, since HF only supports flat imports

Files changed (1) hide show
  1. math_utils.py +123 -0
math_utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # MIT License
7
+
8
+ # Copyright (c) 2022 Petr Kellnhofer
9
+
10
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ # of this software and associated documentation files (the "Software"), to deal
12
+ # in the Software without restriction, including without limitation the rights
13
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ # copies of the Software, and to permit persons to whom the Software is
15
+ # furnished to do so, subject to the following conditions:
16
+
17
+ # The above copyright notice and this permission notice shall be included in all
18
+ # copies or substantial portions of the Software.
19
+
20
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ # SOFTWARE.
27
+
28
+ import torch
29
+
30
+ def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
31
+ """
32
+ Left-multiplies MxM @ NxM. Returns NxM.
33
+ """
34
+ res = torch.matmul(vectors4, matrix.T)
35
+ return res
36
+
37
+
38
+ def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
39
+ """
40
+ Normalize vector lengths.
41
+ """
42
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
43
+
44
+ def torch_dot(x: torch.Tensor, y: torch.Tensor):
45
+ """
46
+ Dot product of two tensors.
47
+ """
48
+ return (x * y).sum(-1)
49
+
50
+
51
+ def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
52
+ """
53
+ Author: Petr Kellnhofer
54
+ Intersects rays with the [-1, 1] NDC volume.
55
+ Returns min and max distance of entry.
56
+ Returns -1 for no intersection.
57
+ https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
58
+ """
59
+ o_shape = rays_o.shape
60
+ rays_o = rays_o.detach().reshape(-1, 3)
61
+ rays_d = rays_d.detach().reshape(-1, 3)
62
+
63
+
64
+ bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
65
+ bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
66
+ bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
67
+ is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
68
+
69
+ # Precompute inverse for stability.
70
+ invdir = 1 / rays_d
71
+ sign = (invdir < 0).long()
72
+
73
+ # Intersect with YZ plane.
74
+ tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
75
+ tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
76
+
77
+ # Intersect with XZ plane.
78
+ tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
79
+ tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
80
+
81
+ # Resolve parallel rays.
82
+ is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
83
+
84
+ # Use the shortest intersection.
85
+ tmin = torch.max(tmin, tymin)
86
+ tmax = torch.min(tmax, tymax)
87
+
88
+ # Intersect with XY plane.
89
+ tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
90
+ tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
91
+
92
+ # Resolve parallel rays.
93
+ is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
94
+
95
+ # Use the shortest intersection.
96
+ tmin = torch.max(tmin, tzmin)
97
+ tmax = torch.min(tmax, tzmax)
98
+
99
+ # Mark invalid.
100
+ tmin[torch.logical_not(is_valid)] = -1
101
+ tmax[torch.logical_not(is_valid)] = -2
102
+
103
+ return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
104
+
105
+
106
+ def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
107
+ """
108
+ Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
109
+ Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
110
+ """
111
+ # create a tensor of 'num' steps from 0 to 1
112
+ steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
113
+
114
+ # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
115
+ # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
116
+ # "cannot statically infer the expected size of a list in this contex", hence the code below
117
+ for i in range(start.ndim):
118
+ steps = steps.unsqueeze(-1)
119
+
120
+ # the output starts at 'start' and increments until 'stop' in each dimension
121
+ out = start[None] + steps * (stop - start)[None]
122
+
123
+ return out