# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. import numpy as np import torch from vidar.geometry.camera import Camera from vidar.utils.tensor import grid_sample from vidar.utils.types import is_tensor def warp_bins(rgb, cam, bins): """ Warp an image based on depth bins Parameters ---------- rgb : torch.Tensor [B,?,H,W] Input image for warping cam : Camera Input camera bins : torch.Tensor Depth bins for warping Returns ------- warped : torch.Tensor Warped images for each depth bin """ ones = torch.ones((1, *cam.hw)).to(rgb.device) volume = torch.stack([depth * ones for depth in bins], 1) coords_volume = cam.coords_from_cost_volume(volume) return grid_sample( rgb.repeat(len(bins), 1, 1, 1), coords_volume[0].type(rgb.dtype), padding_mode='zeros', mode='bilinear', align_corners=True) def sample(grid, pred): """ Sample a grid based on predictions Parameters ---------- grid : torch.Tensor Grid to be sampled [B,?,H,W] pred : torch.Tensor Coordinate predictions [B,2,H,W] Returns ------- values : torch.Tensor Sampled grid[B,?,H,W] """ n, _, h, w = grid.shape coords = pred.permute(1, 2, 0).reshape(-1, 1, 1, 1).repeat(1, 1, 1, 2) coords = 2 * coords / (n - 1) - 1 grid = grid.permute(2, 3, 0, 1).reshape(-1, 1, n, 1).repeat(1, 1, 1, 2) values = grid_sample(grid, coords, padding_mode='zeros', mode='bilinear', align_corners=True) return values.reshape(h, w, 1, 1).permute(2, 3, 0, 1) def compute_depth_bin(min_depth, max_depth, num_bins, i): """ Calculate a single SID depth bin Parameters ---------- min_depth : Float Minimum depth value max_depth : Float Maximum depth value num_bins : Int Number of depth bins i : Int Index of the depth bin in the interval Returns ------- bin : torch.Tensor Corresponding depth bin """ return torch.exp(np.log(min_depth) + np.log(max_depth / min_depth) * i / (num_bins - 1)).\ clamp(min=min_depth, max=max_depth) def uncompute_depth_bin(min_depth, max_depth, num_bins, depth): """ Recover the SID bin index from a depth value Parameters ---------- min_depth : Float Minimum depth value max_depth : Float Maximum depth value num_bins : Int Number of depth bins depth : torch.Tensor Depth value Returns ------- index : torch.Tensor Index for the depth value in the SID interval """ return (num_bins - 1) * ((torch.log(depth) - np.log(min_depth)) / np.log(max_depth / min_depth)).clamp(min=0, max=num_bins) def compute_depth_bins(min_depth, max_depth, num_bins, mode): """ Compute depth bins for an interval Parameters ---------- min_depth : Float Minimum depth value max_depth : Float Maximum depth value num_bins : Int Number of depth bins mode : String Depth discretization mode Returns ------- bins : torch.Tensor Discretized depth bins """ if is_tensor(min_depth): min_depth = min_depth.detach().cpu() if is_tensor(max_depth): max_depth = max_depth.detach().cpu() if mode == 'inverse': depth_bins = 1. / np.linspace( 1. / max_depth, 1. / min_depth, num_bins)[::-1] elif mode == 'linear': depth_bins = np.linspace( min_depth, max_depth, num_bins) elif mode == 'sid': depth_bins = np.array( [np.exp(np.log(min_depth) + np.log(max_depth / min_depth) * i / (num_bins - 1)) for i in range(num_bins)]) else: raise NotImplementedError return torch.from_numpy(depth_bins).float()