File size: 8,089 Bytes
fc16538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# TRI-VIDAR - Copyright 2022 Toyota Research Institute.  All rights reserved.

import torch
import torch.nn.functional as tfn

from vidar.utils.data import make_list
from vidar.utils.flow_triangulation_support import bearing_grid, mult_rotation_bearing, triangulation
from vidar.utils.tensor import pixel_grid, norm_pixel_grid, unnorm_pixel_grid
from vidar.utils.types import is_list


def warp_from_coords(tensor, coords, mode='bilinear',
                     padding_mode='zeros', align_corners=True):
    """
    Warp an image from a coordinate map

    Parameters
    ----------
    tensor : torch.Tensor
        Input tensor for warping [B,?,H,W]
    coords : torch.Tensor
        Warping coordinates [B,2,H,W]
    mode : String
        Warping mode
    padding_mode : String
        Padding mode
    align_corners : Bool
        Align corners flag

    Returns
    -------
    warp : torch.Tensor
        Warped tensor [B,?,H,W]
    """
    # Sample grid from data with coordinates
    warp = tfn.grid_sample(tensor, coords.permute(0, 2, 3, 1),
                           mode=mode, padding_mode=padding_mode,
                           align_corners=align_corners)
    # Returned warped tensor
    return warp


def coords_from_optical_flow(optflow):
    """
    Get warping coordinates from optical flow
    Parameters
    ----------
    optflow : torch.Tensor
        Input optical flow tensor [B,2,H,W]

    Returns
    -------
    coords : torch.Tensor
        Warping coordinates [B,2,H,W]
    """
    # Create coordinate with optical flow
    coords = pixel_grid(optflow, device=optflow) + optflow
    # Normalize and return coordinate grid
    return norm_pixel_grid(coords)


def warp_depth_from_motion(ref_depth, tgt_depth, ref_cam):
    """
    Warp depth map using motion (depth + ego-motion) information

    Parameters
    ----------
    ref_depth : torch.Tensor
        Reference depth map [B,1,H,W]
    tgt_depth : torch.Tensor
        Target depth map [B,1,H,W]
    ref_cam : Camera
        Reference camera

    Returns
    -------
    warp : torch.Tensor
        Warped depth map [B,1,H,W]
    """
    ref_depth = reproject_depth_from_motion(ref_depth, ref_cam)
    return warp_from_motion(ref_depth, tgt_depth, ref_cam)


def reproject_depth_from_motion(ref_depth, ref_cam):
    """
    Calculate reprojected depth from motion (depth + ego-motion) information

    Parameters
    ----------
    ref_depth : torch.Tensor
        Reference depth map [B,1,H,W]
    ref_cam : Camera
        Reference camera

    Returns
    -------
    coords : torch.Tensor
        Warping coordinates from reprojection [B,2,H,W]
    """
    ref_points = ref_cam.reconstruct_depth_map(ref_depth, to_world=True)
    return ref_cam.project_points(ref_points, from_world=False, return_z=True)[1]


def warp_from_motion(ref_rgb, tgt_depth, ref_cam):
    """
    Warp image using motion (depth + ego-motion) information

    Parameters
    ----------
    ref_rgb : torch.Tensor
        Reference image [B,3,H,W]
    tgt_depth : torch.Tensor
        Target depth map [B,1,H,W]
    ref_cam : Camera
        Reference camera

    Returns
    -------
    warp : torch.Tensor
        Warped image [B,3,H,W]
    """
    tgt_points = ref_cam.reconstruct_depth_map(tgt_depth, to_world=False)
    return warp_from_coords(ref_rgb, ref_cam.project_points(tgt_points, from_world=True).permute(0, 3, 1, 2))


def coords_from_motion(ref_camera, tgt_depth, tgt_camera):
    """
    Get coordinates from motion (depth + ego-motion) information

    Parameters
    ----------
    ref_camera : Camera
        Reference camera
    tgt_depth : torch.Tensor
        Target depth map [B,1,H,W]
    tgt_camera : Camera
        Target camera

    Returns
    -------
    coords : torch.Tensor
        Warping coordinates [B,2,H,W]
    """
    if is_list(ref_camera):
        return [coords_from_motion(camera, tgt_depth, tgt_camera)
                for camera in ref_camera]
    # If there are multiple depth maps, iterate for each
    if is_list(tgt_depth):
        return [coords_from_motion(ref_camera, depth, tgt_camera)
                for depth in tgt_depth]
    world_points = tgt_camera.reconstruct_depth_map(tgt_depth, to_world=True)
    return ref_camera.project_points(world_points, from_world=True).permute(0, 3, 1, 2)


