Spaces:
Runtime error
Runtime error
File size: 3,584 Bytes
fc16538 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import torch
from vidar.geometry.pose_utils import invert_pose, pose_vec2mat
class Pose:
"""
Pose class, that encapsulates a [4,4] transformation matrix
for a specific reference frame
"""
def __init__(self, mat):
"""
Initializes a Pose object.
Parameters
----------
mat : torch.Tensor
Transformation matrix [B,4,4]
"""
assert tuple(mat.shape[-2:]) == (4, 4)
if mat.dim() == 2:
mat = mat.unsqueeze(0)
assert mat.dim() == 3
self.mat = mat
def __len__(self):
"""Batch size of the transformation matrix"""
return len(self.mat)
def __getitem__(self, i):
return Pose(self.mat[i].unsqueeze(0)).to(self.device)
@property
def device(self):
"""Return pose device"""
return self.mat.device
@classmethod
def identity(cls, N=1, device=None, dtype=torch.float):
"""Initializes as a [4,4] identity matrix"""
return cls(torch.eye(4, device=device, dtype=dtype).repeat([N,1,1]))
@classmethod
def from_vec(cls, vec, mode):
"""Initializes from a [B,6] batch vector"""
mat = pose_vec2mat(vec, mode) # [B,3,4]
pose = torch.eye(4, device=vec.device, dtype=vec.dtype).repeat([len(vec), 1, 1])
pose[:, :3, :3] = mat[:, :3, :3]
pose[:, :3, -1] = mat[:, :3, -1]
return cls(pose)
@property
def shape(self):
"""Returns the transformation matrix shape"""
return self.mat.shape
def item(self):
"""Returns the transformation matrix"""
return self.mat
def repeat(self, *args, **kwargs):
"""Repeats the transformation matrix multiple times"""
self.mat = self.mat.repeat(*args, **kwargs)
return self
def inverse(self):
"""Returns a new Pose that is the inverse of this one"""
return Pose(invert_pose(self.mat))
def to(self, *args, **kwargs):
"""Moves object to a specific device"""
self.mat = self.mat.to(*args, **kwargs)
return self
def transform_pose(self, pose):
"""Creates a new pose object that compounds this and another one (self * pose)"""
assert tuple(pose.shape[-2:]) == (4, 4)
return Pose(self.mat.bmm(pose.item()))
def transform_points(self, points):
"""Transforms 3D points using this object"""
assert 2 < points.dim() <= 4 and points.shape[1] == 3, \
'Wrong dimensions for transform_points'
# Determine if input is a grid
is_grid = points.dim() == 4
# If it's a grid, flatten it
points_flat = points.view(points.shape[0], 3, -1) if is_grid else points
# Tranform points
out = self.mat[:, :3, :3].bmm(points_flat) + \
self.mat[:, :3, -1].unsqueeze(-1)
# Return transformed points
return out.view(points.shape) if is_grid else out
def __matmul__(self, other):
"""Transforms the input (Pose or 3D points) using this object"""
if isinstance(other, Pose):
return self.transform_pose(other)
elif isinstance(other, torch.Tensor):
if other.shape[1] == 3 and other.dim() > 2:
assert other.dim() == 3 or other.dim() == 4
return self.transform_points(other)
else:
raise ValueError('Unknown tensor dimensions {}'.format(other.shape))
else:
raise NotImplementedError()
|