File size: 7,670 Bytes
fc16538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# TRI-VIDAR - Copyright 2022 Toyota Research Institute.  All rights reserved.

import torch
from torch_scatter import scatter_min

from vidar.geometry.camera import Camera
from vidar.utils.tensor import unnorm_pixel_grid


class CameraFull(Camera):
    """Camera class with additional functionality"""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.convert_matrix = torch.tensor(
            [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]],
            dtype=torch.float32,
        ).unsqueeze(0)

    @staticmethod
    def from_list(cams):
        """Create cameras from a list"""
        K = torch.cat([cam.K for cam in cams], 0)
        Twc = torch.cat([cam.Twc.T for cam in cams], 0)
        return CameraFull(K=K, Twc=Twc, hw=cams[0].hw)

    def switch(self):
        """Switch camera between conventions"""
        T = self.convert_matrix.to(self.device)
        Twc = T @ self.Twc.T @ T
        return type(self)(K=self.K, Twc=Twc, hw=self.hw)

    def bwd(self):
        """Switch camera to the backwards convention"""
        T = self.convert_matrix.to(self.device)
        Tcw = T @ self.Twc.T @ T
        return type(self)(K=self.K, Tcw=Tcw, hw=self.hw)

    def fwd(self):
        """Switch camera to the forward convention"""
        T = self.convert_matrix.to(self.device)
        Twc = T @ self.Tcw.T @ T
        return type(self)(K=self.K, Twc=Twc, hw=self.hw)

    def look_at(self, at, up=torch.Tensor([0, 1, 0])):
        """
        Set a direction for the camera to point (in-place)

        Parameters
        ----------
        at : torch.Tensor
            Where the camera should be pointing at [B,3]
        up : torch.Tensor
            Up direction [B,3]
        """
        eps = 1e-5
        eye = self.Tcw.T[:, :3, -1]

        at = at.unsqueeze(0)
        up = up.unsqueeze(0).to(at.device)

        z_axis = at - eye
        z_axis /= z_axis.norm(dim=-1, keepdim=True) + eps

        up = up.expand(z_axis.shape)
        x_axis = torch.cross(up, z_axis)
        x_axis /= x_axis.norm(dim=-1, keepdim=True) + eps

        y_axis = torch.cross(z_axis, x_axis)
        y_axis /= y_axis.norm(dim=-1, keepdim=True) + eps

        R = torch.stack((x_axis, y_axis, z_axis), dim=-1)

        Tcw = self.Tcw
        Tcw.T[:, :3, :3] = R
        self.Twc = Tcw.inverse()

    def get_origin(self, flatten=False):
        """Return camera origin"""
        orig = self.Tcw.T[:, :3, -1].view(len(self), 3, 1, 1).repeat(1, 1, *self.hw)
        if flatten:
            orig = orig.reshape(len(self), 3, -1).permute(0, 2, 1)
        return orig

    def get_viewdirs(self, normalize=False, flatten=False, to_world=False):
        """Return camera viewing rays"""
        ones = torch.ones((len(self), 1, *self.hw), dtype=self.dtype, device=self.device)
        rays = self.reconstruct_depth_map(ones, to_world=False)
        if normalize:
            rays = rays / torch.norm(rays, dim=1).unsqueeze(1)
        if to_world:
            rays = self.to_world(rays).reshape(len(self), 3, *self.hw)
        if flatten:
            rays = rays.reshape(len(self), 3, -1).permute(0, 2, 1)
        return rays

    def get_render_rays(self, near=None, far=None, n_rays=None, gt=None):
        """
        Get render rays

        Parameters
        ----------
        near : Float
            Near plane
        far : Float
            Far plane
        n_rays : Int
            Number of rays
        gt : torch.Tensor
            Ground-truth values for concatenation

        Returns
        -------
        rays : torch.Tensor
            Camera viewing rays
        """
        b = len(self)

        ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device)

        rays = self.reconstruct_depth_map(ones, to_world=False)
        rays = rays / torch.norm(rays, dim=1).unsqueeze(1)

        rays[:, 1] = - rays[:, 1]
        rays[:, 2] = - rays[:, 2]

        orig = self.pose[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw)
        rays = self.no_translation().inverted_pose().to_world(rays).reshape(b, 3, *self.hw)

        info = [orig, rays]
        if near is not None:
            info = info + [near * ones]
        if far is not None:
            info = info + [far * ones]
        if gt is not None:
            info = info + [gt]

        rays = torch.cat(info, 1)
        rays = rays.permute(0, 2, 3, 1).reshape(b, -1, rays.shape[1])

        if n_rays is not None:
            idx = torch.randint(0, self.n_pixels, (n_rays,))
            rays = rays[:, idx, :]

        return rays

    def get_plucker(self):
        """Get plucker vectors"""
        b = len(self)
        ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device)
        rays = self.reconstruct_depth_map(ones, to_world=False)
        rays = rays / torch.norm(rays, dim=1).unsqueeze(1)
        orig = self.Tcw.T[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw)

        orig = orig.view(1, 3, -1).permute(0, 2, 1)
        rays = rays.view(1, 3, -1).permute(0, 2, 1)

        cross = torch.cross(orig, rays, dim=-1)
        plucker = torch.cat((rays, cross), dim=-1)

        return plucker

    def project_pointcloud(self, pcl_src, rgb_src, thr=1):
        """
        Project pointcloud to the camera plane

        Parameters
        ----------
        pcl_src : torch.Tensor
            Input 3D pointcloud
        rgb_src : torch.Tensor
            Pointcloud color information
        thr : Int
            Threshold for the number of valid points

        Returns
        -------
        rgb_tgt : torch.Tensor
            Projected image [B,3,H,W]
        depth_tgt : torch.Tensor
            Projected depth map [B,1,H,W]
        """
        if rgb_src.dim() == 4:
            rgb_src = rgb_src.view(*rgb_src.shape[:2], -1)

        # Get projected coordinates and depth values
        uv_all, z_all = self.project_points(pcl_src, return_z=True, from_world=True)

        rgbs_tgt, depths_tgt = [], []

        b = pcl_src.shape[0]
        for i in range(b):
            uv, z = uv_all[i].reshape(-1, 2), z_all[i].reshape(-1, 1)

            # Remove out-of-bounds coordinates and points behind the camera
            idx = (uv[:, 0] >= -1) & (uv[:, 0] <= 1) & \
                  (uv[:, 1] >= -1) & (uv[:, 1] <= 1) & (z[:, 0] > 0.0)

            # Unormalize and stack coordinates for scatter operation
            uv = (unnorm_pixel_grid(uv[idx], self.hw)).round().long()
            uv = uv[:, 0] + uv[:, 1] * self.hw[1]

            # Min scatter operation (only keep the closest depth)
            depth_tgt = 1e10 * torch.ones((self.hw[0] * self.hw[1], 1), device=pcl_src.device)
            depth_tgt, argmin = scatter_min(src=z[idx], index=uv.unsqueeze(1), dim=0, out=depth_tgt)
            depth_tgt[depth_tgt == 1e10] = 0.

            num_valid = (depth_tgt > 0).sum()
            if num_valid > thr:

                # Substitute invalid values with zero
                invalid = argmin == argmin.max()
                argmin[invalid] = 0
                rgb_tgt = rgb_src[i].permute(1, 0)[idx][argmin]
                rgb_tgt[invalid] = -1

            else:

                rgb_tgt = -1 * torch.ones(1, self.n_pixels, 3, device=self.device, dtype=self.dtype)

            # Reshape outputs
            rgb_tgt = rgb_tgt.reshape(1, self.hw[0], self.hw[1], 3).permute(0, 3, 1, 2)
            depth_tgt = depth_tgt.reshape(1, 1, self.hw[0], self.hw[1])

            rgbs_tgt.append(rgb_tgt)
            depths_tgt.append(depth_tgt)

        rgb_tgt = torch.cat(rgbs_tgt, 0)
        depth_tgt = torch.cat(depths_tgt, 0)

        return rgb_tgt, depth_tgt