rerun-viewer / visualization /et_visualizer.py
abreza's picture
remove unused codes
44f6fd9
raw
history blame
4.59 kB
import tempfile
import os
import spaces
import numpy as np
import torch
import torch.nn.functional as F
from evo.tools.file_interface import read_kitti_poses_file
from pathlib import Path
import rerun as rr
from typing import Optional, Dict
from visualization.logger import SimulationLogger
from scipy.spatial.transform import Rotation
def load_trajectory_data(traj_file: str, char_file: str, num_cams: int = 30) -> Dict:
trajectory = read_kitti_poses_file(traj_file)
matrix_trajectory = torch.from_numpy(
np.array(trajectory.poses_se3)).to(torch.float32)
raw_trans = torch.clone(matrix_trajectory[:, :3, 3])
raw_rot = matrix_trajectory[:, :3, :3]
rot6d = raw_rot[:, :, :2].permute(0, 2, 1).reshape(-1, 6)
trajectory_feature = torch.hstack([rot6d, raw_trans]).permute(1, 0)
padded_trajectory_feature = F.pad(
trajectory_feature,
(0, num_cams - trajectory_feature.shape[1])
)
padding_mask = torch.ones((num_cams))
padding_mask[trajectory_feature.shape[1]:] = 0
char_feature = torch.from_numpy(np.load(char_file)).to(torch.float32)
padding_size = num_cams - char_feature.shape[0]
padded_char_feature = F.pad(
char_feature, (0, 0, 0, padding_size)).permute(1, 0)
return {
"traj_filename": Path(traj_file).name,
"char_filename": Path(char_file).name,
"traj_feat": padded_trajectory_feature,
"char_feat": padded_char_feature,
"padding_mask": padding_mask,
"raw_matrix_trajectory": matrix_trajectory
}
class ETLogger(SimulationLogger):
def __init__(self):
super().__init__()
rr.init("et_visualization")
rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True)
self.K = np.array([
[500, 0, 320],
[0, 500, 240],
[0, 0, 1]
])
def log_trajectory(self, trajectory: np.ndarray, padding_mask: np.ndarray):
valid_frames = int(padding_mask.sum())
valid_trajectory = trajectory[:valid_frames]
positions = valid_trajectory[:, :3, 3]
rr.log(
"world/trajectory/points",
rr.Points3D(
positions,
colors=np.full((len(positions), 4), [0.0, 0.8, 0.8, 1.0])
),
timeless=True
)
if len(positions) > 1:
lines = np.stack([positions[:-1], positions[1:]], axis=1)
rr.log(
"world/trajectory/line",
rr.LineStrips3D(
lines,
colors=[(0.0, 0.8, 0.8, 1.0)]
),
timeless=True
)
for k in range(valid_frames):
rr.set_time_sequence("frame_idx", k)
translation = valid_trajectory[k, :3, 3]
rotation_q = Rotation.from_matrix(
valid_trajectory[k, :3, :3]).as_quat()
rr.log(
f"world/camera",
rr.Transform3D(
translation=translation,
rotation=rr.Quaternion(xyzw=rotation_q),
),
)
rr.log(
f"world/camera/image",
rr.Pinhole(
image_from_camera=self.K,
width=640,
height=480,
),
)
def log_character(self, char_feature: np.ndarray, padding_mask: np.ndarray):
valid_frames = int(padding_mask.sum())
valid_char = char_feature[:, :valid_frames]
if valid_char.shape[0] > 0:
rr.log(
"world/character",
rr.Points3D(
valid_char.reshape(-1, 3),
colors=np.full(
(valid_char.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0])
),
timeless=True
)
@spaces.GPU
def visualize_et_data(traj_file: str, char_file: str) -> Optional[str]:
try:
data = load_trajectory_data(traj_file, char_file)
temp_dir = tempfile.mkdtemp()
rrd_path = os.path.join(temp_dir, "et_visualization.rrd")
logger = ETLogger()
logger.log_trajectory(
data["raw_matrix_trajectory"].numpy(),
data["padding_mask"].numpy()
)
logger.log_character(
data["char_feat"].numpy(),
data["padding_mask"].numpy()
)
rr.save(rrd_path)
return rrd_path
except Exception as e:
print(f"Error visualizing E.T. data: {str(e)}")
return None