customdiffusion360 / sgm /modules /utils_cameraray.py
customdiffusion360's picture
first commit
ad7bc89
raw
history blame
No virus
15.1 kB
#### Code taken from: https://github.com/mayankgrwl97/gbt
"""Utils for ray manipulation"""
import numpy as np
import torch
from pytorch3d.renderer.implicit.raysampling import RayBundle as RayBundle
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import PerspectiveCameras
############################# RAY BUNDLE UTILITIES #############################
def is_scalar(x):
"""Returns True if the provided variable is a scalar
Args:
x: scalar or array-like (numpy array or torch tensor)
Returns:
bool: True if x is of the type scalar, or array-like with 0 dimension. False, otherwise
"""
if isinstance(x, float) or isinstance(x, int):
return True
if isinstance(x, np.ndarray) and np.ndim(x) == 0:
return True
if isinstance(x, torch.Tensor) and x.dim() == 0:
return True
return False
def transform_rays(reference_R, reference_T, rays):
"""
PyTorch3D Convention is used: X_cam = X_world @ R + T
Args:
reference_R: world2cam rotation matrix for reference camera (B, 3, 3)
reference_T: world2cam translation vector for reference camera (B, 3)
rays: (origin, direction) defined in world reference frame (B, V, N, 6)
Returns:
torch.Tensor: Transformed rays w.r.t. reference camera (B, V, N, 6)
"""
batch, num_views, num_rays, ray_dim = rays.shape
assert (
ray_dim == 6
), "First 3 dimensions should be origin; Last 3 dimensions should be direction"
rays = rays.reshape(batch, num_views * num_rays, ray_dim)
rays_out = rays.clone()
rays_out[..., :3] = torch.bmm(rays[..., :3], reference_R) + reference_T.unsqueeze(
-2
)
rays_out[..., 3:] = torch.bmm(rays[..., 3:], reference_R)
rays_out = rays_out.reshape(batch, num_views, num_rays, ray_dim)
return rays_out
def get_directional_raybundle(cameras, x_pos_ndc, y_pos_ndc, max_depth=1):
if is_scalar(x_pos_ndc):
x_pos_ndc = [x_pos_ndc]
if is_scalar(y_pos_ndc):
y_pos_ndc = [y_pos_ndc]
assert is_scalar(max_depth)
if not isinstance(x_pos_ndc, torch.Tensor):
x_pos_ndc = torch.tensor(x_pos_ndc) # (N, )
if not isinstance(y_pos_ndc, torch.Tensor):
y_pos_ndc = torch.tensor(y_pos_ndc) # (N, )
xy_depth = torch.stack(
(x_pos_ndc, y_pos_ndc, torch.ones_like(x_pos_ndc) * max_depth), dim=-1
) # (N, 3)
num_points = xy_depth.shape[0]
unprojected = cameras.unproject_points(
xy_depth.to(cameras.device), world_coordinates=True, from_ndc=True
) # (N, 3)
unprojected = unprojected.unsqueeze(0).to("cpu") # (B, N, 3)
origins = (
cameras.get_camera_center()[:, None, :].expand(-1, num_points, -1).to("cpu")
) # (B, N, 3)
directions = unprojected - origins # (B, N, 3)
directions = directions / directions.norm(dim=-1).unsqueeze(-1) # (B, N, 3)
lengths = (
torch.tensor([[0, 3]]).unsqueeze(0).expand(-1, num_points, -1).to("cpu")
) # (B, N, 2)
xys = xy_depth[:, :2].unsqueeze(0).to("cpu") # (B, N, 2)
raybundle = RayBundle(
origins=origins.to("cpu"),
directions=directions.to("cpu"),
lengths=lengths.to("cpu"),
xys=xys.to("cpu"),
)
return raybundle
def get_patch_raybundle(
cameras, num_patches_x, num_patches_y, max_depth=1, stratified=False
):
horizontal_patch_edges = torch.linspace(1, -1, num_patches_x + 1)
# horizontal_positions = horizontal_patch_edges[:-1] # (num_patches_x,): Top left corner of patch
vertical_patch_edges = torch.linspace(1, -1, num_patches_y + 1)
# vertical_positions = vertical_patch_edges[:-1] # (num_patches_y,): Top left corner of patch
if stratified:
horizontal_patch_edges_center = (
horizontal_patch_edges[..., 1:] + horizontal_patch_edges[..., :-1]
) / 2.0
horizontal_patch_edges_upper = torch.cat(
[horizontal_patch_edges_center, horizontal_patch_edges[..., -1:]], -1
)
horizontal_patch_edges_lower = torch.cat(
[horizontal_patch_edges[..., :1], horizontal_patch_edges_center], -1
)
horizontal_positions = (
horizontal_patch_edges_lower
+ (horizontal_patch_edges_upper - horizontal_patch_edges_lower)
* torch.rand_like(horizontal_patch_edges_lower)
)[..., :-1]
vertical_patch_edges_center = (
vertical_patch_edges[..., 1:] + vertical_patch_edges[..., :-1]
) / 2.0
vertical_patch_edges_upper = torch.cat(
[vertical_patch_edges_center, vertical_patch_edges[..., -1:]], -1
)
vertical_patch_edges_lower = torch.cat(
[vertical_patch_edges[..., :1], vertical_patch_edges_center], -1
)
vertical_positions = (
vertical_patch_edges_lower
+ (vertical_patch_edges_upper - vertical_patch_edges_lower)
* torch.rand_like(vertical_patch_edges_lower)
)[..., :-1]
else:
horizontal_positions = (
horizontal_patch_edges[:-1] + horizontal_patch_edges[1:]
) / 2 # (num_patches_x, ) # Center of patch
vertical_positions = (
vertical_patch_edges[:-1] + vertical_patch_edges[1:]
) / 2 # (num_patches_y, ) # Center of patch
h_pos, v_pos = torch.meshgrid(
horizontal_positions, vertical_positions, indexing='xy'
) # (num_patches_y, num_patches_x), (num_patches_y, num_patches_x)
h_pos = h_pos.reshape(-1) # (num_patches_y * num_patches_x)
v_pos = v_pos.reshape(-1) # (num_patches_y * num_patches_x)
raybundle = get_directional_raybundle(
cameras=cameras, x_pos_ndc=h_pos, y_pos_ndc=v_pos, max_depth=max_depth
)
return raybundle
def get_patch_rays(
cameras_list,
num_patches_x,
num_patches_y,
device,
return_xys=False,
stratified=False,
):
"""Returns patch rays given the camera viewpoints
Args:
cameras_list(list[pytorch3d.renderer.cameras.BaseCameras]): List of list of cameras (len (batch_size, num_input_views,))
num_patches_x: Number of patches in the x-direction (horizontal)
num_patches_y: Number of patches in the y-direction (vertical)
Returns:
torch.tensor: Patch rays of shape (batch_size, num_views, num_patches, 6)
"""
batch, numviews = len(cameras_list), len(cameras_list[0])
cameras_list = join_cameras_as_batch([cam for cam_batch in cameras_list for cam in cam_batch])
patch_rays = get_patch_raybundle(
cameras_list,
num_patches_y=num_patches_y,
num_patches_x=num_patches_x,
stratified=stratified,
)
if return_xys:
xys = patch_rays.xys
patch_rays = torch.cat((patch_rays.origins.unsqueeze(0), patch_rays.directions), dim=-1)
patch_rays = patch_rays.reshape(
batch, numviews, num_patches_x * num_patches_y, 6
).to(device)
if return_xys:
return patch_rays, xys
return patch_rays
############################ RAY PARAMETERIZATION ##############################
def get_plucker_parameterization(ray):
"""Returns the plucker representation of the rays given the (origin, direction) representation
Args:
ray(torch.Tensor): Tensor of shape (..., 6) with the (origin, direction) representation
Returns:
torch.Tensor: Tensor of shape (..., 6) with the plucker (D, OxD) representation
"""
ray = ray.clone() # Create a clone
ray_origins = ray[..., :3]
ray_directions = ray[..., 3:]
ray_directions = ray_directions / ray_directions.norm(dim=-1).unsqueeze(
-1
) # Normalize ray directions to unit vectors
plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1)
plucker_parameterization = torch.cat([ray_directions, plucker_normal], dim=-1)
return plucker_parameterization
def positional_encoding(ray, n_freqs=10, start_freq=0):
"""
Positional Embeddings. For more details see Section 5.1 of
NeRFs: https://arxiv.org/pdf/2003.08934.pdf
Args:
ray: (B,P,d)
n_freqs: num of frequency bands
parameterize(str|None): Parameterization used for rays. Recommended: use 'plucker'. Default=None.
Returns:
pos_embeddings: Mapping input ray from R to R^{2*n_freqs}.
"""
start_freq = -1 * (n_freqs / 2)
freq_bands = 2.0 ** torch.arange(start_freq, start_freq + n_freqs) * np.pi
sin_encodings = [torch.sin(ray * freq) for freq in freq_bands]
cos_encodings = [torch.cos(ray * freq) for freq in freq_bands]
pos_embeddings = torch.cat(
sin_encodings + cos_encodings, dim=-1
) # B, P, d * 2n_freqs
return pos_embeddings
def convert_to_target_space(input_cameras, input_rays):
input_rays_transformed = []
# input_cameras: b, N
# input_rays: b, N, hw, 6
# return: b, N, hw, 6
for i in range(len(input_cameras[0])):
reference_cameras = [cameras[0] for cameras in input_cameras]
reference_R = [
camera.R.to(input_rays.device) for camera in reference_cameras
] # List (length=batch_size) of Rs(shape: 1, 3, 3)
reference_R = torch.cat(reference_R, dim=0) # (B, 3, 3)
reference_T = [
camera.T.to(input_rays.device) for camera in reference_cameras
] # List (length=batch_size) of Ts(shape: 1, 3)
reference_T = torch.cat(reference_T, dim=0) # (B, 3)
input_rays_transformed.append(
transform_rays(
reference_R=reference_R,
reference_T=reference_T,
rays=input_rays[:, i : i + 1],
)
)
return torch.cat(input_rays_transformed, 1)
def convert_to_view_space(input_cameras, input_rays):
input_rays_transformed = []
# input_cameras: b, N
# input_rays: b, hw, 6
# return: b, n, hw, 6
for i in range(len(input_cameras[0])):
reference_cameras = [cameras[i] for cameras in input_cameras]
reference_R = [
camera.R.to(input_rays.device) for camera in reference_cameras
] # List (length=batch_size) of Rs(shape: 1, 3, 3)
reference_R = torch.cat(reference_R, dim=0) # (B, 3, 3)
reference_T = [
camera.T.to(input_rays.device) for camera in reference_cameras
] # List (length=batch_size) of Ts(shape: 1, 3)
reference_T = torch.cat(reference_T, dim=0) # (B, 3)
input_rays_transformed.append(
transform_rays(
reference_R=reference_R,
reference_T=reference_T,
rays=input_rays.unsqueeze(1),
)
)
return torch.cat(input_rays_transformed, 1)
def convert_to_view_space_points(input_cameras, input_points):
input_rays_transformed = []
# input_cameras: b, N
# ipput_points: b, hw, d, 3
# returns: b, N, hw, d, 3 [target points transformed in the reference view frame]
for i in range(len(input_cameras[0])):
reference_cameras = [cameras[i] for cameras in input_cameras]
reference_R = [
camera.R.to(input_points.device) for camera in reference_cameras
] # List (length=batch_size) of Rs(shape: 1, 3, 3)
reference_R = torch.cat(reference_R, dim=0) # (B, 3, 3)
reference_T = [
camera.T.to(input_points.device) for camera in reference_cameras
] # List (length=batch_size) of Ts(shape: 1, 3)
reference_T = torch.cat(reference_T, dim=0) # (B, 3)
input_points_clone = torch.einsum(
"bsdj,bjk->bsdk", input_points, reference_R
) + reference_T.reshape(-1, 1, 1, 3)
input_rays_transformed.append(input_points_clone.unsqueeze(1))
return torch.cat(input_rays_transformed, dim=1)
def interpolate_translate_interpolate_xaxis(cam1, interp_start, interp_end, interp_step):
cameras = []
for i in np.arange(interp_start, interp_end, interp_step):
viewtoworld = cam1.get_world_to_view_transform().inverse()
x_axis = torch.from_numpy(np.array([i, 0., 0.0])).reshape(1,3).float().to(cam1.device)
newc = viewtoworld.transform_points(x_axis)
rt = cam1.R[0]
# t = cam1.T
new_t = -rt.T@newc.T
cameras.append(PerspectiveCameras(R=cam1.R,
T=new_t.T,
focal_length=cam1.focal_length,
principal_point=cam1.principal_point,
image_size=512,
)
)
return cameras
def interpolate_translate_interpolate_yaxis(cam1, interp_start, interp_end, interp_step):
cameras = []
for i in np.arange(interp_start, interp_end, interp_step):
# i = np.clip(i, -0.2, 0.18)
viewtoworld = cam1.get_world_to_view_transform().inverse()
x_axis = torch.from_numpy(np.array([0, i, 0.0])).reshape(1,3).float().to(cam1.device)
newc = viewtoworld.transform_points(x_axis)
rt = cam1.R[0]
# t = cam1.T
new_t = -rt.T@newc.T
cameras.append(PerspectiveCameras(R=cam1.R,
T=new_t.T,
focal_length=cam1.focal_length,
principal_point=cam1.principal_point,
image_size=512,
)
)
return cameras
def interpolate_translate_interpolate_zaxis(cam1, interp_start, interp_end, interp_step):
cameras = []
for i in np.arange(interp_start, interp_end, interp_step):
viewtoworld = cam1.get_world_to_view_transform().inverse()
x_axis = torch.from_numpy(np.array([0, 0., i])).reshape(1,3).float().to(cam1.device)
newc = viewtoworld.transform_points(x_axis)
rt = cam1.R[0]
# t = cam1.T
new_t = -rt.T@newc.T
cameras.append(PerspectiveCameras(R=cam1.R,
T=new_t.T,
focal_length=cam1.focal_length,
principal_point=cam1.principal_point,
image_size=512,
)
)
return cameras
def interpolatefocal(cam1, interp_start, interp_end, interp_step):
cameras = []
for i in np.arange(interp_start, interp_end, interp_step):
cameras.append(PerspectiveCameras(R=cam1.R,
T=cam1.T,
focal_length=cam1.focal_length*i,
principal_point=cam1.principal_point,
image_size=512,
)
)
return cameras