File size: 9,311 Bytes
fc16538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# TRI-VIDAR - Copyright 2022 Toyota Research Institute.  All rights reserved.

from functools import lru_cache

import torch
import torch.nn as nn

from vidar.arch.networks.layers.fsm.camera_utils import scale_intrinsics, invert_intrinsics
from vidar.arch.networks.layers.fsm.pose import Pose
from vidar.utils.tensor import pixel_grid
from vidar.utils.types import is_tensor, is_list


class Camera(nn.Module):
    """
    Differentiable camera class implementing reconstruction and projection
    functions for a pinhole model.
    """
    def __init__(self, K, Tcw=None, Twc=None, hw=None):
        """
        Initializes the Camera class

        Parameters
        ----------
        K : torch.Tensor
            Camera intrinsics [B,3,3]
        Tcw : Pose or torch.Tensor
            Camera -> World pose transformation [B,4,4]
        Twc : Pose or torch.Tensor
            World -> Camera pose transformation [B,4,4]
        hw : tuple or torch.Tensor
            Camera width and height, or a tensor with the proper shape
        """
        super().__init__()
        assert Tcw is None or Twc is None, 'You should provide either Tcw or Twc'
        self.K = K
        self.hw = None if hw is None else hw.shape[-2:] if is_tensor(hw) else hw[-2:]
        if Tcw is not None:
            self.Tcw = Tcw if isinstance(Tcw, Pose) else Pose(Tcw)
        elif Twc is not None:
            self.Tcw = Twc.inverse() if isinstance(Twc, Pose) else Pose(Twc).inverse()
        else:
            self.Tcw = Pose.identity(len(self.K))

    def __len__(self):
        """Batch size of the camera intrinsics"""
        return len(self.K)

    def __getitem__(self, idx):
        """Return single camera from a batch position"""
        return Camera(K=self.K[idx].unsqueeze(0),
                      hw=self.hw, Tcw=self.Tcw[idx]).to(self.device)

    @property
    def wh(self):
        """Return camera width and height"""
        return None if self.hw is None else self.hw[::-1]

    @property
    def pose(self):
        """Return camera pose"""
        return self.Twc.mat

    @property
    def device(self):
        """Return camera device"""
        return self.K.device

    def invert_pose(self):
        """Return new camera with inverted pose"""
        return Camera(K=self.K, Tcw=self.Twc)

    def to(self, *args, **kwargs):
        """Moves object to a specific device"""
        self.K = self.K.to(*args, **kwargs)
        self.Tcw = self.Tcw.to(*args, **kwargs)
        return self

    @property
    def fx(self):
        """Focal length in x"""
        return self.K[:, 0, 0]

    @property
    def fy(self):
        """Focal length in y"""
        return self.K[:, 1, 1]

    @property
    def cx(self):
        """Principal point in x"""
        return self.K[:, 0, 2]

    @property
    def cy(self):
        """Principal point in y"""
        return self.K[:, 1, 2]

    @property
    @lru_cache()
    def Twc(self):
        """World -> Camera pose transformation (inverse of Tcw)"""
        return self.Tcw.inverse()

    @property
    @lru_cache()
    def Kinv(self):
        """Inverse intrinsics (for lifting)"""
        return invert_intrinsics(self.K)

    def equal(self, cam):
        """Check if two cameras are the same"""
        return torch.allclose(self.K, cam.K) and \
               torch.allclose(self.Tcw.mat, cam.Tcw.mat)

    def scaled(self, x_scale, y_scale=None):
        """
        Returns a scaled version of the camera (changing intrinsics)

        Parameters
        ----------
        x_scale : float
            Resize scale in x
        y_scale : float
            Resize scale in y. If None, use the same as x_scale

        Returns
        -------
        camera : Camera
            Scaled version of the current camera
        """
        # If single value is provided, use for both dimensions
        if y_scale is None:
            y_scale = x_scale
        # If no scaling is necessary, return same camera
        if x_scale == 1. and y_scale == 1.:
            return self
        # Scale intrinsics
        K = scale_intrinsics(self.K.clone(), x_scale, y_scale)
        # Scale image dimensions
        hw = None if self.hw is None else (int(self.hw[0] * y_scale),
                                           int(self.hw[1] * x_scale))
        # Return scaled camera
        return Camera(K=K, Tcw=self.Tcw, hw=hw)

    def scaled_K(self, shape):
        """Return scaled intrinsics to match a shape"""
        if self.hw is None:
            return self.K
        else:
            y_scale, x_scale = [sh / hw for sh, hw in zip(shape[-2:], self.hw)]
            return scale_intrinsics(self.K, x_scale, y_scale)

    def scaled_Kinv(self, shape):
        """Return scaled inverse intrinsics to match a shape"""
        return invert_intrinsics(self.scaled_K(shape))

    def reconstruct(self, depth, frame='w', scene_flow=None, return_grid=False):
        """
        Reconstructs pixel-wise 3D points from a depth map.

        Parameters
        ----------
        depth : torch.Tensor
            Depth map for the camera [B,1,H,W]
        frame : 'w'
            Reference frame: 'c' for camera and 'w' for world
        scene_flow : torch.Tensor
            Optional per-point scene flow to be added (camera reference frame) [B,3,H,W]
        return_grid : bool
            Return pixel grid as well

        Returns
        -------
        points : torch.tensor
            Pixel-wise 3D points [B,3,H,W]
        """
        # If depth is a list, return each reconstruction
        if is_list(depth):
            return [self.reconstruct(d, frame, scene_flow, return_grid) for d in depth]
        # Dimension assertions
        assert depth.dim() == 4 and depth.shape[1] == 1, \
            'Wrong dimensions for camera reconstruction'

        # Create flat index grid [B,3,H,W]
        B, _, H, W = depth.shape
        grid = pixel_grid((H, W), B, device=depth.device, normalize=False, with_ones=True)
        flat_grid = grid.view(B, 3, -1)

        # Get inverse intrinsics
        Kinv = self.Kinv if self.hw is None else self.scaled_Kinv(depth.shape)

        # Estimate the outward rays in the camera frame
        Xnorm = (Kinv.bmm(flat_grid)).view(B, 3, H, W)
        # Scale rays to metric depth
        Xc = Xnorm * depth

        # Add scene flow if provided
        if scene_flow is not None:
            Xc = Xc + scene_flow

        # If in camera frame of reference
        if frame == 'c':
            pass
        # If in world frame of reference
        elif frame == 'w':
            Xc = self.Twc @ Xc
        # If none of the above
        else:
            raise ValueError('Unknown reference frame {}'.format(frame))
        # Return points and grid if requested
        return (Xc, grid) if return_grid else Xc

    def project(self, X, frame='w', normalize=True, return_z=False):
        """
        Projects 3D points onto the image plane

        Parameters
        ----------
        X : torch.Tensor
            3D points to be projected [B,3,H,W]
        frame : 'w'
            Reference frame: 'c' for camera and 'w' for world
        normalize : bool
            Normalize grid coordinates
        return_z : bool
            Return the projected z coordinate as well

        Returns
        -------
        points : torch.Tensor
            2D projected points that are within the image boundaries [B,H,W,2]
        """
        assert 2 < X.dim() <= 4 and X.shape[1] == 3, \
            'Wrong dimensions for camera projection'

        # Determine if input is a grid
        is_grid = X.dim() == 4
        # If it's a grid, flatten it
        X_flat = X.view(X.shape[0], 3, -1) if is_grid else X

        # Get dimensions
        hw = X.shape[2:] if is_grid else self.hw
        # Get intrinsics
        K = self.scaled_K(X.shape) if is_grid else self.K

        # Project 3D points onto the camera image plane
        if frame == 'c':
            Xc = K.bmm(X_flat)
        elif frame == 'w':
            Xc = K.bmm(self.Tcw @ X_flat)
        else:
            raise ValueError('Unknown reference frame {}'.format(frame))

        # Extract coordinates
        Z = Xc[:, 2].clamp(min=1e-5)
        XZ = Xc[:, 0] / Z
        YZ = Xc[:, 1] / Z

        # Normalize points
        if normalize and hw is not None:
            XZ = 2 * XZ / (hw[1] - 1) - 1.
            YZ = 2 * YZ / (hw[0] - 1) - 1.

        # Clamp out-of-bounds pixels
        Xmask = ((XZ > 1) + (XZ < -1)).detach()
        XZ[Xmask] = 2.
        Ymask = ((XZ > 1) + (YZ < -1)).detach()
        YZ[Ymask] = 2.

        # Stack X and Y coordinates
        XY = torch.stack([XZ, YZ], dim=-1)
        # Reshape coordinates to a grid if possible
        if is_grid and hw is not None:
            XY = XY.view(X.shape[0], hw[0], hw[1], 2)

        # If also returning depth
        if return_z:
            # Reshape depth values to a grid if possible
            if is_grid and hw is not None:
                Z = Z.view(X.shape[0], hw[0], hw[1], 1).permute(0, 3, 1, 2)
            # Otherwise, reshape to an array
            else:
                Z = Z.view(X.shape[0], -1, 1).permute(0, 2, 1)
            # Return coordinates and depth values
            return XY, Z
        else:
            # Return coordinates
            return XY