# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. from functools import reduce import torch import torch.nn.functional as tfn from vidar.utils.decorators import iterate1 from vidar.utils.types import is_tensor, is_dict, is_seq @iterate1 def interpolate(tensor, size, scale_factor, mode, align_corners): """ Interpolate a tensor to a different resolution Parameters ---------- tensor : torch.Tensor Input tensor [B,?,H,W] size : Tuple Interpolation size (H,W) scale_factor : Float Scale factor for interpolation mode : String Interpolation mode align_corners : Bool Corner alignment flag Returns ------- tensor : torch.Tensor Interpolated tensor [B,?,h,w] """ if is_tensor(size): size = size.shape[-2:] return tfn.interpolate( tensor, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=False, ) def masked_average(loss, mask, eps=1e-7): """Calculates the average of a tensor considering mask information""" return (loss * mask).sum() / (mask.sum() + eps) def multiply_mask(data, mask): """Multiplies a tensor with a mask""" return data if (data is None or mask is None) else data * mask def multiply_args(*args): """Multiplies input arguments""" valids = [v for v in args if v is not None] return None if not valids else reduce((lambda x, y: x * y), valids) def grid_sample(tensor, grid, padding_mode, mode, align_corners): return tfn.grid_sample(tensor, grid, padding_mode=padding_mode, mode=mode, align_corners=align_corners) def pixel_grid(hw, b=None, with_ones=False, device=None, normalize=False): """ Creates a pixel grid for image operations Parameters ---------- hw : Tuple Height/width of the grid b : Int Batch size with_ones : Bool Stack an extra channel with 1s device : String Device where the grid will be created normalize : Bool Whether the grid is normalized between [-1,1] Returns ------- grid : torch.Tensor Output pixel grid [B,2,H,W] """ if is_tensor(hw): b, hw = hw.shape[0], hw.shape[-2:] if is_tensor(device): device = device.device hi, hf = 0, hw[0] - 1 wi, wf = 0, hw[1] - 1 yy, xx = torch.meshgrid([torch.linspace(hi, hf, hw[0], device=device), torch.linspace(wi, wf, hw[1], device=device)], indexing='ij') if with_ones: grid = torch.stack([xx, yy, torch.ones(hw, device=device)], 0) else: grid = torch.stack([xx, yy], 0) if b is not None: grid = grid.unsqueeze(0).repeat(b, 1, 1, 1) if normalize: grid = norm_pixel_grid(grid) return grid def norm_pixel_grid(grid, hw=None, in_place=False): """ Normalize a pixel grid to be between [0,1] Parameters ---------- grid : torch.Tensor Grid to be normalized [B,2,H,W] hw : Tuple Height/Width for normalization in_place : Bool Whether the operation is done in place or not Returns ------- grid : torch.Tensor Normalized grid [B,2,H,W] """ if hw is None: hw = grid.shape[-2:] if not in_place: grid = grid.clone() grid[:, 0] = 2.0 * grid[:, 0] / (hw[1] - 1) - 1.0 grid[:, 1] = 2.0 * grid[:, 1] / (hw[0] - 1) - 1.0 return grid def unnorm_pixel_grid(grid, hw=None, in_place=False): """ Unnormalize pixel grid to be between [0,H] and [0,W] Parameters ---------- grid : torch.Tensor Grid to be normalized [B,2,H,W] hw : Tuple Height/width for unnormalization in_place : Bool Whether the operation is done in place or not Returns ------- grid : torch.Tensor Unnormalized grid [B,2,H,W] """ if hw is None: hw = grid.shape[-2:] if not in_place: grid = grid.clone() grid[:, 0] = 0.5 * (hw[1] - 1) * (grid[:, 0] + 1) grid[:, 1] = 0.5 * (hw[0] - 1) * (grid[:, 1] + 1) return grid def match_scales(image, targets, num_scales, mode='bilinear', align_corners=True): """ Creates multiple resolution versions of the input to match another list of tensors Parameters ---------- image : torch.Tensor Input image [B,?,H,W] targets : list[torch.Tensor] Target resolutions num_scales : int Number of scales to consider mode : String Interpolation mode align_corners : Bool Corner alignment flag Returns ------- images : list[torch.Tensor] List containing tensors in the required resolutions """ # For all scales images = [] image_shape = image.shape[-2:] for i in range(num_scales): target_shape = targets[i].shape # If image shape is equal to target shape if same_shape(image_shape, target_shape): images.append(image) else: # Otherwise, interpolate images.append(interpolate_image( image, target_shape, mode=mode, align_corners=align_corners)) # Return scaled images return images def cat_channel_ones(tensor, n=1): """ Concatenate tensor with an extra channel of ones Parameters ---------- tensor : torch.Tensor Tensor to be concatenated n : Int Which channel will be concatenated Returns ------- cat_tensor : torch.Tensor Concatenated tensor """ # Get tensor shape with 1 channel shape = list(tensor.shape) shape[n] = 1 # Return concatenation of tensor with ones return torch.cat([tensor, torch.ones(shape, device=tensor.device, dtype=tensor.dtype)], n) def same_shape(shape1, shape2): """Checks if two shapes are the same""" if len(shape1) != len(shape2): return False for i in range(len(shape1)): if shape1[i] != shape2[i]: return False return True def interpolate_image(image, shape=None, scale_factor=None, mode='bilinear', align_corners=True, recompute_scale_factor=False): """ Interpolate an image to a different resolution Parameters ---------- image : torch.Tensor Image to be interpolated [B,?,h,w] shape : torch.Tensor or tuple Output shape [H,W] scale_factor : Float Scale factor for output shape mode : String Interpolation mode align_corners : Bool True if corners will be aligned after interpolation recompute_scale_factor : Bool True if scale factor is recomputed Returns ------- image : torch.Tensor Interpolated image [B,?,H,W] """ assert shape is not None or scale_factor is not None, 'Invalid option for interpolate_image' if mode == 'nearest': align_corners = None # Take last two dimensions as shape if shape is not None: if is_tensor(shape): shape = shape.shape if len(shape) > 2: shape = shape[-2:] # If the shapes are the same, do nothing if same_shape(image.shape[-2:], shape): return image # Interpolate image to match the shape return tfn.interpolate(image, size=shape, scale_factor=scale_factor, mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) def check_assert(pred, gt, atol=1e-5, rtol=1e-5): """ Check two dictionaries with allclose assertions Parameters ---------- pred : Dict Dictionary with predictions gt : Dict Dictionary with ground-truth atol : Float Absolute tolerance rtol : Float Relative tolerance """ for key in gt.keys(): if key in pred.keys(): # assert key in pred and key in gt if is_dict(pred[key]): check_assert(pred[key], gt[key]) elif is_seq(pred[key]): for val1, val2 in zip(pred[key], gt[key]): if is_tensor(val1): assert torch.allclose(val1, val2, atol=atol, rtol=rtol), \ f'Assert error in {key} : {val1.mean().item()} x {val2.mean().item()}' else: assert val1 == val2, \ f'Assert error in {key} : {val1} x {val2}' else: if is_tensor(pred[key]): assert torch.allclose(pred[key], gt[key], atol=atol, rtol=rtol), \ f'Assert error in {key} : {pred[key].mean().item()} x {gt[key].mean().item()}' else: assert pred[key] == gt[key], \ f'Assert error in {key} : {pred[key]} x {gt[key]}' def interleave(data, b): """Interleave data considering multiple batches""" data_interleave = data.unsqueeze(1).expand(-1, b, *data.shape[1:]) return data_interleave.reshape(-1, *data.shape[1:])