File size: 7,247 Bytes
fc16538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# TRI-VIDAR - Copyright 2022 Toyota Research Institute.  All rights reserved.

import numpy as np
import torch
import torch.nn.functional as tfunc

from vidar.utils.tensor import pixel_grid, cat_channel_ones


def bearing_grid(rgb, intrinsics):
    """
    Create a homogeneous bearing grid from camera intrinsics and a base image

    Parameters
    ----------
    rgb : torch.Tensor
        Base image for dimensions [B,3,H,W]
    intrinsics : torch.Tensor
        Camera intrinsics [B,3,3]

    Returns
    -------
    grid : torch.Tensor
        Bearing grid [B,3,H,W]
    """
    # Create pixel grid from base image
    b, _, h, w = rgb.shape
    grid = pixel_grid((h, w), b).to(rgb.device)
    # Normalize pixel grid with camera parameters
    grid[:, 0] = (grid[:, 0] - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1)
    grid[:, 1] = (grid[:, 1] - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1)
    # Return bearing grid (with 1s as extra dimension)
    return cat_channel_ones(grid)


def mult_rotation_bearing(rotation, bearing):
    """
    Rotates a bearing grid

    Parameters
    ----------
    rotation : torch.Tensor
        Rotation matrix [B,3,3]
    bearing : torch.Tensor
        Bearing grid [B,3,H,W]

    Returns
    -------
    rot_bearing : torch.Tensor
        Rotated bearing grid [B,3,H,W]
    """
    # Multiply rotation and bearing
    product = torch.bmm(rotation, bearing.view(bearing.shape[0], 3, -1))
    # Return product with bearing shape
    return product.view(bearing.shape)


def pre_triangulation(ref_bearings, ref_translations, tgt_flows,
                      intrinsics, concat=True):
    """
    Triangulates bearings and flows

    Parameters
    ----------
    ref_bearings : list[torch.Tensor]
        Reference bearings [B,3,H,W]
    ref_translations : list[torch.Tensor]
        Reference translations [B,3]
    tgt_flows : list[torch.Tensor]
        Target optical flow values [B,2,H,W]
    intrinsics : torch.Tensor
        Camera intrinsics [B,3,3]
    concat : Bool
        True if cross product results are concatenated

    Returns
    -------
    rs : torch.Tensor or list[torch.Tensor]
        Bearing x translation cross product [B,3,H,W] (concatenated or not)
    ss : torch.Tensor or list[torch.Tensor]
        Bearing x bearing cross product [B,3,H,W] (concatenated or not)
    """
    # Get target bearings from flow
    tgt_bearings = [flow2bearing(flow, intrinsics, normalize=True)
                    for flow in tgt_flows]
    # Bearings x translation cross product
    rs = [torch.cross(tgt_bearing, ref_translation[:, :, None, None].expand_as(tgt_bearing), dim=1)
          for tgt_bearing, ref_translation in zip(tgt_bearings, ref_translations)]
    # Bearings x bearings cross product
    ss = [torch.cross(tgt_bearing, ref_bearing, dim=1)
          for tgt_bearing, ref_bearing in zip(tgt_bearings, ref_bearings)]
    if concat:
        # If results are to be concatenated
        return torch.cat(rs, dim=1), torch.cat(ss, dim=1)
    else:
        # Otherwise, return as lists
        return rs, ss


def depth_ls2views(r, s, clip_range=None):
    """
    Least-squares depth estimation from two views

    Parameters
    ----------
    r : torch.Tensor
        Bearing x translation cross product between images [B,3,H,W]
    s : torch.Tensor
        Bearing x translation cross product between images [B,3,H,W]
    clip_range : Tuple
        Depth clipping range (min, max)

    Returns
    -------
    depth : torch.Tensor
        Calculated depth [B,1,H,W]
    error : torch.Tensor
        Calculated error [B,1,H,W]
    hessian : torch.Tensor
        Calculated hessian [B,1,H,W]

    """
    # Calculate matrices
    hessian = (s * s).sum(dim=1, keepdims=True)
    depth = -(s * r).sum(dim=1, keepdims=True) / (hessian + 1e-30)
    error = (r * r).sum(dim=1, keepdims=True) - hessian * (depth ** 2)

    # Clip depth and other matrices if requested
    if clip_range is not None:

        invalid_mask = (depth <= clip_range[0])
        invalid_mask |= (depth >= clip_range[1])

        depth[invalid_mask] = 0
        error[invalid_mask] = 0
        hessian[invalid_mask] = 0
    # Return calculated matrices
    return depth, error, hessian


def flow2bearing(flow, intrinsics, normalize=True):
    """
    Convert optical flow to bearings

    Parameters
    ----------
    flow : torch.Tensor
        Input optical flow [B,2,H,W]
    intrinsics : torch.Tensor
        Camera intrinsics [B,3,3]
    normalize : Bool
        True if bearings are normalized

    Returns
    -------
    bearings : torch.Tensor
        Calculated bearings [B,3,H,W]
    """
    # Create initial grid
    height, width = flow.shape[2:]
    xx, yy = np.meshgrid(range(width), range(height))
    # Initialize bearing matrix
    bearings = torch.zeros_like(flow)
    # Populate bearings
    match = (flow[:, 0] + torch.from_numpy(xx).to(flow.device),
             flow[:, 1] + torch.from_numpy(yy).to(flow.device))
    bearings[:, 0] = (match[0] - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1)
    bearings[:, 1] = (match[1] - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1)
    # Stack 1s as the last dimension
    bearings = cat_channel_ones(bearings)
    # Normalize if necessary
    if normalize:
        bearings = tfunc.normalize(bearings)
    # Return bearings
    return bearings


def triangulation(ref_bearings, ref_translations,
                  tgt_flows, intrinsics, clip_range=None, residual=False):
    """
    Triangulate optical flow points to produce depth estimates

    Parameters
    ----------
    ref_bearings : list[torch.Tensor]
        Reference bearings [B,3,H,W]
    ref_translations : list[torch.Tensor]
        Reference translations [B,3]
    tgt_flows : list[torch.Tensor]
        Target optical flow to reference [B,2,H,W]
    intrinsics : torch.Tensor
        Camera intrinsics [B,3,3]
    clip_range : Tuple
        Depth clipping range
    residual : Bool
        True to return residual error and squared root of Hessian

    Returns
    -------
    depth : torch.Tensor
        Estimated depth [B,1,H,W]
    error : torch.Tensor
        Estimated error [B,1,H,W]
    sqrt_hessian : torch.Tensor
        Squared root of Hessian [B,1,H,W]
    """
    # Pre-triangulate flows
    rs, ss = pre_triangulation(ref_bearings, ref_translations, tgt_flows, intrinsics, concat=False)
    # Calculate list of triangulations
    outputs = [depth_ls2views(*rs_ss, clip_range=clip_range) for rs_ss in zip(rs, ss)]
    # Calculate predicted hessian and depths
    hessian = sum([output[2] for output in outputs])
    depth = sum([output[0] * output[2] for output in outputs]) / (hessian + 1e-12)
    # Return depth + residual error and hessian matrix
    if residual:
        error = torch.sqrt(sum([output[2] * (depth - output[0]) ** 2 + output[1]
                                for output in outputs]).clamp_min(0))
        sqrt_hessian = torch.sqrt(hessian)
        return depth, (error, sqrt_hessian)
    # Return depth
    else:
        return depth