Spaces:
Runtime error
Runtime error
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
import torch | |
from pytorch3d.transforms.rotation_conversions import \ | |
matrix_to_euler_angles, euler_angles_to_matrix | |
from vidar.utils.data import keys_in | |
from vidar.utils.decorators import iterate1, iterate12 | |
from vidar.utils.types import is_tensor, is_list, is_seq | |
def flip_lr_fn(tensor): | |
"""Function to flip a tensor from left to right""" | |
return torch.flip(tensor, [-1]) | |
def flip_flow_lr_fn(flow): | |
"""Function to flip a flow tensor from left to right""" | |
flow_flip = torch.flip(flow, [3]) | |
flow_flip[:, :1, :, :] *= -1 | |
return flow_flip.contiguous() | |
def flip_intrinsics_lr_fn(K, shape): | |
"""Function to flip a 3x3 intrinsic matrix from left to right""" | |
K = K.clone() | |
K[:, 0, 2] = shape[-1] - K[:, 0, 2] | |
return K | |
def flip_pose_lr_fn(T): | |
"""Function to flip a 4x4 transformation matrix from left to right""" | |
rot = T[:, :3, :3] | |
axis = matrix_to_euler_angles(rot, convention='XYZ') | |
axis[:, [1, 2]] = axis[:, [1, 2]] * -1 | |
rot = euler_angles_to_matrix(axis, convention='XYZ') | |
T[:, :3, :3] = rot | |
T[:, 0, -1] = - T[:, 0, -1] | |
return T | |
def flip_lr(tensor, flip=True): | |
"""Flip a tensor from left to right""" | |
# Not flipping option | |
if not flip: | |
return tensor | |
# If it's a list, repeat | |
if is_list(tensor): | |
return [flip_lr(t) for t in tensor] | |
# Return flipped tensor | |
if tensor.dim() == 5: | |
return torch.stack([flip_lr_fn(tensor[:, i]) | |
for i in range(tensor.shape[1])], 1) | |
else: | |
return flip_lr_fn(tensor) | |
def flip_flow_lr(flow, flip=True): | |
"""Flip a flow tensor from left to right""" | |
# Not flipping option | |
if not flip: | |
return flow | |
# If it's a list, repeat | |
if is_list(flow): | |
return [flip_flow_lr(f) for f in flow] | |
# Flip flow and invert first dimension | |
if flow.dim() == 5: | |
return torch.stack([flip_flow_lr_fn(flow[:, i]) | |
for i in range(flow.shape[1])], 1) | |
else: | |
return flip_flow_lr_fn(flow) | |
def flip_intrinsics_lr(K, shape, flip=True): | |
"""Flip a 3x3 camera intrinsic matrix from left to right""" | |
# Not flipping option | |
if not flip: | |
return K | |
# If shape is a tensor, use it's dimensions | |
if is_tensor(shape): | |
shape = shape.shape | |
# Flip horizontal information (first row) | |
if K.dim() == 4: | |
return torch.stack([flip_intrinsics_lr_fn(K[:, i], shape) | |
for i in range(K.shape[1])], 1) | |
else: | |
return flip_intrinsics_lr_fn(K, shape) | |
def flip_pose_lr(pose, flip=True): | |
"""Flip a 4x4 transformation matrix from left to right""" | |
# Not flipping option | |
if not flip: | |
return pose | |
# Repeat for all pose keys | |
for key in pose.keys(): | |
# Get pose key | |
if key == 0: | |
if pose[key].dim() == 3: | |
continue | |
elif pose[key].dim() == 4: | |
T = pose[key][:, 1:].clone() | |
else: | |
raise ValueError('Invalid pose dimension') | |
else: | |
T = pose[key].clone() | |
# Flip pose | |
if T.dim() == 4: | |
T = torch.stack([flip_pose_lr_fn(T[:, i]) | |
for i in range(T.shape[1])], 1) | |
else: | |
T = flip_pose_lr_fn(T) | |
# Store flipped value back | |
if key == 0: | |
pose[key][:, 1:] = T | |
else: | |
pose[key] = T | |
# Return flipped pose | |
return pose | |
def flip_batch(batch, flip=True): | |
"""Flip a batch from left to right""" | |
# Not flipping option | |
if not flip: | |
return batch | |
# If it's a list, repeat | |
if is_seq(batch): | |
return [flip_batch(b) for b in batch] | |
# Flip batch | |
flipped_batch = {} | |
# Keys to not flip | |
for key in keys_in(batch, ['idx', 'filename', 'splitname']): | |
flipped_batch[key] = batch[key] | |
# Tensor flipping | |
for key in keys_in(batch, ['rgb', 'mask', 'input_depth', 'depth', 'semantic']): | |
flipped_batch[key] = flip_lr(batch[key]) | |
# Intrinsics flipping | |
for key in keys_in(batch, ['intrinsics']): | |
flipped_batch[key] = flip_intrinsics_lr(batch[key], batch['rgb']) | |
# Pose flipping | |
for key in keys_in(batch, ['pose']): | |
flipped_batch[key] = flip_pose_lr(batch[key]) | |
return flipped_batch | |
def flip_predictions(predictions, flip=True): | |
"""Flip predictions from left to right""" | |
# Not flipping option | |
if not flip: | |
return predictions | |
# Flip predictions | |
flipped_predictions = {} | |
for key in predictions.keys(): | |
if key.startswith('depth'): | |
flipped_predictions[key] = flip_lr(predictions[key]) | |
if key.startswith('pose'): | |
flipped_predictions[key] = flip_pose_lr(predictions[key]) | |
# Return flipped predictions | |
return flipped_predictions | |
def flip_output(output, flip=True): | |
"""Flip output from left to right""" | |
# Not flipping option | |
if not flip: | |
return output | |
# If it's a list, repeat | |
if is_seq(output): | |
return [flip_output(b) for b in output] | |
# Flip output | |
flipped_output = {} | |
# Do not flip loss and metrics | |
for key in keys_in(output, ['loss', 'metrics']): | |
flipped_output[key] = output[key] | |
# Flip predictions | |
flipped_output['predictions'] = flip_predictions(output['predictions']) | |
# Return flipped output | |
return flipped_output | |