DIRECTOR-demo / utils /rotation_utils.py
robin-courant's picture
Add app
f7a5cb1 verified
raw
history blame
5.64 kB
import numpy as np
from scipy.spatial.transform import Rotation as R
import torch
from torchtyping import TensorType
from itertools import product
num_samples, num_cams = None, None
def rotvec_to_matrix(rotvec):
return R.from_rotvec(rotvec).as_matrix()
def matrix_to_rotvec(mat):
return R.from_matrix(mat).as_rotvec()
def compose_rotvec(r1, r2):
"""
#TODO: adapt to torch
Compose two rotation euler vectors.
"""
r1 = r1.cpu().numpy() if isinstance(r1, torch.Tensor) else r1
r2 = r2.cpu().numpy() if isinstance(r2, torch.Tensor) else r2
R1 = rotvec_to_matrix(r1)
R2 = rotvec_to_matrix(r2)
cR = np.einsum("...ij,...jk->...ik", R1, R2)
return torch.from_numpy(matrix_to_rotvec(cR))
def quat_to_rotvec(quat, eps=1e-6):
# w > 0 to ensure 0 <= angle <= pi
flip = (quat[..., :1] < 0).float()
quat = (-1 * quat) * flip + (1 - flip) * quat
angle = 2 * torch.atan2(torch.linalg.norm(quat[..., 1:], dim=-1), quat[..., 0])
angle2 = angle * angle
small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
large_angle_scales = angle / torch.sin(angle / 2 + eps)
small_angles = (angle <= 1e-3).float()
rot_vec_scale = (
small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales
)
rot_vec = rot_vec_scale[..., None] * quat[..., 1:]
return rot_vec
# batch*n
def normalize_vector(v, return_mag=False):
batch = v.shape[0]
v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
v_mag = torch.max(
v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8])).to(v.device)
)
v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
v = v / v_mag
if return_mag is True:
return v, v_mag[:, 0]
else:
return v
# u, v batch*n
def cross_product(u, v):
batch = u.shape[0]
i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
out = torch.cat(
(i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1
) # [batch, 6]
return out
def compute_rotation_matrix_from_ortho6d(ortho6d):
x_raw = ortho6d[:, 0:3] # [batch, 6]
y_raw = ortho6d[:, 3:6] # [batch, 6]
x = normalize_vector(x_raw) # [batch, 6]
z = cross_product(x, y_raw) # [batch, 6]
z = normalize_vector(z) # [batch, 6]
y = cross_product(z, x) # [batch, 6]
x = x.view(-1, 3, 1)
y = y.view(-1, 3, 1)
z = z.view(-1, 3, 1)
matrix = torch.cat((x, y, z), 2) # [batch, 3, 3]
return matrix
def invert_rotvec(rotvec: TensorType["num_samples", 3]):
angle = torch.norm(rotvec, dim=-1)
axis = rotvec / (angle.unsqueeze(-1) + 1e-6)
inverted_rotvec = -angle.unsqueeze(-1) * axis
return inverted_rotvec
def are_rotations(matrix: TensorType["num_samples", 3, 3]) -> TensorType["num_samples"]:
"""Check if a matrix is a rotation matrix."""
# Check if the matrix is orthogonal
identity = torch.eye(3, device=matrix.device)
is_orthogonal = (
torch.isclose(torch.bmm(matrix, matrix.transpose(1, 2)), identity, atol=1e-6)
.all(dim=1)
.all(dim=1)
)
# Check if the determinant is 1
determinant = torch.det(matrix)
is_determinant_one = torch.isclose(
determinant, torch.tensor(1.0, device=matrix.device), atol=1e-6
)
return torch.logical_and(is_orthogonal, is_determinant_one)
def project_so3(
matrix: TensorType["num_samples", 4, 4]
) -> TensorType["num_samples", 4, 4]:
# Project rotation matrix to SO(3)
# TODO: use torch
rot = R.from_matrix(matrix[:, :3, :3].cpu().numpy()).as_matrix()
projection = torch.eye(4).unsqueeze(0).repeat(matrix.shape[0], 1, 1).to(matrix)
projection[:, :3, :3] = torch.from_numpy(rot).to(matrix)
projection[:, :3, 3] = matrix[:, :3, 3]
return projection
def pairwise_geodesic(
R_x: TensorType["num_samples", "num_cams", 3, 3],
R_y: TensorType["num_samples", "num_cams", 3, 3],
reduction: str = "mean",
block_size: int = 200,
):
def arange(start, stop, step, endpoint=True):
arr = torch.arange(start, stop, step)
if endpoint and arr[-1] != stop - 1:
arr = torch.cat((arr, torch.tensor([stop - 1], dtype=arr.dtype)))
return arr
# Geodesic distance
# https://math.stackexchange.com/questions/2113634/comparing-two-rotation-matrices
num_samples, num_cams, _, _ = R_x.shape
C = torch.zeros(num_samples, num_samples, device=R_x.device)
chunk_indices = arange(0, num_samples + 1, block_size, endpoint=True)
for i, j in product(
range(chunk_indices.shape[0] - 1), range(chunk_indices.shape[0] - 1)
):
start_x, stop_x = chunk_indices[i], chunk_indices[i + 1]
start_y, stop_y = chunk_indices[j], chunk_indices[j + 1]
r_x, r_y = R_x[start_x:stop_x], R_y[start_y:stop_y]
# Compute rotations between each pair of cameras of each sample
r_xy = torch.einsum("anjk,bnlk->abnjl", r_x, r_y) # b, b, N, 3, 3
# Compute axis-angle representations: angle is the geodesic distance
traces = r_xy.diagonal(dim1=-2, dim2=-1).sum(-1)
c = torch.acos(torch.clamp((traces - 1) / 2, -1, 1)) / torch.pi
# Average distance between cameras over samples
if reduction == "mean":
C[start_x:stop_x, start_y:stop_y] = c.mean(-1)
elif reduction == "sum":
C[start_x:stop_x, start_y:stop_y] = c.sum(-1)
# Check for NaN values in traces
if torch.isnan(c).any():
raise ValueError("NaN values detected in traces")
return C