import tempfile import os import spaces import numpy as np import torch import torch.nn.functional as F from evo.tools.file_interface import read_kitti_poses_file from pathlib import Path import rerun as rr from typing import Optional, Dict from visualization.logger import SimulationLogger from scipy.spatial.transform import Rotation def load_trajectory_data(traj_file: str, char_file: str) -> Dict: trajectory = read_kitti_poses_file(traj_file) matrix_trajectory = torch.from_numpy( np.array(trajectory.poses_se3)).to(torch.float32) char_feature = torch.from_numpy(np.load(char_file)).to(torch.float32) return { "traj_filename": Path(traj_file).name, "char_filename": Path(char_file).name, "char_feat": char_feature, "matrix_trajectory": matrix_trajectory } class ETLogger(SimulationLogger): def __init__(self): super().__init__() rr.init("et_visualization") rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True) self.K = np.array([ [500, 0, 320], [0, 500, 240], [0, 0, 1] ]) def log_trajectory(self, trajectory: np.ndarray): positions = trajectory[:, :3, 3] rr.log( "world/trajectory/points", rr.Points3D( positions, colors=np.full((len(positions), 4), [0.0, 0.8, 0.8, 1.0]) ), timeless=True ) if len(positions) > 1: lines = np.stack([positions[:-1], positions[1:]], axis=1) rr.log( "world/trajectory/line", rr.LineStrips3D( lines, colors=[(0.0, 0.8, 0.8, 1.0)] ), timeless=True ) for k in range(len(trajectory)): rr.set_time_sequence("frame_idx", k) translation = trajectory[k, :3, 3] rotation_q = Rotation.from_matrix( trajectory[k, :3, :3]).as_quat() rr.log( f"world/camera", rr.Transform3D( translation=translation, rotation=rr.Quaternion(xyzw=rotation_q), ), ) rr.log( f"world/camera/image", rr.Pinhole( image_from_camera=self.K, width=640, height=480, ), ) def log_character(self, char_feature: np.ndarray): rr.log( "world/character", rr.Points3D( char_feature.reshape(-1, 3), colors=np.full( (char_feature.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0]) ), timeless=True ) @spaces.GPU def visualize_et_data(traj_file: str, char_file: str) -> Optional[str]: try: data = load_trajectory_data(traj_file, char_file) temp_dir = tempfile.mkdtemp() rrd_path = os.path.join(temp_dir, "et_visualization.rrd") logger = ETLogger() logger.log_trajectory(data["matrix_trajectory"].numpy()) logger.log_character(data["char_feat"].numpy()) rr.save(rrd_path) return rrd_path except Exception as e: print(f"Error visualizing E.T. data: {str(e)}") return None