Spaces:
Runtime error
Runtime error
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
from functools import lru_cache | |
import torch | |
import torch.nn as nn | |
from vidar.arch.networks.layers.fsm.camera_utils import scale_intrinsics, invert_intrinsics | |
from vidar.arch.networks.layers.fsm.pose import Pose | |
from vidar.utils.tensor import pixel_grid | |
from vidar.utils.types import is_tensor, is_list | |
class Camera(nn.Module): | |
""" | |
Differentiable camera class implementing reconstruction and projection | |
functions for a pinhole model. | |
""" | |
def __init__(self, K, Tcw=None, Twc=None, hw=None): | |
""" | |
Initializes the Camera class | |
Parameters | |
---------- | |
K : torch.Tensor | |
Camera intrinsics [B,3,3] | |
Tcw : Pose or torch.Tensor | |
Camera -> World pose transformation [B,4,4] | |
Twc : Pose or torch.Tensor | |
World -> Camera pose transformation [B,4,4] | |
hw : tuple or torch.Tensor | |
Camera width and height, or a tensor with the proper shape | |
""" | |
super().__init__() | |
assert Tcw is None or Twc is None, 'You should provide either Tcw or Twc' | |
self.K = K | |
self.hw = None if hw is None else hw.shape[-2:] if is_tensor(hw) else hw[-2:] | |
if Tcw is not None: | |
self.Tcw = Tcw if isinstance(Tcw, Pose) else Pose(Tcw) | |
elif Twc is not None: | |
self.Tcw = Twc.inverse() if isinstance(Twc, Pose) else Pose(Twc).inverse() | |
else: | |
self.Tcw = Pose.identity(len(self.K)) | |
def __len__(self): | |
"""Batch size of the camera intrinsics""" | |
return len(self.K) | |
def __getitem__(self, idx): | |
"""Return single camera from a batch position""" | |
return Camera(K=self.K[idx].unsqueeze(0), | |
hw=self.hw, Tcw=self.Tcw[idx]).to(self.device) | |
def wh(self): | |
"""Return camera width and height""" | |
return None if self.hw is None else self.hw[::-1] | |
def pose(self): | |
"""Return camera pose""" | |
return self.Twc.mat | |
def device(self): | |
"""Return camera device""" | |
return self.K.device | |
def invert_pose(self): | |
"""Return new camera with inverted pose""" | |
return Camera(K=self.K, Tcw=self.Twc) | |
def to(self, *args, **kwargs): | |
"""Moves object to a specific device""" | |
self.K = self.K.to(*args, **kwargs) | |
self.Tcw = self.Tcw.to(*args, **kwargs) | |
return self | |
def fx(self): | |
"""Focal length in x""" | |
return self.K[:, 0, 0] | |
def fy(self): | |
"""Focal length in y""" | |
return self.K[:, 1, 1] | |
def cx(self): | |
"""Principal point in x""" | |
return self.K[:, 0, 2] | |
def cy(self): | |
"""Principal point in y""" | |
return self.K[:, 1, 2] | |
def Twc(self): | |
"""World -> Camera pose transformation (inverse of Tcw)""" | |
return self.Tcw.inverse() | |
def Kinv(self): | |
"""Inverse intrinsics (for lifting)""" | |
return invert_intrinsics(self.K) | |
def equal(self, cam): | |
"""Check if two cameras are the same""" | |
return torch.allclose(self.K, cam.K) and \ | |
torch.allclose(self.Tcw.mat, cam.Tcw.mat) | |
def scaled(self, x_scale, y_scale=None): | |
""" | |
Returns a scaled version of the camera (changing intrinsics) | |
Parameters | |
---------- | |
x_scale : float | |
Resize scale in x | |
y_scale : float | |
Resize scale in y. If None, use the same as x_scale | |
Returns | |
------- | |
camera : Camera | |
Scaled version of the current camera | |
""" | |
# If single value is provided, use for both dimensions | |
if y_scale is None: | |
y_scale = x_scale | |
# If no scaling is necessary, return same camera | |
if x_scale == 1. and y_scale == 1.: | |
return self | |
# Scale intrinsics | |
K = scale_intrinsics(self.K.clone(), x_scale, y_scale) | |
# Scale image dimensions | |
hw = None if self.hw is None else (int(self.hw[0] * y_scale), | |
int(self.hw[1] * x_scale)) | |
# Return scaled camera | |
return Camera(K=K, Tcw=self.Tcw, hw=hw) | |
def scaled_K(self, shape): | |
"""Return scaled intrinsics to match a shape""" | |
if self.hw is None: | |
return self.K | |
else: | |
y_scale, x_scale = [sh / hw for sh, hw in zip(shape[-2:], self.hw)] | |
return scale_intrinsics(self.K, x_scale, y_scale) | |
def scaled_Kinv(self, shape): | |
"""Return scaled inverse intrinsics to match a shape""" | |
return invert_intrinsics(self.scaled_K(shape)) | |
def reconstruct(self, depth, frame='w', scene_flow=None, return_grid=False): | |
""" | |
Reconstructs pixel-wise 3D points from a depth map. | |
Parameters | |
---------- | |
depth : torch.Tensor | |
Depth map for the camera [B,1,H,W] | |
frame : 'w' | |
Reference frame: 'c' for camera and 'w' for world | |
scene_flow : torch.Tensor | |
Optional per-point scene flow to be added (camera reference frame) [B,3,H,W] | |
return_grid : bool | |
Return pixel grid as well | |
Returns | |
------- | |
points : torch.tensor | |
Pixel-wise 3D points [B,3,H,W] | |
""" | |
# If depth is a list, return each reconstruction | |
if is_list(depth): | |
return [self.reconstruct(d, frame, scene_flow, return_grid) for d in depth] | |
# Dimension assertions | |
assert depth.dim() == 4 and depth.shape[1] == 1, \ | |
'Wrong dimensions for camera reconstruction' | |
# Create flat index grid [B,3,H,W] | |
B, _, H, W = depth.shape | |
grid = pixel_grid((H, W), B, device=depth.device, normalize=False, with_ones=True) | |
flat_grid = grid.view(B, 3, -1) | |
# Get inverse intrinsics | |
Kinv = self.Kinv if self.hw is None else self.scaled_Kinv(depth.shape) | |
# Estimate the outward rays in the camera frame | |
Xnorm = (Kinv.bmm(flat_grid)).view(B, 3, H, W) | |
# Scale rays to metric depth | |
Xc = Xnorm * depth | |
# Add scene flow if provided | |
if scene_flow is not None: | |
Xc = Xc + scene_flow | |
# If in camera frame of reference | |
if frame == 'c': | |
pass | |
# If in world frame of reference | |
elif frame == 'w': | |
Xc = self.Twc @ Xc | |
# If none of the above | |
else: | |
raise ValueError('Unknown reference frame {}'.format(frame)) | |
# Return points and grid if requested | |
return (Xc, grid) if return_grid else Xc | |
def project(self, X, frame='w', normalize=True, return_z=False): | |
""" | |
Projects 3D points onto the image plane | |
Parameters | |
---------- | |
X : torch.Tensor | |
3D points to be projected [B,3,H,W] | |
frame : 'w' | |
Reference frame: 'c' for camera and 'w' for world | |
normalize : bool | |
Normalize grid coordinates | |
return_z : bool | |
Return the projected z coordinate as well | |
Returns | |
------- | |
points : torch.Tensor | |
2D projected points that are within the image boundaries [B,H,W,2] | |
""" | |
assert 2 < X.dim() <= 4 and X.shape[1] == 3, \ | |
'Wrong dimensions for camera projection' | |
# Determine if input is a grid | |
is_grid = X.dim() == 4 | |
# If it's a grid, flatten it | |
X_flat = X.view(X.shape[0], 3, -1) if is_grid else X | |
# Get dimensions | |
hw = X.shape[2:] if is_grid else self.hw | |
# Get intrinsics | |
K = self.scaled_K(X.shape) if is_grid else self.K | |
# Project 3D points onto the camera image plane | |
if frame == 'c': | |
Xc = K.bmm(X_flat) | |
elif frame == 'w': | |
Xc = K.bmm(self.Tcw @ X_flat) | |
else: | |
raise ValueError('Unknown reference frame {}'.format(frame)) | |
# Extract coordinates | |
Z = Xc[:, 2].clamp(min=1e-5) | |
XZ = Xc[:, 0] / Z | |
YZ = Xc[:, 1] / Z | |
# Normalize points | |
if normalize and hw is not None: | |
XZ = 2 * XZ / (hw[1] - 1) - 1. | |
YZ = 2 * YZ / (hw[0] - 1) - 1. | |
# Clamp out-of-bounds pixels | |
Xmask = ((XZ > 1) + (XZ < -1)).detach() | |
XZ[Xmask] = 2. | |
Ymask = ((XZ > 1) + (YZ < -1)).detach() | |
YZ[Ymask] = 2. | |
# Stack X and Y coordinates | |
XY = torch.stack([XZ, YZ], dim=-1) | |
# Reshape coordinates to a grid if possible | |
if is_grid and hw is not None: | |
XY = XY.view(X.shape[0], hw[0], hw[1], 2) | |
# If also returning depth | |
if return_z: | |
# Reshape depth values to a grid if possible | |
if is_grid and hw is not None: | |
Z = Z.view(X.shape[0], hw[0], hw[1], 1).permute(0, 3, 1, 2) | |
# Otherwise, reshape to an array | |
else: | |
Z = Z.view(X.shape[0], -1, 1).permute(0, 2, 1) | |
# Return coordinates and depth values | |
return XY, Z | |
else: | |
# Return coordinates | |
return XY | |