Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,550 Bytes
184193d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import torch
import torch.nn.functional as F
from einops import rearrange
# --- Intrinsics Transformations ---
def normalize_intrinsics(intrinsics, image_shape):
'''Normalize an intrinsics matrix given the image shape'''
intrinsics = intrinsics.clone()
intrinsics[..., 0, :] /= image_shape[1]
intrinsics[..., 1, :] /= image_shape[0]
return intrinsics
def unnormalize_intrinsics(intrinsics, image_shape):
'''Unnormalize an intrinsics matrix given the image shape'''
intrinsics = intrinsics.clone()
intrinsics[..., 0, :] *= image_shape[1]
intrinsics[..., 1, :] *= image_shape[0]
return intrinsics
# --- Projections ---
def homogenize_points(points):
"""Append a '1' along the final dimension of the tensor (i.e. convert xyz->xyz1)"""
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
def normalize_homogenous_points(points):
"""Normalize the point vectors"""
return points / points[..., -1:]
def pixel_space_to_camera_space(pixel_space_points, depth, intrinsics):
"""
Convert pixel space points to camera space points.
Args:
pixel_space_points (torch.Tensor): Pixel space points with shape (h, w, 2)
depth (torch.Tensor): Depth map with shape (b, v, h, w, 1)
intrinsics (torch.Tensor): Camera intrinsics with shape (b, v, 3, 3)
Returns:
torch.Tensor: Camera space points with shape (b, v, h, w, 3).
"""
pixel_space_points = homogenize_points(pixel_space_points)
camera_space_points = torch.einsum('b v i j , h w j -> b v h w i', intrinsics.inverse(), pixel_space_points)
camera_space_points = camera_space_points * depth
return camera_space_points
def camera_space_to_world_space(camera_space_points, c2w):
"""
Convert camera space points to world space points.
Args:
camera_space_points (torch.Tensor): Camera space points with shape (b, v, h, w, 3)
c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v, 4, 4)
Returns:
torch.Tensor: World space points with shape (b, v, h, w, 3).
"""
camera_space_points = homogenize_points(camera_space_points)
world_space_points = torch.einsum('b v i j , b v h w j -> b v h w i', c2w, camera_space_points)
return world_space_points[..., :3]
def camera_space_to_pixel_space(camera_space_points, intrinsics):
"""
Convert camera space points to pixel space points.
Args:
camera_space_points (torch.Tensor): Camera space points with shape (b, v1, v2, h, w, 3)
c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 3, 3)
Returns:
torch.Tensor: World space points with shape (b, v1, v2, h, w, 2).
"""
camera_space_points = normalize_homogenous_points(camera_space_points)
pixel_space_points = torch.einsum('b u i j , b v u h w j -> b v u h w i', intrinsics, camera_space_points)
return pixel_space_points[..., :2]
def world_space_to_camera_space(world_space_points, c2w):
"""
Convert world space points to pixel space points.
Args:
world_space_points (torch.Tensor): World space points with shape (b, v1, h, w, 3)
c2w (torch.Tensor): Camera to world extrinsics matrix with shape (b, v2, 4, 4)
Returns:
torch.Tensor: Camera space points with shape (b, v1, v2, h, w, 3).
"""
world_space_points = homogenize_points(world_space_points)
camera_space_points = torch.einsum('b u i j , b v h w j -> b v u h w i', c2w.inverse(), world_space_points)
return camera_space_points[..., :3]
def unproject_depth(depth, intrinsics, c2w):
"""
Turn the depth map into a 3D point cloud in world space
Args:
depth: (b, v, h, w, 1)
intrinsics: (b, v, 3, 3)
c2w: (b, v, 4, 4)
Returns:
torch.Tensor: World space points with shape (b, v, h, w, 3).
"""
# Compute indices of pixels
h, w = depth.shape[-3], depth.shape[-2]
x_grid, y_grid = torch.meshgrid(
torch.arange(w, device=depth.device, dtype=torch.float32),
torch.arange(h, device=depth.device, dtype=torch.float32),
indexing='xy'
) # (h, w), (h, w)
# Compute coordinates of pixels in camera space
pixel_space_points = torch.stack((x_grid, y_grid), dim=-1) # (..., h, w, 2)
camera_points = pixel_space_to_camera_space(pixel_space_points, depth, intrinsics) # (..., h, w, 3)
# Convert points to world space
world_points = camera_space_to_world_space(camera_points, c2w) # (..., h, w, 3)
return world_points
@torch.no_grad()
def calculate_in_frustum_mask(depth_1, intrinsics_1, c2w_1, depth_2, intrinsics_2, c2w_2, depth_tolerance=1e-1):
"""
A function that takes in the depth, intrinsics and c2w matrices of two sets
of views, and then works out which of the pixels in the first set of views
has a direct corresponding pixel in any of views in the second set
Args:
depth_1: (b, v1, h, w)
intrinsics_1: (b, v1, 3, 3)
c2w_1: (b, v1, 4, 4)
depth_2: (b, v2, h, w)
intrinsics_2: (b, v2, 3, 3)
c2w_2: (b, v2, 4, 4)
Returns:
torch.Tensor: Mask with shape (b, v1, h, w).
"""
_, v1, h, w = depth_1.shape
_, v2, _, _ = depth_2.shape
# unnormalize intrinsics if needed
if intrinsics_1[0, 0, 0, 2] < 1:
intrinsics_1 = unnormalize_intrinsics(intrinsics_1, (h, w))
if intrinsics_2[0, 0, 0, 2] < 1:
intrinsics_2 = unnormalize_intrinsics(intrinsics_2, (h, w))
# Unproject the depth to get the 3D points in world space
points_3d = unproject_depth(depth_1[..., None], intrinsics_1, c2w_1) # (b, v1, h, w, 3)
# Project the 3D points into the pixel space of all the second views simultaneously
camera_points = world_space_to_camera_space(points_3d, c2w_2) # (b, v1, v2, h, w, 3)
points_2d = camera_space_to_pixel_space(camera_points, intrinsics_2) # (b, v1, v2, h, w, 2)
# Calculate the depth of each point
rendered_depth = camera_points[..., 2] # (b, v1, v2, h, w)
# We use three conditions to determine if a point should be masked
# Condition 1: Check if the points are in the frustum of any of the v2 views
in_frustum_mask = (
(points_2d[..., 0] > 0) &
(points_2d[..., 0] < w) &
(points_2d[..., 1] > 0) &
(points_2d[..., 1] < h)
) # (b, v1, v2, h, w)
in_frustum_mask = in_frustum_mask.any(dim=-3) # (b, v1, h, w)
# Condition 2: Check if the points have non-zero (i.e. valid) depth in the input view
non_zero_depth = depth_1 > 1e-6
# Condition 3: Check if the points have matching depth to any of the v2
# views F.grid_sample expects the input coordinates to
# be normalized to the range [-1, 1], so we normalize first
points_2d[..., 0] /= w
points_2d[..., 1] /= h
points_2d = points_2d * 2 - 1
matching_depth = torch.ones_like(rendered_depth, dtype=torch.bool)
for b in range(depth_1.shape[0]):
for i in range(v1):
for j in range(v2):
depth = rearrange(depth_2[b, j], 'h w -> 1 1 h w')
coords = rearrange(points_2d[b, i, j], 'h w c -> 1 h w c')
sampled_depths = F.grid_sample(depth, coords, align_corners=False)[0, 0]
matching_depth[b, i, j] = torch.isclose(rendered_depth[b, i, j], sampled_depths, atol=depth_tolerance)
matching_depth = matching_depth.any(dim=-3) # (..., v1, h, w)
mask = in_frustum_mask & non_zero_depth & matching_depth
return mask |