Jiading Fang
add define
fc16538
raw
history blame
5.53 kB
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import torch
from pytorch3d.transforms.rotation_conversions import \
matrix_to_euler_angles, euler_angles_to_matrix
from vidar.utils.data import keys_in
from vidar.utils.decorators import iterate1, iterate12
from vidar.utils.types import is_tensor, is_list, is_seq
def flip_lr_fn(tensor):
"""Function to flip a tensor from left to right"""
return torch.flip(tensor, [-1])
def flip_flow_lr_fn(flow):
"""Function to flip a flow tensor from left to right"""
flow_flip = torch.flip(flow, [3])
flow_flip[:, :1, :, :] *= -1
return flow_flip.contiguous()
def flip_intrinsics_lr_fn(K, shape):
"""Function to flip a 3x3 intrinsic matrix from left to right"""
K = K.clone()
K[:, 0, 2] = shape[-1] - K[:, 0, 2]
return K
def flip_pose_lr_fn(T):
"""Function to flip a 4x4 transformation matrix from left to right"""
rot = T[:, :3, :3]
axis = matrix_to_euler_angles(rot, convention='XYZ')
axis[:, [1, 2]] = axis[:, [1, 2]] * -1
rot = euler_angles_to_matrix(axis, convention='XYZ')
T[:, :3, :3] = rot
T[:, 0, -1] = - T[:, 0, -1]
return T
@iterate1
def flip_lr(tensor, flip=True):
"""Flip a tensor from left to right"""
# Not flipping option
if not flip:
return tensor
# If it's a list, repeat
if is_list(tensor):
return [flip_lr(t) for t in tensor]
# Return flipped tensor
if tensor.dim() == 5:
return torch.stack([flip_lr_fn(tensor[:, i])
for i in range(tensor.shape[1])], 1)
else:
return flip_lr_fn(tensor)
@iterate1
def flip_flow_lr(flow, flip=True):
"""Flip a flow tensor from left to right"""
# Not flipping option
if not flip:
return flow
# If it's a list, repeat
if is_list(flow):
return [flip_flow_lr(f) for f in flow]
# Flip flow and invert first dimension
if flow.dim() == 5:
return torch.stack([flip_flow_lr_fn(flow[:, i])
for i in range(flow.shape[1])], 1)
else:
return flip_flow_lr_fn(flow)
@iterate12
def flip_intrinsics_lr(K, shape, flip=True):
"""Flip a 3x3 camera intrinsic matrix from left to right"""
# Not flipping option
if not flip:
return K
# If shape is a tensor, use it's dimensions
if is_tensor(shape):
shape = shape.shape
# Flip horizontal information (first row)
if K.dim() == 4:
return torch.stack([flip_intrinsics_lr_fn(K[:, i], shape)
for i in range(K.shape[1])], 1)
else:
return flip_intrinsics_lr_fn(K, shape)
def flip_pose_lr(pose, flip=True):
"""Flip a 4x4 transformation matrix from left to right"""
# Not flipping option
if not flip:
return pose
# Repeat for all pose keys
for key in pose.keys():
# Get pose key
if key == 0:
if pose[key].dim() == 3:
continue
elif pose[key].dim() == 4:
T = pose[key][:, 1:].clone()
else:
raise ValueError('Invalid pose dimension')
else:
T = pose[key].clone()
# Flip pose
if T.dim() == 4:
T = torch.stack([flip_pose_lr_fn(T[:, i])
for i in range(T.shape[1])], 1)
else:
T = flip_pose_lr_fn(T)
# Store flipped value back
if key == 0:
pose[key][:, 1:] = T
else:
pose[key] = T
# Return flipped pose
return pose
def flip_batch(batch, flip=True):
"""Flip a batch from left to right"""
# Not flipping option
if not flip:
return batch
# If it's a list, repeat
if is_seq(batch):
return [flip_batch(b) for b in batch]
# Flip batch
flipped_batch = {}
# Keys to not flip
for key in keys_in(batch, ['idx', 'filename', 'splitname']):
flipped_batch[key] = batch[key]
# Tensor flipping
for key in keys_in(batch, ['rgb', 'mask', 'input_depth', 'depth', 'semantic']):
flipped_batch[key] = flip_lr(batch[key])
# Intrinsics flipping
for key in keys_in(batch, ['intrinsics']):
flipped_batch[key] = flip_intrinsics_lr(batch[key], batch['rgb'])
# Pose flipping
for key in keys_in(batch, ['pose']):
flipped_batch[key] = flip_pose_lr(batch[key])
return flipped_batch
def flip_predictions(predictions, flip=True):
"""Flip predictions from left to right"""
# Not flipping option
if not flip:
return predictions
# Flip predictions
flipped_predictions = {}
for key in predictions.keys():
if key.startswith('depth'):
flipped_predictions[key] = flip_lr(predictions[key])
if key.startswith('pose'):
flipped_predictions[key] = flip_pose_lr(predictions[key])
# Return flipped predictions
return flipped_predictions
def flip_output(output, flip=True):
"""Flip output from left to right"""
# Not flipping option
if not flip:
return output
# If it's a list, repeat
if is_seq(output):
return [flip_output(b) for b in output]
# Flip output
flipped_output = {}
# Do not flip loss and metrics
for key in keys_in(output, ['loss', 'metrics']):
flipped_output[key] = output[key]
# Flip predictions
flipped_output['predictions'] = flip_predictions(output['predictions'])
# Return flipped output
return flipped_output