def optflow_from_motion(ref_camera, tgt_depth):
    """
    Get optical flow from motion (depth + ego-motion) information

    Parameters
    ----------
    ref_camera : Camera
        Reference camera
    tgt_depth : torch.Tensor
        Target depth map

    Returns
    -------
    optflow : torch.Tensor
        Optical flow map [B,2,H,W]
    """
    coords = ref_camera.coords_from_depth(tgt_depth).permute(0, 3, 1, 2)
    return optflow_from_coords(coords)


def optflow_from_coords(coords):
    """
    Get optical flow from coordinates

    Parameters
    ----------
    coords : torch.Tensor
        Input warping coordinates [B,2,H,W]

    Returns
    -------
    optflow : torch.Tensor
        Optical flow map [B,2,H,W]
    """
    return unnorm_pixel_grid(coords) - pixel_grid(coords, device=coords)


def warp_from_optflow(ref_rgb, tgt_optflow):
    """
    Warp image using optical flow information

    Parameters
    ----------
    ref_rgb : torch.Tensor
        Reference image [B,3,H,W]
    tgt_optflow : torch.Tensor
        Target optical flow [B,2,H,W]

    Returns
    -------
    warp : torch.Tensor
        Warped image [B,3,H,W]
    """
    coords = coords_from_optical_flow(tgt_optflow)
    return warp_from_coords(ref_rgb, coords, align_corners=True,
                            mode='bilinear', padding_mode='zeros')


def reverse_optflow(tgt_optflow, ref_optflow):
    """
    Reverse optical flow

    Parameters
    ----------
    tgt_optflow : torch.Tensor
        Target optical flow [B,2,H,W]
    ref_optflow : torch.Tensor
        Reference optical flow [B,2,H,W]

    Returns
    -------
    optflow : torch.Tensor
        Reversed optical flow [B,2,H,W]
    """
    return - warp_from_optflow(tgt_optflow, ref_optflow)


def mask_from_coords(coords, align_corners=True):
    """
    Get overlap mask from coordinates

    Parameters
    ----------
    coords : torch.Tensor
        Warping coordinates [B,2,H,W]
    align_corners : Bool
        Align corners flag

    Returns
    -------
    mask : torch.Tensor
        Overlap mask [B,1,H,W]
    """
    if is_list(coords):
        return [mask_from_coords(coord) for coord in coords]
    b, _, h, w = coords.shape
    mask = torch.ones((b, 1, h, w), dtype=torch.float32, device=coords.device, requires_grad=False)
    mask = warp_from_coords(mask, coords, mode='nearest', padding_mode='zeros', align_corners=True)
    return mask.bool()


def depth_from_optflow(rgb, intrinsics, pose_context, flows,
                       residual=False, clip_range=None):
    """
    Get depth from optical flow + camera information

    Parameters
    ----------
    rgb : torch.Tensor
        Base image [B,3,H,W]
    intrinsics : torch.Tensor
        Camera intrinsics [B,3,3]
    pose_context : torch.Tensor or list[torch.Tensor]
        List of relative context camera poses [B,4,4]
    flows : torch.Tensor or list[torch.Tensor]
        List of target optical flows [B,2,H,W]
    residual : Bool
        Return residual error with depth
    clip_range : Tuple
        Depth range clipping values

    Returns
    -------
    depth : torch.Tensor
        Depth map [B,1,H,W]
    """
    # Make lists if necessary
    flows = make_list(flows)
    pose_context = make_list(pose_context)
    # Extract rotations and translations
    rotations = [p[:, :3, :3] for p in pose_context]
    translations = [p[:, :3, -1] for p in pose_context]
    # Get bearings
    bearings = bearing_grid(rgb, intrinsics).to(rgb.device)
    rot_bearings = [mult_rotation_bearing(rotation, bearings)
                    for rotation in rotations]
    # Return triangulation results
    return triangulation(rot_bearings, translations, flows, intrinsics,
                         clip_range=clip_range, residual=residual)