Jiading Fang
add define
fc16538
raw
history blame
5.94 kB
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import torch
import torch.nn.functional as F
from vidar.utils.decorators import iterate1
def to_global_pose(pose, zero_origin=False):
"""Get global pose coordinates from current and context poses"""
if zero_origin:
pose[0].T[[0]] = torch.eye(4, device=pose[0].device, dtype=pose[0].dtype)
for b in range(1, len(pose[0])):
pose[0].T[[b]] = (pose[0][b] * pose[0][0]).T.float()
for key in pose.keys():
if key != 0:
pose[key] = pose[key] * pose[0]
return pose
# def to_global_pose(pose, zero_origin=False):
# """Get global pose coordinates from current and context poses"""
# if zero_origin:
# pose[(0, 0)].T = torch.eye(4, device=pose[(0, 0)].device, dtype=pose[(0, 0)].dtype). \
# repeat(pose[(0, 0)].shape[0], 1, 1)
# for key in pose.keys():
# if key[0] == 0 and key[1] != 0:
# pose[key].T = (pose[key] * pose[(0, 0)]).T
# for key in pose.keys():
# if key[0] != 0:
# pose[key] = pose[key] * pose[(0, 0)]
# return pose
def euler2mat(angle):
"""Convert euler angles to rotation matrix"""
B = angle.size(0)
x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
cosz = torch.cos(z)
sinz = torch.sin(z)
zeros = z.detach() * 0
ones = zeros.detach() + 1
zmat = torch.stack([ cosz, -sinz, zeros,
sinz, cosz, zeros,
zeros, zeros, ones], dim=1).view(B, 3, 3)
cosy = torch.cos(y)
siny = torch.sin(y)
ymat = torch.stack([ cosy, zeros, siny,
zeros, ones, zeros,
-siny, zeros, cosy], dim=1).view(B, 3, 3)
cosx = torch.cos(x)
sinx = torch.sin(x)
xmat = torch.stack([ ones, zeros, zeros,
zeros, cosx, -sinx,
zeros, sinx, cosx], dim=1).view(B, 3, 3)
rot_mat = xmat.bmm(ymat).bmm(zmat)
return rot_mat
def pose_vec2mat(vec, mode='euler'):
"""Convert translation and Euler rotation to a [B,4,4] torch.Tensor transformation matrix"""
if mode is None:
return vec
trans, rot = vec[:, :3].unsqueeze(-1), vec[:, 3:]
if mode == 'euler':
rot_mat = euler2mat(rot)
else:
raise ValueError('Rotation mode not supported {}'.format(mode))
mat = torch.cat([rot_mat, trans], dim=2) # [B,3,4]
return mat
@iterate1
def invert_pose(T):
"""Invert a [B,4,4] torch.Tensor pose"""
Tinv = torch.eye(4, device=T.device, dtype=T.dtype).repeat([len(T), 1, 1])
Tinv[:, :3, :3] = torch.transpose(T[:, :3, :3], -2, -1)
Tinv[:, :3, -1] = torch.bmm(-1. * Tinv[:, :3, :3], T[:, :3, -1].unsqueeze(-1)).squeeze(-1)
return Tinv
# return torch.linalg.inv(T)
def tvec_to_translation(tvec):
"""Convert translation vector to translation matrix (no rotation)"""
batch_size = tvec.shape[0]
T = torch.eye(4).to(device=tvec.device).repeat(batch_size, 1, 1)
t = tvec.contiguous().view(-1, 3, 1)
T[:, :3, 3, None] = t
return T
def euler2rot(euler):
"""Convert Euler parameters to a [B,3,3] torch.Tensor rotation matrix"""
euler_norm = torch.norm(euler, 2, 2, True)
axis = euler / (euler_norm + 1e-7)
cos_a = torch.cos(euler_norm)
sin_a = torch.sin(euler_norm)
cos1_a = 1 - cos_a
x = axis[..., 0].unsqueeze(1)
y = axis[..., 1].unsqueeze(1)
z = axis[..., 2].unsqueeze(1)
x_sin = x * sin_a
y_sin = y * sin_a
z_sin = z * sin_a
x_cos1 = x * cos1_a
y_cos1 = y * cos1_a
z_cos1 = z * cos1_a
xx_cos1 = x * x_cos1
yy_cos1 = y * y_cos1
zz_cos1 = z * z_cos1
xy_cos1 = x * y_cos1
yz_cos1 = y * z_cos1
zx_cos1 = z * x_cos1
batch_size = euler.shape[0]
rot = torch.zeros((batch_size, 4, 4)).to(device=euler.device)
rot[:, 0, 0] = torch.squeeze(xx_cos1 + cos_a)
rot[:, 0, 1] = torch.squeeze(xy_cos1 - z_sin)
rot[:, 0, 2] = torch.squeeze(zx_cos1 + y_sin)
rot[:, 1, 0] = torch.squeeze(xy_cos1 + z_sin)
rot[:, 1, 1] = torch.squeeze(yy_cos1 + cos_a)
rot[:, 1, 2] = torch.squeeze(yz_cos1 - x_sin)
rot[:, 2, 0] = torch.squeeze(zx_cos1 - y_sin)
rot[:, 2, 1] = torch.squeeze(yz_cos1 + x_sin)
rot[:, 2, 2] = torch.squeeze(zz_cos1 + cos_a)
rot[:, 3, 3] = 1
return rot
def vec2mat(euler, translation, invert=False):
"""Convert Euler rotation and translation to a [B,4,4] torch.Tensor transformation matrix"""
R = euler2rot(euler)
t = translation.clone()
if invert:
R = R.transpose(1, 2)
t *= -1
T = tvec_to_translation(t)
if invert:
M = torch.matmul(R, T)
else:
M = torch.matmul(T, R)
return M
def rot2quat(R):
"""Convert a [B,3,3] rotation matrix to [B,4] quaternions"""
b, _, _ = R.shape
q = torch.ones((b, 4), device=R.device)
R00 = R[:, 0, 0]
R01 = R[:, 0, 1]
R02 = R[:, 0, 2]
R10 = R[:, 1, 0]
R11 = R[:, 1, 1]
R12 = R[:, 1, 2]
R20 = R[:, 2, 0]
R21 = R[:, 2, 1]
R22 = R[:, 2, 2]
q[:, 3] = torch.sqrt(1.0 + R00 + R11 + R22) / 2
q[:, 0] = (R21 - R12) / (4 * q[:, 3])
q[:, 1] = (R02 - R20) / (4 * q[:, 3])
q[:, 2] = (R10 - R01) / (4 * q[:, 3])
return q
def quat2rot(q):
"""Convert [B,4] quaternions to [B,3,3] rotation matrix"""
b, _ = q.shape
q = F.normalize(q, dim=1)
R = torch.ones((b, 3, 3), device=q.device)
qr = q[:, 0]
qi = q[:, 1]
qj = q[:, 2]
qk = q[:, 3]
R[:, 0, 0] = 1 - 2 * (qj ** 2 + qk ** 2)
R[:, 0, 1] = 2 * (qj * qi - qk * qr)
R[:, 0, 2] = 2 * (qi * qk + qr * qj)
R[:, 1, 0] = 2 * (qj * qi + qk * qr)
R[:, 1, 1] = 1 - 2 * (qi ** 2 + qk ** 2)
R[:, 1, 2] = 2 * (qj * qk - qi * qr)
R[:, 2, 0] = 2 * (qk * qi - qj * qr)
R[:, 2, 1] = 2 * (qj * qk + qi * qr)
R[:, 2, 2] = 1 - 2 * (qi ** 2 + qj ** 2)
return R