Spaces:
Runtime error
Runtime error
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
import torch | |
from torch_scatter import scatter_min | |
from vidar.geometry.camera import Camera | |
from vidar.utils.tensor import unnorm_pixel_grid | |
class CameraFull(Camera): | |
"""Camera class with additional functionality""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.convert_matrix = torch.tensor( | |
[[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], | |
dtype=torch.float32, | |
).unsqueeze(0) | |
def from_list(cams): | |
"""Create cameras from a list""" | |
K = torch.cat([cam.K for cam in cams], 0) | |
Twc = torch.cat([cam.Twc.T for cam in cams], 0) | |
return CameraFull(K=K, Twc=Twc, hw=cams[0].hw) | |
def switch(self): | |
"""Switch camera between conventions""" | |
T = self.convert_matrix.to(self.device) | |
Twc = T @ self.Twc.T @ T | |
return type(self)(K=self.K, Twc=Twc, hw=self.hw) | |
def bwd(self): | |
"""Switch camera to the backwards convention""" | |
T = self.convert_matrix.to(self.device) | |
Tcw = T @ self.Twc.T @ T | |
return type(self)(K=self.K, Tcw=Tcw, hw=self.hw) | |
def fwd(self): | |
"""Switch camera to the forward convention""" | |
T = self.convert_matrix.to(self.device) | |
Twc = T @ self.Tcw.T @ T | |
return type(self)(K=self.K, Twc=Twc, hw=self.hw) | |
def look_at(self, at, up=torch.Tensor([0, 1, 0])): | |
""" | |
Set a direction for the camera to point (in-place) | |
Parameters | |
---------- | |
at : torch.Tensor | |
Where the camera should be pointing at [B,3] | |
up : torch.Tensor | |
Up direction [B,3] | |
""" | |
eps = 1e-5 | |
eye = self.Tcw.T[:, :3, -1] | |
at = at.unsqueeze(0) | |
up = up.unsqueeze(0).to(at.device) | |
z_axis = at - eye | |
z_axis /= z_axis.norm(dim=-1, keepdim=True) + eps | |
up = up.expand(z_axis.shape) | |
x_axis = torch.cross(up, z_axis) | |
x_axis /= x_axis.norm(dim=-1, keepdim=True) + eps | |
y_axis = torch.cross(z_axis, x_axis) | |
y_axis /= y_axis.norm(dim=-1, keepdim=True) + eps | |
R = torch.stack((x_axis, y_axis, z_axis), dim=-1) | |
Tcw = self.Tcw | |
Tcw.T[:, :3, :3] = R | |
self.Twc = Tcw.inverse() | |
def get_origin(self, flatten=False): | |
"""Return camera origin""" | |
orig = self.Tcw.T[:, :3, -1].view(len(self), 3, 1, 1).repeat(1, 1, *self.hw) | |
if flatten: | |
orig = orig.reshape(len(self), 3, -1).permute(0, 2, 1) | |
return orig | |
def get_viewdirs(self, normalize=False, flatten=False, to_world=False): | |
"""Return camera viewing rays""" | |
ones = torch.ones((len(self), 1, *self.hw), dtype=self.dtype, device=self.device) | |
rays = self.reconstruct_depth_map(ones, to_world=False) | |
if normalize: | |
rays = rays / torch.norm(rays, dim=1).unsqueeze(1) | |
if to_world: | |
rays = self.to_world(rays).reshape(len(self), 3, *self.hw) | |
if flatten: | |
rays = rays.reshape(len(self), 3, -1).permute(0, 2, 1) | |
return rays | |
def get_render_rays(self, near=None, far=None, n_rays=None, gt=None): | |
""" | |
Get render rays | |
Parameters | |
---------- | |
near : Float | |
Near plane | |
far : Float | |
Far plane | |
n_rays : Int | |
Number of rays | |
gt : torch.Tensor | |
Ground-truth values for concatenation | |
Returns | |
------- | |
rays : torch.Tensor | |
Camera viewing rays | |
""" | |
b = len(self) | |
ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device) | |
rays = self.reconstruct_depth_map(ones, to_world=False) | |
rays = rays / torch.norm(rays, dim=1).unsqueeze(1) | |
rays[:, 1] = - rays[:, 1] | |
rays[:, 2] = - rays[:, 2] | |
orig = self.pose[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw) | |
rays = self.no_translation().inverted_pose().to_world(rays).reshape(b, 3, *self.hw) | |
info = [orig, rays] | |
if near is not None: | |
info = info + [near * ones] | |
if far is not None: | |
info = info + [far * ones] | |
if gt is not None: | |
info = info + [gt] | |
rays = torch.cat(info, 1) | |
rays = rays.permute(0, 2, 3, 1).reshape(b, -1, rays.shape[1]) | |
if n_rays is not None: | |
idx = torch.randint(0, self.n_pixels, (n_rays,)) | |
rays = rays[:, idx, :] | |
return rays | |
def get_plucker(self): | |
"""Get plucker vectors""" | |
b = len(self) | |
ones = torch.ones((b, 1, *self.hw), dtype=self.dtype, device=self.device) | |
rays = self.reconstruct_depth_map(ones, to_world=False) | |
rays = rays / torch.norm(rays, dim=1).unsqueeze(1) | |
orig = self.Tcw.T[:, :3, -1].view(b, 3, 1, 1).repeat(1, 1, *self.hw) | |
orig = orig.view(1, 3, -1).permute(0, 2, 1) | |
rays = rays.view(1, 3, -1).permute(0, 2, 1) | |
cross = torch.cross(orig, rays, dim=-1) | |
plucker = torch.cat((rays, cross), dim=-1) | |
return plucker | |
def project_pointcloud(self, pcl_src, rgb_src, thr=1): | |
""" | |
Project pointcloud to the camera plane | |
Parameters | |
---------- | |
pcl_src : torch.Tensor | |
Input 3D pointcloud | |
rgb_src : torch.Tensor | |
Pointcloud color information | |
thr : Int | |
Threshold for the number of valid points | |
Returns | |
------- | |
rgb_tgt : torch.Tensor | |
Projected image [B,3,H,W] | |
depth_tgt : torch.Tensor | |
Projected depth map [B,1,H,W] | |
""" | |
if rgb_src.dim() == 4: | |
rgb_src = rgb_src.view(*rgb_src.shape[:2], -1) | |
# Get projected coordinates and depth values | |
uv_all, z_all = self.project_points(pcl_src, return_z=True, from_world=True) | |
rgbs_tgt, depths_tgt = [], [] | |
b = pcl_src.shape[0] | |
for i in range(b): | |
uv, z = uv_all[i].reshape(-1, 2), z_all[i].reshape(-1, 1) | |
# Remove out-of-bounds coordinates and points behind the camera | |
idx = (uv[:, 0] >= -1) & (uv[:, 0] <= 1) & \ | |
(uv[:, 1] >= -1) & (uv[:, 1] <= 1) & (z[:, 0] > 0.0) | |
# Unormalize and stack coordinates for scatter operation | |
uv = (unnorm_pixel_grid(uv[idx], self.hw)).round().long() | |
uv = uv[:, 0] + uv[:, 1] * self.hw[1] | |
# Min scatter operation (only keep the closest depth) | |
depth_tgt = 1e10 * torch.ones((self.hw[0] * self.hw[1], 1), device=pcl_src.device) | |
depth_tgt, argmin = scatter_min(src=z[idx], index=uv.unsqueeze(1), dim=0, out=depth_tgt) | |
depth_tgt[depth_tgt == 1e10] = 0. | |
num_valid = (depth_tgt > 0).sum() | |
if num_valid > thr: | |
# Substitute invalid values with zero | |
invalid = argmin == argmin.max() | |
argmin[invalid] = 0 | |
rgb_tgt = rgb_src[i].permute(1, 0)[idx][argmin] | |
rgb_tgt[invalid] = -1 | |
else: | |
rgb_tgt = -1 * torch.ones(1, self.n_pixels, 3, device=self.device, dtype=self.dtype) | |
# Reshape outputs | |
rgb_tgt = rgb_tgt.reshape(1, self.hw[0], self.hw[1], 3).permute(0, 3, 1, 2) | |
depth_tgt = depth_tgt.reshape(1, 1, self.hw[0], self.hw[1]) | |
rgbs_tgt.append(rgb_tgt) | |
depths_tgt.append(depth_tgt) | |
rgb_tgt = torch.cat(rgbs_tgt, 0) | |
depth_tgt = torch.cat(depths_tgt, 0) | |
return rgb_tgt, depth_tgt | |