rerun-viewer / visualization /et_visualizer.py
abreza's picture
remove padding
c065381
import tempfile
import os
import spaces
import numpy as np
import torch
import torch.nn.functional as F
from evo.tools.file_interface import read_kitti_poses_file
from pathlib import Path
import rerun as rr
from typing import Optional, Dict
from visualization.logger import SimulationLogger
from scipy.spatial.transform import Rotation
def load_trajectory_data(traj_file: str, char_file: str) -> Dict:
trajectory = read_kitti_poses_file(traj_file)
matrix_trajectory = torch.from_numpy(
np.array(trajectory.poses_se3)).to(torch.float32)
char_feature = torch.from_numpy(np.load(char_file)).to(torch.float32)
return {
"traj_filename": Path(traj_file).name,
"char_filename": Path(char_file).name,
"char_feat": char_feature,
"matrix_trajectory": matrix_trajectory
}
class ETLogger(SimulationLogger):
def __init__(self):
super().__init__()
rr.init("et_visualization")
rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True)
self.K = np.array([
[500, 0, 320],
[0, 500, 240],
[0, 0, 1]
])
def log_trajectory(self, trajectory: np.ndarray):
positions = trajectory[:, :3, 3]
rr.log(
"world/trajectory/points",
rr.Points3D(
positions,
colors=np.full((len(positions), 4), [0.0, 0.8, 0.8, 1.0])
),
timeless=True
)
if len(positions) > 1:
lines = np.stack([positions[:-1], positions[1:]], axis=1)
rr.log(
"world/trajectory/line",
rr.LineStrips3D(
lines,
colors=[(0.0, 0.8, 0.8, 1.0)]
),
timeless=True
)
for k in range(len(trajectory)):
rr.set_time_sequence("frame_idx", k)
translation = trajectory[k, :3, 3]
rotation_q = Rotation.from_matrix(
trajectory[k, :3, :3]).as_quat()
rr.log(
f"world/camera",
rr.Transform3D(
translation=translation,
rotation=rr.Quaternion(xyzw=rotation_q),
),
)
rr.log(
f"world/camera/image",
rr.Pinhole(
image_from_camera=self.K,
width=640,
height=480,
),
)
def log_character(self, char_feature: np.ndarray):
rr.log(
"world/character",
rr.Points3D(
char_feature.reshape(-1, 3),
colors=np.full(
(char_feature.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0])
),
timeless=True
)
@spaces.GPU
def visualize_et_data(traj_file: str, char_file: str) -> Optional[str]:
try:
data = load_trajectory_data(traj_file, char_file)
temp_dir = tempfile.mkdtemp()
rrd_path = os.path.join(temp_dir, "et_visualization.rrd")
logger = ETLogger()
logger.log_trajectory(data["matrix_trajectory"].numpy())
logger.log_character(data["char_feat"].numpy())
rr.save(rrd_path)
return rrd_path
except Exception as e:
print(f"Error visualizing E.T. data: {str(e)}")
return None