Spaces:
Runtime error
Runtime error
#### Code taken from: https://github.com/mayankgrwl97/gbt | |
"""Utils for ray manipulation""" | |
import numpy as np | |
import torch | |
from pytorch3d.renderer.implicit.raysampling import RayBundle as RayBundle | |
from pytorch3d.renderer.camera_utils import join_cameras_as_batch | |
from pytorch3d.renderer.cameras import PerspectiveCameras | |
############################# RAY BUNDLE UTILITIES ############################# | |
def is_scalar(x): | |
"""Returns True if the provided variable is a scalar | |
Args: | |
x: scalar or array-like (numpy array or torch tensor) | |
Returns: | |
bool: True if x is of the type scalar, or array-like with 0 dimension. False, otherwise | |
""" | |
if isinstance(x, float) or isinstance(x, int): | |
return True | |
if isinstance(x, np.ndarray) and np.ndim(x) == 0: | |
return True | |
if isinstance(x, torch.Tensor) and x.dim() == 0: | |
return True | |
return False | |
def transform_rays(reference_R, reference_T, rays): | |
""" | |
PyTorch3D Convention is used: X_cam = X_world @ R + T | |
Args: | |
reference_R: world2cam rotation matrix for reference camera (B, 3, 3) | |
reference_T: world2cam translation vector for reference camera (B, 3) | |
rays: (origin, direction) defined in world reference frame (B, V, N, 6) | |
Returns: | |
torch.Tensor: Transformed rays w.r.t. reference camera (B, V, N, 6) | |
""" | |
batch, num_views, num_rays, ray_dim = rays.shape | |
assert ( | |
ray_dim == 6 | |
), "First 3 dimensions should be origin; Last 3 dimensions should be direction" | |
rays = rays.reshape(batch, num_views * num_rays, ray_dim) | |
rays_out = rays.clone() | |
rays_out[..., :3] = torch.bmm(rays[..., :3], reference_R) + reference_T.unsqueeze( | |
-2 | |
) | |
rays_out[..., 3:] = torch.bmm(rays[..., 3:], reference_R) | |
rays_out = rays_out.reshape(batch, num_views, num_rays, ray_dim) | |
return rays_out | |
def get_directional_raybundle(cameras, x_pos_ndc, y_pos_ndc, max_depth=1): | |
if is_scalar(x_pos_ndc): | |
x_pos_ndc = [x_pos_ndc] | |
if is_scalar(y_pos_ndc): | |
y_pos_ndc = [y_pos_ndc] | |
assert is_scalar(max_depth) | |
if not isinstance(x_pos_ndc, torch.Tensor): | |
x_pos_ndc = torch.tensor(x_pos_ndc) # (N, ) | |
if not isinstance(y_pos_ndc, torch.Tensor): | |
y_pos_ndc = torch.tensor(y_pos_ndc) # (N, ) | |
xy_depth = torch.stack( | |
(x_pos_ndc, y_pos_ndc, torch.ones_like(x_pos_ndc) * max_depth), dim=-1 | |
) # (N, 3) | |
num_points = xy_depth.shape[0] | |
unprojected = cameras.unproject_points( | |
xy_depth.to(cameras.device), world_coordinates=True, from_ndc=True | |
) # (N, 3) | |
unprojected = unprojected.unsqueeze(0).to("cpu") # (B, N, 3) | |
origins = ( | |
cameras.get_camera_center()[:, None, :].expand(-1, num_points, -1).to("cpu") | |
) # (B, N, 3) | |
directions = unprojected - origins # (B, N, 3) | |
directions = directions / directions.norm(dim=-1).unsqueeze(-1) # (B, N, 3) | |
lengths = ( | |
torch.tensor([[0, 3]]).unsqueeze(0).expand(-1, num_points, -1).to("cpu") | |
) # (B, N, 2) | |
xys = xy_depth[:, :2].unsqueeze(0).to("cpu") # (B, N, 2) | |
raybundle = RayBundle( | |
origins=origins.to("cpu"), | |
directions=directions.to("cpu"), | |
lengths=lengths.to("cpu"), | |
xys=xys.to("cpu"), | |
) | |
return raybundle | |
def get_patch_raybundle( | |
cameras, num_patches_x, num_patches_y, max_depth=1, stratified=False | |
): | |
horizontal_patch_edges = torch.linspace(1, -1, num_patches_x + 1) | |
# horizontal_positions = horizontal_patch_edges[:-1] # (num_patches_x,): Top left corner of patch | |
vertical_patch_edges = torch.linspace(1, -1, num_patches_y + 1) | |
# vertical_positions = vertical_patch_edges[:-1] # (num_patches_y,): Top left corner of patch | |
if stratified: | |
horizontal_patch_edges_center = ( | |
horizontal_patch_edges[..., 1:] + horizontal_patch_edges[..., :-1] | |
) / 2.0 | |
horizontal_patch_edges_upper = torch.cat( | |
[horizontal_patch_edges_center, horizontal_patch_edges[..., -1:]], -1 | |
) | |
horizontal_patch_edges_lower = torch.cat( | |
[horizontal_patch_edges[..., :1], horizontal_patch_edges_center], -1 | |
) | |
horizontal_positions = ( | |
horizontal_patch_edges_lower | |
+ (horizontal_patch_edges_upper - horizontal_patch_edges_lower) | |
* torch.rand_like(horizontal_patch_edges_lower) | |
)[..., :-1] | |
vertical_patch_edges_center = ( | |
vertical_patch_edges[..., 1:] + vertical_patch_edges[..., :-1] | |
) / 2.0 | |
vertical_patch_edges_upper = torch.cat( | |
[vertical_patch_edges_center, vertical_patch_edges[..., -1:]], -1 | |
) | |
vertical_patch_edges_lower = torch.cat( | |
[vertical_patch_edges[..., :1], vertical_patch_edges_center], -1 | |
) | |
vertical_positions = ( | |
vertical_patch_edges_lower | |
+ (vertical_patch_edges_upper - vertical_patch_edges_lower) | |
* torch.rand_like(vertical_patch_edges_lower) | |
)[..., :-1] | |
else: | |
horizontal_positions = ( | |
horizontal_patch_edges[:-1] + horizontal_patch_edges[1:] | |
) / 2 # (num_patches_x, ) # Center of patch | |
vertical_positions = ( | |
vertical_patch_edges[:-1] + vertical_patch_edges[1:] | |
) / 2 # (num_patches_y, ) # Center of patch | |
h_pos, v_pos = torch.meshgrid( | |
horizontal_positions, vertical_positions, indexing='xy' | |
) # (num_patches_y, num_patches_x), (num_patches_y, num_patches_x) | |
h_pos = h_pos.reshape(-1) # (num_patches_y * num_patches_x) | |
v_pos = v_pos.reshape(-1) # (num_patches_y * num_patches_x) | |
raybundle = get_directional_raybundle( | |
cameras=cameras, x_pos_ndc=h_pos, y_pos_ndc=v_pos, max_depth=max_depth | |
) | |
return raybundle | |
def get_patch_rays( | |
cameras_list, | |
num_patches_x, | |
num_patches_y, | |
device, | |
return_xys=False, | |
stratified=False, | |
): | |
"""Returns patch rays given the camera viewpoints | |
Args: | |
cameras_list(list[pytorch3d.renderer.cameras.BaseCameras]): List of list of cameras (len (batch_size, num_input_views,)) | |
num_patches_x: Number of patches in the x-direction (horizontal) | |
num_patches_y: Number of patches in the y-direction (vertical) | |
Returns: | |
torch.tensor: Patch rays of shape (batch_size, num_views, num_patches, 6) | |
""" | |
batch, numviews = len(cameras_list), len(cameras_list[0]) | |
cameras_list = join_cameras_as_batch([cam for cam_batch in cameras_list for cam in cam_batch]) | |
patch_rays = get_patch_raybundle( | |
cameras_list, | |
num_patches_y=num_patches_y, | |
num_patches_x=num_patches_x, | |
stratified=stratified, | |
) | |
if return_xys: | |
xys = patch_rays.xys | |
patch_rays = torch.cat((patch_rays.origins.unsqueeze(0), patch_rays.directions), dim=-1) | |
patch_rays = patch_rays.reshape( | |
batch, numviews, num_patches_x * num_patches_y, 6 | |
).to(device) | |
if return_xys: | |
return patch_rays, xys | |
return patch_rays | |
############################ RAY PARAMETERIZATION ############################## | |
def get_plucker_parameterization(ray): | |
"""Returns the plucker representation of the rays given the (origin, direction) representation | |
Args: | |
ray(torch.Tensor): Tensor of shape (..., 6) with the (origin, direction) representation | |
Returns: | |
torch.Tensor: Tensor of shape (..., 6) with the plucker (D, OxD) representation | |
""" | |
ray = ray.clone() # Create a clone | |
ray_origins = ray[..., :3] | |
ray_directions = ray[..., 3:] | |
ray_directions = ray_directions / ray_directions.norm(dim=-1).unsqueeze( | |
-1 | |
) # Normalize ray directions to unit vectors | |
plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1) | |
plucker_parameterization = torch.cat([ray_directions, plucker_normal], dim=-1) | |
return plucker_parameterization | |
def positional_encoding(ray, n_freqs=10, start_freq=0): | |
""" | |
Positional Embeddings. For more details see Section 5.1 of | |
NeRFs: https://arxiv.org/pdf/2003.08934.pdf | |
Args: | |
ray: (B,P,d) | |
n_freqs: num of frequency bands | |
parameterize(str|None): Parameterization used for rays. Recommended: use 'plucker'. Default=None. | |
Returns: | |
pos_embeddings: Mapping input ray from R to R^{2*n_freqs}. | |
""" | |
start_freq = -1 * (n_freqs / 2) | |
freq_bands = 2.0 ** torch.arange(start_freq, start_freq + n_freqs) * np.pi | |
sin_encodings = [torch.sin(ray * freq) for freq in freq_bands] | |
cos_encodings = [torch.cos(ray * freq) for freq in freq_bands] | |
pos_embeddings = torch.cat( | |
sin_encodings + cos_encodings, dim=-1 | |
) # B, P, d * 2n_freqs | |
return pos_embeddings | |
def convert_to_target_space(input_cameras, input_rays): | |
input_rays_transformed = [] | |
# input_cameras: b, N | |
# input_rays: b, N, hw, 6 | |
# return: b, N, hw, 6 | |
for i in range(len(input_cameras[0])): | |
reference_cameras = [cameras[0] for cameras in input_cameras] | |
reference_R = [ | |
camera.R.to(input_rays.device) for camera in reference_cameras | |
] # List (length=batch_size) of Rs(shape: 1, 3, 3) | |
reference_R = torch.cat(reference_R, dim=0) # (B, 3, 3) | |
reference_T = [ | |
camera.T.to(input_rays.device) for camera in reference_cameras | |
] # List (length=batch_size) of Ts(shape: 1, 3) | |
reference_T = torch.cat(reference_T, dim=0) # (B, 3) | |
input_rays_transformed.append( | |
transform_rays( | |
reference_R=reference_R, | |
reference_T=reference_T, | |
rays=input_rays[:, i : i + 1], | |
) | |
) | |
return torch.cat(input_rays_transformed, 1) | |
def convert_to_view_space(input_cameras, input_rays): | |
input_rays_transformed = [] | |
# input_cameras: b, N | |
# input_rays: b, hw, 6 | |
# return: b, n, hw, 6 | |
for i in range(len(input_cameras[0])): | |
reference_cameras = [cameras[i] for cameras in input_cameras] | |
reference_R = [ | |
camera.R.to(input_rays.device) for camera in reference_cameras | |
] # List (length=batch_size) of Rs(shape: 1, 3, 3) | |
reference_R = torch.cat(reference_R, dim=0) # (B, 3, 3) | |
reference_T = [ | |
camera.T.to(input_rays.device) for camera in reference_cameras | |
] # List (length=batch_size) of Ts(shape: 1, 3) | |
reference_T = torch.cat(reference_T, dim=0) # (B, 3) | |
input_rays_transformed.append( | |
transform_rays( | |
reference_R=reference_R, | |
reference_T=reference_T, | |
rays=input_rays.unsqueeze(1), | |
) | |
) | |
return torch.cat(input_rays_transformed, 1) | |
def convert_to_view_space_points(input_cameras, input_points): | |
input_rays_transformed = [] | |
# input_cameras: b, N | |
# ipput_points: b, hw, d, 3 | |
# returns: b, N, hw, d, 3 [target points transformed in the reference view frame] | |
for i in range(len(input_cameras[0])): | |
reference_cameras = [cameras[i] for cameras in input_cameras] | |
reference_R = [ | |
camera.R.to(input_points.device) for camera in reference_cameras | |
] # List (length=batch_size) of Rs(shape: 1, 3, 3) | |
reference_R = torch.cat(reference_R, dim=0) # (B, 3, 3) | |
reference_T = [ | |
camera.T.to(input_points.device) for camera in reference_cameras | |
] # List (length=batch_size) of Ts(shape: 1, 3) | |
reference_T = torch.cat(reference_T, dim=0) # (B, 3) | |
input_points_clone = torch.einsum( | |
"bsdj,bjk->bsdk", input_points, reference_R | |
) + reference_T.reshape(-1, 1, 1, 3) | |
input_rays_transformed.append(input_points_clone.unsqueeze(1)) | |
return torch.cat(input_rays_transformed, dim=1) | |
def interpolate_translate_interpolate_xaxis(cam1, interp_start, interp_end, interp_step): | |
cameras = [] | |
for i in np.arange(interp_start, interp_end, interp_step): | |
viewtoworld = cam1.get_world_to_view_transform().inverse() | |
x_axis = torch.from_numpy(np.array([i, 0., 0.0])).reshape(1,3).float().to(cam1.device) | |
newc = viewtoworld.transform_points(x_axis) | |
rt = cam1.R[0] | |
# t = cam1.T | |
new_t = -rt.T@newc.T | |
cameras.append(PerspectiveCameras(R=cam1.R, | |
T=new_t.T, | |
focal_length=cam1.focal_length, | |
principal_point=cam1.principal_point, | |
image_size=512, | |
) | |
) | |
return cameras | |
def interpolate_translate_interpolate_yaxis(cam1, interp_start, interp_end, interp_step): | |
cameras = [] | |
for i in np.arange(interp_start, interp_end, interp_step): | |
# i = np.clip(i, -0.2, 0.18) | |
viewtoworld = cam1.get_world_to_view_transform().inverse() | |
x_axis = torch.from_numpy(np.array([0, i, 0.0])).reshape(1,3).float().to(cam1.device) | |
newc = viewtoworld.transform_points(x_axis) | |
rt = cam1.R[0] | |
# t = cam1.T | |
new_t = -rt.T@newc.T | |
cameras.append(PerspectiveCameras(R=cam1.R, | |
T=new_t.T, | |
focal_length=cam1.focal_length, | |
principal_point=cam1.principal_point, | |
image_size=512, | |
) | |
) | |
return cameras | |
def interpolate_translate_interpolate_zaxis(cam1, interp_start, interp_end, interp_step): | |
cameras = [] | |
for i in np.arange(interp_start, interp_end, interp_step): | |
viewtoworld = cam1.get_world_to_view_transform().inverse() | |
x_axis = torch.from_numpy(np.array([0, 0., i])).reshape(1,3).float().to(cam1.device) | |
newc = viewtoworld.transform_points(x_axis) | |
rt = cam1.R[0] | |
# t = cam1.T | |
new_t = -rt.T@newc.T | |
cameras.append(PerspectiveCameras(R=cam1.R, | |
T=new_t.T, | |
focal_length=cam1.focal_length, | |
principal_point=cam1.principal_point, | |
image_size=512, | |
) | |
) | |
return cameras | |
def interpolatefocal(cam1, interp_start, interp_end, interp_step): | |
cameras = [] | |
for i in np.arange(interp_start, interp_end, interp_step): | |
cameras.append(PerspectiveCameras(R=cam1.R, | |
T=cam1.T, | |
focal_length=cam1.focal_length*i, | |
principal_point=cam1.principal_point, | |
image_size=512, | |
) | |
) | |
return cameras | |