Spaces:
Runtime error
Runtime error
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
import torch | |
import torch.nn.functional as F | |
from vidar.utils.decorators import iterate1 | |
def to_global_pose(pose, zero_origin=False): | |
"""Get global pose coordinates from current and context poses""" | |
if zero_origin: | |
pose[0].T[[0]] = torch.eye(4, device=pose[0].device, dtype=pose[0].dtype) | |
for b in range(1, len(pose[0])): | |
pose[0].T[[b]] = (pose[0][b] * pose[0][0]).T.float() | |
for key in pose.keys(): | |
if key != 0: | |
pose[key] = pose[key] * pose[0] | |
return pose | |
# def to_global_pose(pose, zero_origin=False): | |
# """Get global pose coordinates from current and context poses""" | |
# if zero_origin: | |
# pose[(0, 0)].T = torch.eye(4, device=pose[(0, 0)].device, dtype=pose[(0, 0)].dtype). \ | |
# repeat(pose[(0, 0)].shape[0], 1, 1) | |
# for key in pose.keys(): | |
# if key[0] == 0 and key[1] != 0: | |
# pose[key].T = (pose[key] * pose[(0, 0)]).T | |
# for key in pose.keys(): | |
# if key[0] != 0: | |
# pose[key] = pose[key] * pose[(0, 0)] | |
# return pose | |
def euler2mat(angle): | |
"""Convert euler angles to rotation matrix""" | |
B = angle.size(0) | |
x, y, z = angle[:, 0], angle[:, 1], angle[:, 2] | |
cosz = torch.cos(z) | |
sinz = torch.sin(z) | |
zeros = z.detach() * 0 | |
ones = zeros.detach() + 1 | |
zmat = torch.stack([ cosz, -sinz, zeros, | |
sinz, cosz, zeros, | |
zeros, zeros, ones], dim=1).view(B, 3, 3) | |
cosy = torch.cos(y) | |
siny = torch.sin(y) | |
ymat = torch.stack([ cosy, zeros, siny, | |
zeros, ones, zeros, | |
-siny, zeros, cosy], dim=1).view(B, 3, 3) | |
cosx = torch.cos(x) | |
sinx = torch.sin(x) | |
xmat = torch.stack([ ones, zeros, zeros, | |
zeros, cosx, -sinx, | |
zeros, sinx, cosx], dim=1).view(B, 3, 3) | |
rot_mat = xmat.bmm(ymat).bmm(zmat) | |
return rot_mat | |
def pose_vec2mat(vec, mode='euler'): | |
"""Convert translation and Euler rotation to a [B,4,4] torch.Tensor transformation matrix""" | |
if mode is None: | |
return vec | |
trans, rot = vec[:, :3].unsqueeze(-1), vec[:, 3:] | |
if mode == 'euler': | |
rot_mat = euler2mat(rot) | |
else: | |
raise ValueError('Rotation mode not supported {}'.format(mode)) | |
mat = torch.cat([rot_mat, trans], dim=2) # [B,3,4] | |
return mat | |
def invert_pose(T): | |
"""Invert a [B,4,4] torch.Tensor pose""" | |
Tinv = torch.eye(4, device=T.device, dtype=T.dtype).repeat([len(T), 1, 1]) | |
Tinv[:, :3, :3] = torch.transpose(T[:, :3, :3], -2, -1) | |
Tinv[:, :3, -1] = torch.bmm(-1. * Tinv[:, :3, :3], T[:, :3, -1].unsqueeze(-1)).squeeze(-1) | |
return Tinv | |
# return torch.linalg.inv(T) | |
def tvec_to_translation(tvec): | |
"""Convert translation vector to translation matrix (no rotation)""" | |
batch_size = tvec.shape[0] | |
T = torch.eye(4).to(device=tvec.device).repeat(batch_size, 1, 1) | |
t = tvec.contiguous().view(-1, 3, 1) | |
T[:, :3, 3, None] = t | |
return T | |
def euler2rot(euler): | |
"""Convert Euler parameters to a [B,3,3] torch.Tensor rotation matrix""" | |
euler_norm = torch.norm(euler, 2, 2, True) | |
axis = euler / (euler_norm + 1e-7) | |
cos_a = torch.cos(euler_norm) | |
sin_a = torch.sin(euler_norm) | |
cos1_a = 1 - cos_a | |
x = axis[..., 0].unsqueeze(1) | |
y = axis[..., 1].unsqueeze(1) | |
z = axis[..., 2].unsqueeze(1) | |
x_sin = x * sin_a | |
y_sin = y * sin_a | |
z_sin = z * sin_a | |
x_cos1 = x * cos1_a | |
y_cos1 = y * cos1_a | |
z_cos1 = z * cos1_a | |
xx_cos1 = x * x_cos1 | |
yy_cos1 = y * y_cos1 | |
zz_cos1 = z * z_cos1 | |
xy_cos1 = x * y_cos1 | |
yz_cos1 = y * z_cos1 | |
zx_cos1 = z * x_cos1 | |
batch_size = euler.shape[0] | |
rot = torch.zeros((batch_size, 4, 4)).to(device=euler.device) | |
rot[:, 0, 0] = torch.squeeze(xx_cos1 + cos_a) | |
rot[:, 0, 1] = torch.squeeze(xy_cos1 - z_sin) | |
rot[:, 0, 2] = torch.squeeze(zx_cos1 + y_sin) | |
rot[:, 1, 0] = torch.squeeze(xy_cos1 + z_sin) | |
rot[:, 1, 1] = torch.squeeze(yy_cos1 + cos_a) | |
rot[:, 1, 2] = torch.squeeze(yz_cos1 - x_sin) | |
rot[:, 2, 0] = torch.squeeze(zx_cos1 - y_sin) | |
rot[:, 2, 1] = torch.squeeze(yz_cos1 + x_sin) | |
rot[:, 2, 2] = torch.squeeze(zz_cos1 + cos_a) | |
rot[:, 3, 3] = 1 | |
return rot | |
def vec2mat(euler, translation, invert=False): | |
"""Convert Euler rotation and translation to a [B,4,4] torch.Tensor transformation matrix""" | |
R = euler2rot(euler) | |
t = translation.clone() | |
if invert: | |
R = R.transpose(1, 2) | |
t *= -1 | |
T = tvec_to_translation(t) | |
if invert: | |
M = torch.matmul(R, T) | |
else: | |
M = torch.matmul(T, R) | |
return M | |
def rot2quat(R): | |
"""Convert a [B,3,3] rotation matrix to [B,4] quaternions""" | |
b, _, _ = R.shape | |
q = torch.ones((b, 4), device=R.device) | |
R00 = R[:, 0, 0] | |
R01 = R[:, 0, 1] | |
R02 = R[:, 0, 2] | |
R10 = R[:, 1, 0] | |
R11 = R[:, 1, 1] | |
R12 = R[:, 1, 2] | |
R20 = R[:, 2, 0] | |
R21 = R[:, 2, 1] | |
R22 = R[:, 2, 2] | |
q[:, 3] = torch.sqrt(1.0 + R00 + R11 + R22) / 2 | |
q[:, 0] = (R21 - R12) / (4 * q[:, 3]) | |
q[:, 1] = (R02 - R20) / (4 * q[:, 3]) | |
q[:, 2] = (R10 - R01) / (4 * q[:, 3]) | |
return q | |
def quat2rot(q): | |
"""Convert [B,4] quaternions to [B,3,3] rotation matrix""" | |
b, _ = q.shape | |
q = F.normalize(q, dim=1) | |
R = torch.ones((b, 3, 3), device=q.device) | |
qr = q[:, 0] | |
qi = q[:, 1] | |
qj = q[:, 2] | |
qk = q[:, 3] | |
R[:, 0, 0] = 1 - 2 * (qj ** 2 + qk ** 2) | |
R[:, 0, 1] = 2 * (qj * qi - qk * qr) | |
R[:, 0, 2] = 2 * (qi * qk + qr * qj) | |
R[:, 1, 0] = 2 * (qj * qi + qk * qr) | |
R[:, 1, 1] = 1 - 2 * (qi ** 2 + qk ** 2) | |
R[:, 1, 2] = 2 * (qj * qk - qi * qr) | |
R[:, 2, 0] = 2 * (qk * qi - qj * qr) | |
R[:, 2, 1] = 2 * (qj * qk + qi * qr) | |
R[:, 2, 2] = 1 - 2 * (qi ** 2 + qj ** 2) | |
return R | |