Jiading Fang
add define
fc16538
raw
history blame
5.63 kB
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import torch
from vidar.geometry.pose_utils import invert_pose, pose_vec2mat, to_global_pose, euler2mat
from vidar.utils.types import is_int
def from_dict_sample(T, to_global=False, zero_origin=False, to_matrix=False):
"""
Create poses from a sample dictionary
Parameters
----------
T : Dict
Dictionary containing input poses [B,4,4]
to_global : Bool
Whether poses should be converted to global frame of reference
zero_origin : Bool
Whether the target camera should be the center of the frame of reference
to_matrix : Bool
Whether output poses should be classes or tensors
Returns
-------
pose : Dict
Dictionary containing output poses
"""
pose = {key: Pose(val) for key, val in T.items()}
if to_global:
pose = to_global_pose(pose, zero_origin=zero_origin)
if to_matrix:
pose = {key: val.T for key, val in pose.items()}
return pose
def from_dict_batch(T, **kwargs):
"""Create poses from a batch dictionary"""
pose_batch = [from_dict_sample({key: val[b] for key, val in T.items()}, **kwargs)
for b in range(T[0].shape[0])]
return {key: torch.stack([v[key] for v in pose_batch], 0) for key in pose_batch[0]}
class Pose:
"""
Pose class for 3D operations
Parameters
----------
T : torch.Tensor or Int
Transformation matrix [B,4,4], or batch size (poses initialized as identity)
"""
def __init__(self, T=1):
if is_int(T):
T = torch.eye(4).repeat(T, 1, 1)
self.T = T if T.dim() == 3 else T.unsqueeze(0)
def __len__(self):
"""Return batch size"""
return len(self.T)
def __getitem__(self, i):
"""Return batch-wise pose"""
return Pose(self.T[[i]])
def __mul__(self, data):
"""Transforms data (pose or 3D points)"""
if isinstance(data, Pose):
return Pose(self.T.bmm(data.T))
elif isinstance(data, torch.Tensor):
return self.T[:, :3, :3].bmm(data) + self.T[:, :3, -1].unsqueeze(-1)
else:
raise NotImplementedError()
def detach(self):
"""Return detached pose"""
return Pose(self.T.detach())
@property
def shape(self):
"""Return pose shape"""
return self.T.shape
@property
def device(self):
"""Return pose device"""
return self.T.device
@property
def dtype(self):
"""Return pose type"""
return self.T.dtype
@classmethod
def identity(cls, N=1, device=None, dtype=torch.float):
"""Initializes as a [4,4] identity matrix"""
return cls(torch.eye(4, device=device, dtype=dtype).repeat([N,1,1]))
@staticmethod
def from_dict(T, to_global=False, zero_origin=False, to_matrix=False):
"""Create poses from a dictionary"""
if T[0].dim() == 3:
return from_dict_sample(T, to_global=to_global, zero_origin=zero_origin, to_matrix=to_matrix)
elif T[0].dim() == 4:
return from_dict_batch(T, to_global=to_global, zero_origin=zero_origin, to_matrix=True)
@classmethod
def from_vec(cls, vec, mode):
"""Initializes from a [B,6] batch vector"""
mat = pose_vec2mat(vec, mode)
pose = torch.eye(4, device=vec.device, dtype=vec.dtype).repeat([len(vec), 1, 1])
pose[:, :3, :3] = mat[:, :3, :3]
pose[:, :3, -1] = mat[:, :3, -1]
return cls(pose)
def repeat(self, *args, **kwargs):
"""Repeats the transformation matrix multiple times"""
self.T = self.T.repeat(*args, **kwargs)
return self
def inverse(self):
"""Returns a new Pose that is the inverse of this one"""
return Pose(invert_pose(self.T))
def to(self, *args, **kwargs):
"""Copy pose to device"""
self.T = self.T.to(*args, **kwargs)
return self
def cuda(self, *args, **kwargs):
"""Copy pose to CUDA"""
self.to('cuda')
return self
def translate(self, xyz):
"""Translate pose"""
self.T[:, :3, -1] = self.T[:, :3, -1] + xyz.to(self.device)
return self
def rotate(self, rpw):
"""Rotate pose"""
rot = euler2mat(rpw)
T = invert_pose(self.T).clone()
T[:, :3, :3] = T[:, :3, :3] @ rot.to(self.device)
self.T = invert_pose(T)
return self
def rotateRoll(self, r):
"""Rotate pose in the roll axis"""
return self.rotate(torch.tensor([[0, 0, r]]))
def rotatePitch(self, p):
"""Rotate pose in the pitcv axis"""
return self.rotate(torch.tensor([[p, 0, 0]]))
def rotateYaw(self, w):
"""Rotate pose in the yaw axis"""
return self.rotate(torch.tensor([[0, w, 0]]))
def translateForward(self, t):
"""Translate pose forward"""
return self.translate(torch.tensor([[0, 0, -t]]))
def translateBackward(self, t):
"""Translate pose backward"""
return self.translate(torch.tensor([[0, 0, +t]]))
def translateLeft(self, t):
"""Translate pose left"""
return self.translate(torch.tensor([[+t, 0, 0]]))
def translateRight(self, t):
"""Translate pose right"""
return self.translate(torch.tensor([[-t, 0, 0]]))
def translateUp(self, t):
"""Translate pose up"""
return self.translate(torch.tensor([[0, +t, 0]]))
def translateDown(self, t):
"""Translate pose down"""
return self.translate(torch.tensor([[0, -t, 0]]))