Jiading Fang
add define
fc16538
raw
history blame
16.7 kB
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
from abc import ABC
from copy import deepcopy
import torch
import torch.nn as nn
from einops import rearrange
from vidar.geometry.camera_utils import invert_intrinsics, scale_intrinsics
from vidar.geometry.pose import Pose
from vidar.geometry.pose_utils import invert_pose
from vidar.utils.tensor import pixel_grid, same_shape, cat_channel_ones, norm_pixel_grid, interpolate, interleave
from vidar.utils.types import is_tensor, is_seq
class Camera(nn.Module, ABC):
"""
Camera class for 3D reconstruction
Parameters
----------
K : torch.Tensor
Camera intrinsics [B,3,3]
hw : Tuple
Camera height and width
Twc : Pose or torch.Tensor
Camera pose (world to camera) [B,4,4]
Tcw : Pose or torch.Tensor
Camera pose (camera to world) [B,4,4]
"""
def __init__(self, K, hw, Twc=None, Tcw=None):
super().__init__()
# Asserts
assert Twc is None or Tcw is None
# Fold if multi-batch
if K.dim() == 4:
K = rearrange(K, 'b n h w -> (b n) h w')
if Twc is not None:
Twc = rearrange(Twc, 'b n h w -> (b n) h w')
if Tcw is not None:
Tcw = rearrange(Tcw, 'b n h w -> (b n) h w')
# Intrinsics
if same_shape(K.shape[-2:], (3, 3)):
self._K = torch.eye(4, dtype=K.dtype, device=K.device).repeat(K.shape[0], 1, 1)
self._K[:, :3, :3] = K
else:
self._K = K
# Pose
if Twc is None and Tcw is None:
self._Twc = torch.eye(4, dtype=K.dtype, device=K.device).unsqueeze(0).repeat(K.shape[0], 1, 1)
else:
self._Twc = invert_pose(Tcw) if Tcw is not None else Twc
if is_tensor(self._Twc):
self._Twc = Pose(self._Twc)
# Resolution
self._hw = hw
if is_tensor(self._hw):
self._hw = self._hw.shape[-2:]
def __getitem__(self, idx):
"""Return batch-wise pose"""
if is_seq(idx):
return type(self).from_list([self.__getitem__(i) for i in idx])
else:
return type(self)(
K=self._K[[idx]],
Twc=self._Twc[[idx]] if self._Twc is not None else None,
hw=self._hw,
)
def __len__(self):
"""Return length as intrinsics batch"""
return self._K.shape[0]
def __eq__(self, cam):
"""Check if two cameras are the same"""
if not isinstance(cam, type(self)):
return False
if self._hw[0] != cam.hw[0] or self._hw[1] != cam.hw[1]:
return False
if not torch.allclose(self._K, cam.K):
return False
if not torch.allclose(self._Twc.T, cam.Twc.T):
return False
return True
def clone(self):
"""Return a copy of this camera"""
return deepcopy(self)
@property
def pose(self):
"""Return camera pose (world to camera)"""
return self._Twc.T
@property
def K(self):
"""Return camera intrinsics"""
return self._K
@K.setter
def K(self, K):
"""Set camera intrinsics"""
self._K = K
@property
def invK(self):
"""Return inverse of camera intrinsics"""
return invert_intrinsics(self._K)
@property
def batch_size(self):
"""Return batch size"""
return self._Twc.T.shape[0]
@property
def hw(self):
"""Return camera height and width"""
return self._hw
@hw.setter
def hw(self, hw):
"""Set camera height and width"""
self._hw = hw
@property
def wh(self):
"""Get camera width and height"""
return self._hw[::-1]
@property
def n_pixels(self):
"""Return number of pixels"""
return self._hw[0] * self._hw[1]
@property
def fx(self):
"""Return horizontal focal length"""
return self._K[:, 0, 0]
@property
def fy(self):
"""Return vertical focal length"""
return self._K[:, 1, 1]
@property
def cx(self):
"""Return horizontal principal point"""
return self._K[:, 0, 2]
@property
def cy(self):
"""Return vertical principal point"""
return self._K[:, 1, 2]
@property
def fxy(self):
"""Return focal length"""
return torch.tensor([self.fx, self.fy], dtype=self.dtype, device=self.device)
@property
def cxy(self):
"""Return principal points"""
return self._K[:, :2, 2]
# return torch.tensor([self.cx, self.cy], dtype=self.dtype, device=self.device)
@property
def Tcw(self):
"""Return camera pose (camera to world)"""
return None if self._Twc is None else self._Twc.inverse()
@Tcw.setter
def Tcw(self, Tcw):
"""Set camera pose (camera to world)"""
self._Twc = Tcw.inverse()
@property
def Twc(self):
"""Return camera pose (world to camera)"""
return self._Twc
@Twc.setter
def Twc(self, Twc):
"""Set camera pose (world to camera)"""
self._Twc = Twc
@property
def dtype(self):
"""Return tensor type"""
return self._K.dtype
@property
def device(self):
"""Return device"""
return self._K.device
def detach_pose(self):
"""Detach pose from the graph"""
return type(self)(K=self._K, hw=self._hw,
Twc=self._Twc.detach() if self._Twc is not None else None)
def detach_K(self):
"""Detach intrinsics from the graph"""
return type(self)(K=self._K.detach(), hw=self._hw, Twc=self._Twc)
def detach(self):
"""Detach camera from the graph"""
return type(self)(K=self._K.detach(), hw=self._hw,
Twc=self._Twc.detach() if self._Twc is not None else None)
def inverted_pose(self):
"""Invert camera pose"""
return type(self)(K=self._K, hw=self._hw,
Twc=self._Twc.inverse() if self._Twc is not None else None)
def no_translation(self):
"""Return new camera without translation"""
Twc = self.pose.clone()
Twc[:, :-1, -1] = 0
return type(self)(K=self._K, hw=self._hw, Twc=Twc)
@staticmethod
def from_dict(K, hw, Twc=None):
"""Create cameras from a pose dictionary"""
return {key: Camera(K=K[0], hw=hw[0], Twc=val) for key, val in Twc.items()}
# @staticmethod
# def from_dict(K, hw, Twc=None):
# return {key: Camera(K=K[(0, 0)], hw=hw[(0, 0)], Twc=val) for key, val in Twc.items()}
@staticmethod
def from_list(cams):
"""Create cameras from a list"""
K = torch.cat([cam.K for cam in cams], 0)
Twc = torch.cat([cam.Twc.T for cam in cams], 0)
return Camera(K=K, Twc=Twc, hw=cams[0].hw)
def scaled(self, scale_factor):
"""Return a scaled camera"""
if scale_factor is None or scale_factor == 1:
return self
if is_seq(scale_factor):
if len(scale_factor) == 4:
scale_factor = scale_factor[-2:]
scale_factor = [float(scale_factor[i]) / float(self._hw[i]) for i in range(2)]
else:
scale_factor = [scale_factor] * 2
return type(self)(
K=scale_intrinsics(self._K, scale_factor),
hw=[int(self._hw[i] * scale_factor[i]) for i in range(len(self._hw))],
Twc=self._Twc
)
def offset_start(self, start):
"""Offset camera intrinsics based on a crop"""
new_cam = self.clone()
start = start.to(self.device)
new_cam.K[:, 0, 2] -= start[:, 1]
new_cam.K[:, 1, 2] -= start[:, 0]
return new_cam
def interpolate(self, rgb):
"""Interpolate an image to fit the camera"""
if rgb.dim() == 5:
rgb = rearrange(rgb, 'b n c h w -> (b n) c h w')
return interpolate(rgb, scale_factor=None, size=self.hw, mode='bilinear', align_corners=True)
def interleave_K(self, b):
"""Interleave camera intrinsics to fit multiple batches"""
return type(self)(
K=interleave(self._K, b),
Twc=self._Twc,
hw=self._hw,
)
def interleave_Twc(self, b):
"""Interleave camera pose to fit multiple batches"""
return type(self)(
K=self._K,
Twc=interleave(self._Twc, b),
hw=self._hw,
)
def interleave(self, b):
"""Interleave camera to fit multiple batches"""
return type(self)(
K=interleave(self._K, b),
Twc=interleave(self._Twc, b),
hw=self._hw,
)
def Pwc(self, from_world=True):
"""Return projection matrix"""
return self._K[:, :3] if not from_world or self._Twc is None else \
torch.matmul(self._K, self._Twc.T)[:, :3]
def to_world(self, points):
"""Transform points to world coordinates"""
if points.dim() > 3:
points = points.reshape(points.shape[0], 3, -1)
return points if self.Tcw is None else self.Tcw * points
def from_world(self, points):
"""Transform points back to camera coordinates"""
if points.dim() > 3:
points = points.reshape(points.shape[0], 3, -1)
return points if self._Twc is None else \
torch.matmul(self._Twc.T, cat_channel_ones(points, 1))[:, :3]
def to(self, *args, **kwargs):
"""Copy camera to device"""
self._K = self._K.to(*args, **kwargs)
if self._Twc is not None:
self._Twc = self._Twc.to(*args, **kwargs)
return self
def cuda(self, *args, **kwargs):
"""Copy camera to CUDA"""
return self.to('cuda')
def relative_to(self, cam):
"""Create a new camera relative to another camera"""
return Camera(K=self._K, hw=self._hw, Twc=self._Twc * cam.Twc.inverse())
def global_from(self, cam):
"""Create a new camera in global coordinates relative to another camera"""
return Camera(K=self._K, hw=self._hw, Twc=self._Twc * cam.Twc)
def reconstruct_depth_map(self, depth, to_world=False):
"""
Reconstruct a depth map from the camera viewpoint
Parameters
----------
depth : torch.Tensor
Input depth map [B,1,H,W]
to_world : Bool
Transform points to world coordinates
Returns
-------
points : torch.Tensor
Output 3D points [B,3,H,W]
"""
if depth is None:
return None
b, _, h, w = depth.shape
grid = pixel_grid(depth, with_ones=True, device=depth.device).view(b, 3, -1)
points = depth.view(b, 1, -1) * torch.matmul(self.invK[:, :3, :3], grid)
if to_world and self.Tcw is not None:
points = self.Tcw * points
return points.view(b, 3, h, w)
def reconstruct_cost_volume(self, volume, to_world=True, flatten=True):
"""
Reconstruct a cost volume from the camera viewpoint
Parameters
----------
volume : torch.Tensor
Input depth map [B,1,D,H,W]
to_world : Bool
Transform points to world coordinates
flatten: Bool
Flatten volume points
Returns
-------
points : torch.Tensor
Output 3D points [B,3,D,H,W]
"""
c, d, h, w = volume.shape
grid = pixel_grid((h, w), with_ones=True, device=volume.device).view(3, -1).repeat(1, d)
points = torch.stack([
(volume.view(c, -1) * torch.matmul(invK[:3, :3].unsqueeze(0), grid)).view(3, d * h * w)
for invK in self.invK], 0)
if to_world and self.Tcw is not None:
points = self.Tcw * points
if flatten:
return points.view(-1, 3, d, h * w).permute(0, 2, 1, 3)
else:
return points.view(-1, 3, d, h, w)
def project_points(self, points, from_world=True, normalize=True, return_z=False):
"""
Project points back to image plane
Parameters
----------
points : torch.Tensor
Input 3D points [B,3,H,W] or [B,3,N]
from_world : Bool
Whether points are in the global frame
normalize : Bool
Whether projections should be normalized to [-1,1]
return_z : Bool
Whether projected depth is return as well
Returns
-------
coords : torch.Tensor
Projected 2D coordinates [B,2,H,W]
depth : torch.Tensor
Projected depth [B,1,H,W]
"""
is_depth_map = points.dim() == 4
hw = self._hw if not is_depth_map else points.shape[-2:]
if is_depth_map:
points = points.reshape(points.shape[0], 3, -1)
b, _, n = points.shape
points = torch.matmul(self.Pwc(from_world), cat_channel_ones(points, 1))
coords = points[:, :2] / (points[:, 2].unsqueeze(1) + 1e-7)
depth = points[:, 2]
if not is_depth_map:
if normalize:
coords = norm_pixel_grid(coords, hw=self._hw, in_place=True)
invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \
(coords[:, 1] < -1) | (coords[:, 1] > 1) | (depth < 0)
coords[invalid.unsqueeze(1).repeat(1, 2, 1)] = -2
if return_z:
return coords.permute(0, 2, 1), depth
else:
return coords.permute(0, 2, 1)
coords = coords.view(b, 2, *hw)
depth = depth.view(b, 1, *hw)
if normalize:
coords = norm_pixel_grid(coords, hw=self._hw, in_place=True)
invalid = (coords[:, 0] < -1) | (coords[:, 0] > 1) | \
(coords[:, 1] < -1) | (coords[:, 1] > 1) | (depth[:, 0] < 0)
coords[invalid.unsqueeze(1).repeat(1, 2, 1, 1)] = -2
if return_z:
return coords.permute(0, 2, 3, 1), depth
else:
return coords.permute(0, 2, 3, 1)
def project_cost_volume(self, points, from_world=True, normalize=True):
"""
Project points back to image plane
Parameters
----------
points : torch.Tensor
Input 3D points [B,3,H,W] or [B,3,N]
from_world : Bool
Whether points are in the global frame
normalize : Bool
Whether projections should be normalized to [-1,1]
Returns
-------
coords : torch.Tensor
Projected 2D coordinates [B,2,H,W]
"""
if points.dim() == 4:
points = points.permute(0, 2, 1, 3).reshape(points.shape[0], 3, -1)
b, _, n = points.shape
points = torch.matmul(self.Pwc(from_world), cat_channel_ones(points, 1))
coords = points[:, :2] / (points[:, 2].unsqueeze(1) + 1e-7)
coords = coords.view(b, 2, -1, *self._hw).permute(0, 2, 3, 4, 1)
if normalize:
coords[..., 0] /= self._hw[1] - 1
coords[..., 1] /= self._hw[0] - 1
return 2 * coords - 1
else:
return coords
def coords_from_cost_volume(self, volume, ref_cam=None):
"""
Get warp coordinates from a cost volume
Parameters
----------
volume : torch.Tensor
Input cost volume [B,1,D,H,W]
ref_cam : Camera
Optional to generate cross-camera coordinates
Returns
-------
coords : torch.Tensor
Projected 2D coordinates [B,2,H,W]
"""
if ref_cam is None:
return self.project_cost_volume(self.reconstruct_cost_volume(volume, to_world=False), from_world=True)
else:
return ref_cam.project_cost_volume(self.reconstruct_cost_volume(volume, to_world=True), from_world=True)
def coords_from_depth(self, depth, ref_cam=None):
"""
Get warp coordinates from a depth map
Parameters
----------
depth : torch.Tensor
Input cost volume [B,1,D,H,W]
ref_cam : Camera
Optional to generate cross-camera coordinates
Returns
-------
coords : torch.Tensor
Projected 2D coordinates [B,2,H,W]
"""
if ref_cam is None:
return self.project_points(self.reconstruct_depth_map(depth, to_world=False), from_world=True)
else:
return ref_cam.project_points(self.reconstruct_depth_map(depth, to_world=True), from_world=True)