define-hf-demo / vidar /geometry /camera_full.py
Jiading Fang
add define
fc16538
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import torch
from torch_scatter import scatter_min
from vidar.geometry.camera import Camera
from vidar.utils.tensor import unnorm_pixel_grid
class CameraFull(Camera):
"""Camera class with additional functionality"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.convert_matrix = torch.tensor(
[[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]],
dtype=torch.float32,
).unsqueeze(0)
@staticmethod
def from_list(cams):
"""Create cameras from a list"""
K = torch.cat([cam.K for cam in cams], 0)
Twc = torch.cat([cam.Twc.T for cam in cams], 0)
return CameraFull(K=K, Twc=Twc, hw=cams[0].hw)
def switch(self):
"""Switch camera between conventions"""
T = self.convert_matrix.to(self.device)
Twc = T @ self.Twc.T @ T
return type(self)(K=self.K, Twc=Twc, hw=self.hw)
def bwd(self):
"""Switch camera to the backwards convention"""
T = self.convert_matrix.to(self.device)
Tcw = T @ self.Twc.T @ T
return type(self)(K=self.K, Tcw=Tcw, hw=self.hw)
def fwd(self):
"""Switch camera to the forward convention"""
T = self.convert_matrix.to(self.device)
Twc = T @ self.Tcw.T @ T
return type(self)(K=self.K, Twc=Twc, hw=self.hw)
def look_at(self, at, up=torch.Tensor([0, 1, 0])):
"""
Set a direction for the camera to point (in-place)
Parameters
----------
at : torch.Tensor
Where the camera should be pointing at [B,3]
up : torch.Tensor
Up direction [B,3]
"""
eps = 1e-5
eye = self.Tcw.T[:, :3, -1]
at = at.unsqueeze(0)
up = up.unsqueeze(0).to(at.device)
z_axis = at - eye
z_axis /= z_axis.norm(dim=-1, keepdim=True) + eps
up = up.expand(z_axis.shape)
x_axis = torch.cross(up, z_axis)
x_axis /= x_axis.norm(dim=-1, keepdim=True) + eps
y_axis = torch.cross(z_axis, x_axis)
y_axis /= y_axis.norm(dim=-1, keepdim=True) + eps
R = torch.stack((x_axis, y_axis, z_axis), dim=-1)
Tcw = self.Tcw
Tcw.T[:, :3, :3] = R
self.Twc = Tcw.inverse()
def get_origin(self, flatten=False):
"""Return camera origin"""
orig = self.Tcw.T[:, :3, -1].view(len(self), 3, 1, 1).repeat(1, 1, *self.hw)
if flatten:
orig = orig.reshape(len(self), 3, -1).permute(0, 2, 1)
return orig
def get_viewdirs(self, normalize=False, flatten=False, to_world=False):
"""Return camera viewing rays"""
ones = torch.ones((len(self), 1, *self.hw), dtype=self.dtype, device=self.device)
rays = self.reconstruct_depth_map(ones, to_world=False)
if normalize:
rays = rays / torch.norm(rays, dim=1).unsqueeze(1)
if to_world:
rays = self.to_world(rays).reshape(len(self), 3, *self.hw)
if flatten:
rays = rays.reshape(len(self), 3, -1).permute(0, 2, 1)
return rays
def get_render_rays(self, near=None, far=None, n_rays=None, gt=None):
"""
Get render rays
Parameters
----------
near : Float
Near plane
far : Float
Far plane
n_rays : Int
Number of rays
gt : torch.Tensor
Ground-truth values for concatenation
Returns
-------
rays : torch.Tensor
Camera viewing rays
"""
b = len(self)
ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device)
rays = self.reconstruct_depth_map(ones, to_world=False)
rays = rays / torch.norm(rays, dim=1).unsqueeze(1)
rays[:, 1] = - rays[:, 1]
rays[:, 2] = - rays[:, 2]
orig = self.pose[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw)
rays = self.no_translation().inverted_pose().to_world(rays).reshape(b, 3, *self.hw)
info = [orig, rays]
if near is not None:
info = info + [near * ones]
if far is not None:
info = info + [far * ones]
if gt is not None:
info = info + [gt]
rays = torch.cat(info, 1)
rays = rays.permute(0, 2, 3, 1).reshape(b, -1, rays.shape[1])
if n_rays is not None:
idx = torch.randint(0, self.n_pixels, (n_rays,))
rays = rays[:, idx, :]
return rays
def get_plucker(self):
"""Get plucker vectors"""
b = len(self)
ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device)
rays = self.reconstruct_depth_map(ones, to_world=False)
rays = rays / torch.norm(rays, dim=1).unsqueeze(1)
orig = self.Tcw.T[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw)
orig = orig.view(1, 3, -1).permute(0, 2, 1)
rays = rays.view(1, 3, -1).permute(0, 2, 1)
cross = torch.cross(orig, rays, dim=-1)
plucker = torch.cat((rays, cross), dim=-1)
return plucker
def project_pointcloud(self, pcl_src, rgb_src, thr=1):
"""
Project pointcloud to the camera plane
Parameters
----------
pcl_src : torch.Tensor
Input 3D pointcloud
rgb_src : torch.Tensor
Pointcloud color information
thr : Int
Threshold for the number of valid points
Returns
-------
rgb_tgt : torch.Tensor
Projected image [B,3,H,W]
depth_tgt : torch.Tensor
Projected depth map [B,1,H,W]
"""
if rgb_src.dim() == 4:
rgb_src = rgb_src.view(*rgb_src.shape[:2], -1)
# Get projected coordinates and depth values
uv_all, z_all = self.project_points(pcl_src, return_z=True, from_world=True)
rgbs_tgt, depths_tgt = [], []
b = pcl_src.shape[0]
for i in range(b):
uv, z = uv_all[i].reshape(-1, 2), z_all[i].reshape(-1, 1)
# Remove out-of-bounds coordinates and points behind the camera
idx = (uv[:, 0] >= -1) & (uv[:, 0] <= 1) & \
(uv[:, 1] >= -1) & (uv[:, 1] <= 1) & (z[:, 0] > 0.0)
# Unormalize and stack coordinates for scatter operation
uv = (unnorm_pixel_grid(uv[idx], self.hw)).round().long()
uv = uv[:, 0] + uv[:, 1] * self.hw[1]
# Min scatter operation (only keep the closest depth)
depth_tgt = 1e10 * torch.ones((self.hw[0] * self.hw[1], 1), device=pcl_src.device)
depth_tgt, argmin = scatter_min(src=z[idx], index=uv.unsqueeze(1), dim=0, out=depth_tgt)
depth_tgt[depth_tgt == 1e10] = 0.
num_valid = (depth_tgt > 0).sum()
if num_valid > thr:
# Substitute invalid values with zero
invalid = argmin == argmin.max()
argmin[invalid] = 0
rgb_tgt = rgb_src[i].permute(1, 0)[idx][argmin]
rgb_tgt[invalid] = -1
else:
rgb_tgt = -1 * torch.ones(1, self.n_pixels, 3, device=self.device, dtype=self.dtype)
# Reshape outputs
rgb_tgt = rgb_tgt.reshape(1, self.hw[0], self.hw[1], 3).permute(0, 3, 1, 2)
depth_tgt = depth_tgt.reshape(1, 1, self.hw[0], self.hw[1])
rgbs_tgt.append(rgb_tgt)
depths_tgt.append(depth_tgt)
rgb_tgt = torch.cat(rgbs_tgt, 0)
depth_tgt = torch.cat(depths_tgt, 0)
return rgb_tgt, depth_tgt