# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. import numpy as np import torch import torch.nn.functional as tfunc from vidar.utils.tensor import pixel_grid, cat_channel_ones def bearing_grid(rgb, intrinsics): """ Create a homogeneous bearing grid from camera intrinsics and a base image Parameters ---------- rgb : torch.Tensor Base image for dimensions [B,3,H,W] intrinsics : torch.Tensor Camera intrinsics [B,3,3] Returns ------- grid : torch.Tensor Bearing grid [B,3,H,W] """ # Create pixel grid from base image b, _, h, w = rgb.shape grid = pixel_grid((h, w), b).to(rgb.device) # Normalize pixel grid with camera parameters grid[:, 0] = (grid[:, 0] - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1) grid[:, 1] = (grid[:, 1] - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1) # Return bearing grid (with 1s as extra dimension) return cat_channel_ones(grid) def mult_rotation_bearing(rotation, bearing): """ Rotates a bearing grid Parameters ---------- rotation : torch.Tensor Rotation matrix [B,3,3] bearing : torch.Tensor Bearing grid [B,3,H,W] Returns ------- rot_bearing : torch.Tensor Rotated bearing grid [B,3,H,W] """ # Multiply rotation and bearing product = torch.bmm(rotation, bearing.view(bearing.shape[0], 3, -1)) # Return product with bearing shape return product.view(bearing.shape) def pre_triangulation(ref_bearings, ref_translations, tgt_flows, intrinsics, concat=True): """ Triangulates bearings and flows Parameters ---------- ref_bearings : list[torch.Tensor] Reference bearings [B,3,H,W] ref_translations : list[torch.Tensor] Reference translations [B,3] tgt_flows : list[torch.Tensor] Target optical flow values [B,2,H,W] intrinsics : torch.Tensor Camera intrinsics [B,3,3] concat : Bool True if cross product results are concatenated Returns ------- rs : torch.Tensor or list[torch.Tensor] Bearing x translation cross product [B,3,H,W] (concatenated or not) ss : torch.Tensor or list[torch.Tensor] Bearing x bearing cross product [B,3,H,W] (concatenated or not) """ # Get target bearings from flow tgt_bearings = [flow2bearing(flow, intrinsics, normalize=True) for flow in tgt_flows] # Bearings x translation cross product rs = [torch.cross(tgt_bearing, ref_translation[:, :, None, None].expand_as(tgt_bearing), dim=1) for tgt_bearing, ref_translation in zip(tgt_bearings, ref_translations)] # Bearings x bearings cross product ss = [torch.cross(tgt_bearing, ref_bearing, dim=1) for tgt_bearing, ref_bearing in zip(tgt_bearings, ref_bearings)] if concat: # If results are to be concatenated return torch.cat(rs, dim=1), torch.cat(ss, dim=1) else: # Otherwise, return as lists return rs, ss def depth_ls2views(r, s, clip_range=None): """ Least-squares depth estimation from two views Parameters ---------- r : torch.Tensor Bearing x translation cross product between images [B,3,H,W] s : torch.Tensor Bearing x translation cross product between images [B,3,H,W] clip_range : Tuple Depth clipping range (min, max) Returns ------- depth : torch.Tensor Calculated depth [B,1,H,W] error : torch.Tensor Calculated error [B,1,H,W] hessian : torch.Tensor Calculated hessian [B,1,H,W] """ # Calculate matrices hessian = (s * s).sum(dim=1, keepdims=True) depth = -(s * r).sum(dim=1, keepdims=True) / (hessian + 1e-30) error = (r * r).sum(dim=1, keepdims=True) - hessian * (depth ** 2) # Clip depth and other matrices if requested if clip_range is not None: invalid_mask = (depth <= clip_range[0]) invalid_mask |= (depth >= clip_range[1]) depth[invalid_mask] = 0 error[invalid_mask] = 0 hessian[invalid_mask] = 0 # Return calculated matrices return depth, error, hessian def flow2bearing(flow, intrinsics, normalize=True): """ Convert optical flow to bearings Parameters ---------- flow : torch.Tensor Input optical flow [B,2,H,W] intrinsics : torch.Tensor Camera intrinsics [B,3,3] normalize : Bool True if bearings are normalized Returns ------- bearings : torch.Tensor Calculated bearings [B,3,H,W] """ # Create initial grid height, width = flow.shape[2:] xx, yy = np.meshgrid(range(width), range(height)) # Initialize bearing matrix bearings = torch.zeros_like(flow) # Populate bearings match = (flow[:, 0] + torch.from_numpy(xx).to(flow.device), flow[:, 1] + torch.from_numpy(yy).to(flow.device)) bearings[:, 0] = (match[0] - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1) bearings[:, 1] = (match[1] - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)) / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1) # Stack 1s as the last dimension bearings = cat_channel_ones(bearings) # Normalize if necessary if normalize: bearings = tfunc.normalize(bearings) # Return bearings return bearings def triangulation(ref_bearings, ref_translations, tgt_flows, intrinsics, clip_range=None, residual=False): """ Triangulate optical flow points to produce depth estimates Parameters ---------- ref_bearings : list[torch.Tensor] Reference bearings [B,3,H,W] ref_translations : list[torch.Tensor] Reference translations [B,3] tgt_flows : list[torch.Tensor] Target optical flow to reference [B,2,H,W] intrinsics : torch.Tensor Camera intrinsics [B,3,3] clip_range : Tuple Depth clipping range residual : Bool True to return residual error and squared root of Hessian Returns ------- depth : torch.Tensor Estimated depth [B,1,H,W] error : torch.Tensor Estimated error [B,1,H,W] sqrt_hessian : torch.Tensor Squared root of Hessian [B,1,H,W] """ # Pre-triangulate flows rs, ss = pre_triangulation(ref_bearings, ref_translations, tgt_flows, intrinsics, concat=False) # Calculate list of triangulations outputs = [depth_ls2views(*rs_ss, clip_range=clip_range) for rs_ss in zip(rs, ss)] # Calculate predicted hessian and depths hessian = sum([output[2] for output in outputs]) depth = sum([output[0] * output[2] for output in outputs]) / (hessian + 1e-12) # Return depth + residual error and hessian matrix if residual: error = torch.sqrt(sum([output[2] * (depth - output[0]) ** 2 + output[1] for output in outputs]).clamp_min(0)) sqrt_hessian = torch.sqrt(hessian) return depth, (error, sqrt_hessian) # Return depth else: return depth