Spaces:
Runtime error
Runtime error
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
import torch | |
import torch.nn.functional as tfn | |
from vidar.utils.data import make_list | |
from vidar.utils.flow_triangulation_support import bearing_grid, mult_rotation_bearing, triangulation | |
from vidar.utils.tensor import pixel_grid, norm_pixel_grid, unnorm_pixel_grid | |
from vidar.utils.types import is_list | |
def warp_from_coords(tensor, coords, mode='bilinear', | |
padding_mode='zeros', align_corners=True): | |
""" | |
Warp an image from a coordinate map | |
Parameters | |
---------- | |
tensor : torch.Tensor | |
Input tensor for warping [B,?,H,W] | |
coords : torch.Tensor | |
Warping coordinates [B,2,H,W] | |
mode : String | |
Warping mode | |
padding_mode : String | |
Padding mode | |
align_corners : Bool | |
Align corners flag | |
Returns | |
------- | |
warp : torch.Tensor | |
Warped tensor [B,?,H,W] | |
""" | |
# Sample grid from data with coordinates | |
warp = tfn.grid_sample(tensor, coords.permute(0, 2, 3, 1), | |
mode=mode, padding_mode=padding_mode, | |
align_corners=align_corners) | |
# Returned warped tensor | |
return warp | |
def coords_from_optical_flow(optflow): | |
""" | |
Get warping coordinates from optical flow | |
Parameters | |
---------- | |
optflow : torch.Tensor | |
Input optical flow tensor [B,2,H,W] | |
Returns | |
------- | |
coords : torch.Tensor | |
Warping coordinates [B,2,H,W] | |
""" | |
# Create coordinate with optical flow | |
coords = pixel_grid(optflow, device=optflow) + optflow | |
# Normalize and return coordinate grid | |
return norm_pixel_grid(coords) | |
def warp_depth_from_motion(ref_depth, tgt_depth, ref_cam): | |
""" | |
Warp depth map using motion (depth + ego-motion) information | |
Parameters | |
---------- | |
ref_depth : torch.Tensor | |
Reference depth map [B,1,H,W] | |
tgt_depth : torch.Tensor | |
Target depth map [B,1,H,W] | |
ref_cam : Camera | |
Reference camera | |
Returns | |
------- | |
warp : torch.Tensor | |
Warped depth map [B,1,H,W] | |
""" | |
ref_depth = reproject_depth_from_motion(ref_depth, ref_cam) | |
return warp_from_motion(ref_depth, tgt_depth, ref_cam) | |
def reproject_depth_from_motion(ref_depth, ref_cam): | |
""" | |
Calculate reprojected depth from motion (depth + ego-motion) information | |
Parameters | |
---------- | |
ref_depth : torch.Tensor | |
Reference depth map [B,1,H,W] | |
ref_cam : Camera | |
Reference camera | |
Returns | |
------- | |
coords : torch.Tensor | |
Warping coordinates from reprojection [B,2,H,W] | |
""" | |
ref_points = ref_cam.reconstruct_depth_map(ref_depth, to_world=True) | |
return ref_cam.project_points(ref_points, from_world=False, return_z=True)[1] | |
def warp_from_motion(ref_rgb, tgt_depth, ref_cam): | |
""" | |
Warp image using motion (depth + ego-motion) information | |
Parameters | |
---------- | |
ref_rgb : torch.Tensor | |
Reference image [B,3,H,W] | |
tgt_depth : torch.Tensor | |
Target depth map [B,1,H,W] | |
ref_cam : Camera | |
Reference camera | |
Returns | |
------- | |
warp : torch.Tensor | |
Warped image [B,3,H,W] | |
""" | |
tgt_points = ref_cam.reconstruct_depth_map(tgt_depth, to_world=False) | |
return warp_from_coords(ref_rgb, ref_cam.project_points(tgt_points, from_world=True).permute(0, 3, 1, 2)) | |
def coords_from_motion(ref_camera, tgt_depth, tgt_camera): | |
""" | |
Get coordinates from motion (depth + ego-motion) information | |
Parameters | |
---------- | |
ref_camera : Camera | |
Reference camera | |
tgt_depth : torch.Tensor | |
Target depth map [B,1,H,W] | |
tgt_camera : Camera | |
Target camera | |
Returns | |
------- | |
coords : torch.Tensor | |
Warping coordinates [B,2,H,W] | |
""" | |
if is_list(ref_camera): | |
return [coords_from_motion(camera, tgt_depth, tgt_camera) | |
for camera in ref_camera] | |
# If there are multiple depth maps, iterate for each | |
if is_list(tgt_depth): | |
return [coords_from_motion(ref_camera, depth, tgt_camera) | |
for depth in tgt_depth] | |
world_points = tgt_camera.reconstruct_depth_map(tgt_depth, to_world=True) | |
return ref_camera.project_points(world_points, from_world=True).permute(0, 3, 1, 2) | |
def optflow_from_motion(ref_camera, tgt_depth): | |
""" | |
Get optical flow from motion (depth + ego-motion) information | |
Parameters | |
---------- | |
ref_camera : Camera | |
Reference camera | |
tgt_depth : torch.Tensor | |
Target depth map | |
Returns | |
------- | |
optflow : torch.Tensor | |
Optical flow map [B,2,H,W] | |
""" | |
coords = ref_camera.coords_from_depth(tgt_depth).permute(0, 3, 1, 2) | |
return optflow_from_coords(coords) | |
def optflow_from_coords(coords): | |
""" | |
Get optical flow from coordinates | |
Parameters | |
---------- | |
coords : torch.Tensor | |
Input warping coordinates [B,2,H,W] | |
Returns | |
------- | |
optflow : torch.Tensor | |
Optical flow map [B,2,H,W] | |
""" | |
return unnorm_pixel_grid(coords) - pixel_grid(coords, device=coords) | |
def warp_from_optflow(ref_rgb, tgt_optflow): | |
""" | |
Warp image using optical flow information | |
Parameters | |
---------- | |
ref_rgb : torch.Tensor | |
Reference image [B,3,H,W] | |
tgt_optflow : torch.Tensor | |
Target optical flow [B,2,H,W] | |
Returns | |
------- | |
warp : torch.Tensor | |
Warped image [B,3,H,W] | |
""" | |
coords = coords_from_optical_flow(tgt_optflow) | |
return warp_from_coords(ref_rgb, coords, align_corners=True, | |
mode='bilinear', padding_mode='zeros') | |
def reverse_optflow(tgt_optflow, ref_optflow): | |
""" | |
Reverse optical flow | |
Parameters | |
---------- | |
tgt_optflow : torch.Tensor | |
Target optical flow [B,2,H,W] | |
ref_optflow : torch.Tensor | |
Reference optical flow [B,2,H,W] | |
Returns | |
------- | |
optflow : torch.Tensor | |
Reversed optical flow [B,2,H,W] | |
""" | |
return - warp_from_optflow(tgt_optflow, ref_optflow) | |
def mask_from_coords(coords, align_corners=True): | |
""" | |
Get overlap mask from coordinates | |
Parameters | |
---------- | |
coords : torch.Tensor | |
Warping coordinates [B,2,H,W] | |
align_corners : Bool | |
Align corners flag | |
Returns | |
------- | |
mask : torch.Tensor | |
Overlap mask [B,1,H,W] | |
""" | |
if is_list(coords): | |
return [mask_from_coords(coord) for coord in coords] | |
b, _, h, w = coords.shape | |
mask = torch.ones((b, 1, h, w), dtype=torch.float32, device=coords.device, requires_grad=False) | |
mask = warp_from_coords(mask, coords, mode='nearest', padding_mode='zeros', align_corners=True) | |
return mask.bool() | |
def depth_from_optflow(rgb, intrinsics, pose_context, flows, | |
residual=False, clip_range=None): | |
""" | |
Get depth from optical flow + camera information | |
Parameters | |
---------- | |
rgb : torch.Tensor | |
Base image [B,3,H,W] | |
intrinsics : torch.Tensor | |
Camera intrinsics [B,3,3] | |
pose_context : torch.Tensor or list[torch.Tensor] | |
List of relative context camera poses [B,4,4] | |
flows : torch.Tensor or list[torch.Tensor] | |
List of target optical flows [B,2,H,W] | |
residual : Bool | |
Return residual error with depth | |
clip_range : Tuple | |
Depth range clipping values | |
Returns | |
------- | |
depth : torch.Tensor | |
Depth map [B,1,H,W] | |
""" | |
# Make lists if necessary | |
flows = make_list(flows) | |
pose_context = make_list(pose_context) | |
# Extract rotations and translations | |
rotations = [p[:, :3, :3] for p in pose_context] | |
translations = [p[:, :3, -1] for p in pose_context] | |
# Get bearings | |
bearings = bearing_grid(rgb, intrinsics).to(rgb.device) | |
rot_bearings = [mult_rotation_bearing(rotation, bearings) | |
for rotation in rotations] | |
# Return triangulation results | |
return triangulation(rot_bearings, translations, flows, intrinsics, | |
clip_range=clip_range, residual=residual) | |