Jiading Fang
add define
fc16538
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
from functools import lru_cache
import torch
import torch.nn as nn
from vidar.arch.networks.layers.fsm.camera_utils import scale_intrinsics, invert_intrinsics
from vidar.arch.networks.layers.fsm.pose import Pose
from vidar.utils.tensor import pixel_grid
from vidar.utils.types import is_tensor, is_list
class Camera(nn.Module):
"""
Differentiable camera class implementing reconstruction and projection
functions for a pinhole model.
"""
def __init__(self, K, Tcw=None, Twc=None, hw=None):
"""
Initializes the Camera class
Parameters
----------
K : torch.Tensor
Camera intrinsics [B,3,3]
Tcw : Pose or torch.Tensor
Camera -> World pose transformation [B,4,4]
Twc : Pose or torch.Tensor
World -> Camera pose transformation [B,4,4]
hw : tuple or torch.Tensor
Camera width and height, or a tensor with the proper shape
"""
super().__init__()
assert Tcw is None or Twc is None, 'You should provide either Tcw or Twc'
self.K = K
self.hw = None if hw is None else hw.shape[-2:] if is_tensor(hw) else hw[-2:]
if Tcw is not None:
self.Tcw = Tcw if isinstance(Tcw, Pose) else Pose(Tcw)
elif Twc is not None:
self.Tcw = Twc.inverse() if isinstance(Twc, Pose) else Pose(Twc).inverse()
else:
self.Tcw = Pose.identity(len(self.K))
def __len__(self):
"""Batch size of the camera intrinsics"""
return len(self.K)
def __getitem__(self, idx):
"""Return single camera from a batch position"""
return Camera(K=self.K[idx].unsqueeze(0),
hw=self.hw, Tcw=self.Tcw[idx]).to(self.device)
@property
def wh(self):
"""Return camera width and height"""
return None if self.hw is None else self.hw[::-1]
@property
def pose(self):
"""Return camera pose"""
return self.Twc.mat
@property
def device(self):
"""Return camera device"""
return self.K.device
def invert_pose(self):
"""Return new camera with inverted pose"""
return Camera(K=self.K, Tcw=self.Twc)
def to(self, *args, **kwargs):
"""Moves object to a specific device"""
self.K = self.K.to(*args, **kwargs)
self.Tcw = self.Tcw.to(*args, **kwargs)
return self
@property
def fx(self):
"""Focal length in x"""
return self.K[:, 0, 0]
@property
def fy(self):
"""Focal length in y"""
return self.K[:, 1, 1]
@property
def cx(self):
"""Principal point in x"""
return self.K[:, 0, 2]
@property
def cy(self):
"""Principal point in y"""
return self.K[:, 1, 2]
@property
@lru_cache()
def Twc(self):
"""World -> Camera pose transformation (inverse of Tcw)"""
return self.Tcw.inverse()
@property
@lru_cache()
def Kinv(self):
"""Inverse intrinsics (for lifting)"""
return invert_intrinsics(self.K)
def equal(self, cam):
"""Check if two cameras are the same"""
return torch.allclose(self.K, cam.K) and \
torch.allclose(self.Tcw.mat, cam.Tcw.mat)
def scaled(self, x_scale, y_scale=None):
"""
Returns a scaled version of the camera (changing intrinsics)
Parameters
----------
x_scale : float
Resize scale in x
y_scale : float
Resize scale in y. If None, use the same as x_scale
Returns
-------
camera : Camera
Scaled version of the current camera
"""
# If single value is provided, use for both dimensions
if y_scale is None:
y_scale = x_scale
# If no scaling is necessary, return same camera
if x_scale == 1. and y_scale == 1.:
return self
# Scale intrinsics
K = scale_intrinsics(self.K.clone(), x_scale, y_scale)
# Scale image dimensions
hw = None if self.hw is None else (int(self.hw[0] * y_scale),
int(self.hw[1] * x_scale))
# Return scaled camera
return Camera(K=K, Tcw=self.Tcw, hw=hw)
def scaled_K(self, shape):
"""Return scaled intrinsics to match a shape"""
if self.hw is None:
return self.K
else:
y_scale, x_scale = [sh / hw for sh, hw in zip(shape[-2:], self.hw)]
return scale_intrinsics(self.K, x_scale, y_scale)
def scaled_Kinv(self, shape):
"""Return scaled inverse intrinsics to match a shape"""
return invert_intrinsics(self.scaled_K(shape))
def reconstruct(self, depth, frame='w', scene_flow=None, return_grid=False):
"""
Reconstructs pixel-wise 3D points from a depth map.
Parameters
----------
depth : torch.Tensor
Depth map for the camera [B,1,H,W]
frame : 'w'
Reference frame: 'c' for camera and 'w' for world
scene_flow : torch.Tensor
Optional per-point scene flow to be added (camera reference frame) [B,3,H,W]
return_grid : bool
Return pixel grid as well
Returns
-------
points : torch.tensor
Pixel-wise 3D points [B,3,H,W]
"""
# If depth is a list, return each reconstruction
if is_list(depth):
return [self.reconstruct(d, frame, scene_flow, return_grid) for d in depth]
# Dimension assertions
assert depth.dim() == 4 and depth.shape[1] == 1, \
'Wrong dimensions for camera reconstruction'
# Create flat index grid [B,3,H,W]
B, _, H, W = depth.shape
grid = pixel_grid((H, W), B, device=depth.device, normalize=False, with_ones=True)
flat_grid = grid.view(B, 3, -1)
# Get inverse intrinsics
Kinv = self.Kinv if self.hw is None else self.scaled_Kinv(depth.shape)
# Estimate the outward rays in the camera frame
Xnorm = (Kinv.bmm(flat_grid)).view(B, 3, H, W)
# Scale rays to metric depth
Xc = Xnorm * depth
# Add scene flow if provided
if scene_flow is not None:
Xc = Xc + scene_flow
# If in camera frame of reference
if frame == 'c':
pass
# If in world frame of reference
elif frame == 'w':
Xc = self.Twc @ Xc
# If none of the above
else:
raise ValueError('Unknown reference frame {}'.format(frame))
# Return points and grid if requested
return (Xc, grid) if return_grid else Xc
def project(self, X, frame='w', normalize=True, return_z=False):
"""
Projects 3D points onto the image plane
Parameters
----------
X : torch.Tensor
3D points to be projected [B,3,H,W]
frame : 'w'
Reference frame: 'c' for camera and 'w' for world
normalize : bool
Normalize grid coordinates
return_z : bool
Return the projected z coordinate as well
Returns
-------
points : torch.Tensor
2D projected points that are within the image boundaries [B,H,W,2]
"""
assert 2 < X.dim() <= 4 and X.shape[1] == 3, \
'Wrong dimensions for camera projection'
# Determine if input is a grid
is_grid = X.dim() == 4
# If it's a grid, flatten it
X_flat = X.view(X.shape[0], 3, -1) if is_grid else X
# Get dimensions
hw = X.shape[2:] if is_grid else self.hw
# Get intrinsics
K = self.scaled_K(X.shape) if is_grid else self.K
# Project 3D points onto the camera image plane
if frame == 'c':
Xc = K.bmm(X_flat)
elif frame == 'w':
Xc = K.bmm(self.Tcw @ X_flat)
else:
raise ValueError('Unknown reference frame {}'.format(frame))
# Extract coordinates
Z = Xc[:, 2].clamp(min=1e-5)
XZ = Xc[:, 0] / Z
YZ = Xc[:, 1] / Z
# Normalize points
if normalize and hw is not None:
XZ = 2 * XZ / (hw[1] - 1) - 1.
YZ = 2 * YZ / (hw[0] - 1) - 1.
# Clamp out-of-bounds pixels
Xmask = ((XZ > 1) + (XZ < -1)).detach()
XZ[Xmask] = 2.
Ymask = ((XZ > 1) + (YZ < -1)).detach()
YZ[Ymask] = 2.
# Stack X and Y coordinates
XY = torch.stack([XZ, YZ], dim=-1)
# Reshape coordinates to a grid if possible
if is_grid and hw is not None:
XY = XY.view(X.shape[0], hw[0], hw[1], 2)
# If also returning depth
if return_z:
# Reshape depth values to a grid if possible
if is_grid and hw is not None:
Z = Z.view(X.shape[0], hw[0], hw[1], 1).permute(0, 3, 1, 2)
# Otherwise, reshape to an array
else:
Z = Z.view(X.shape[0], -1, 1).permute(0, 2, 1)
# Return coordinates and depth values
return XY, Z
else:
# Return coordinates
return XY