Spaces:
Sleeping
Sleeping
import tempfile | |
import os | |
import spaces | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from evo.tools.file_interface import read_kitti_poses_file | |
from pathlib import Path | |
import rerun as rr | |
from typing import Optional, Dict | |
from visualization.logger import SimulationLogger | |
from scipy.spatial.transform import Rotation | |
def load_trajectory_data(traj_file: str, char_file: str) -> Dict: | |
trajectory = read_kitti_poses_file(traj_file) | |
matrix_trajectory = torch.from_numpy( | |
np.array(trajectory.poses_se3)).to(torch.float32) | |
char_feature = torch.from_numpy(np.load(char_file)).to(torch.float32) | |
return { | |
"traj_filename": Path(traj_file).name, | |
"char_filename": Path(char_file).name, | |
"char_feat": char_feature, | |
"matrix_trajectory": matrix_trajectory | |
} | |
class ETLogger(SimulationLogger): | |
def __init__(self): | |
super().__init__() | |
rr.init("et_visualization") | |
rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True) | |
self.K = np.array([ | |
[500, 0, 320], | |
[0, 500, 240], | |
[0, 0, 1] | |
]) | |
def log_trajectory(self, trajectory: np.ndarray): | |
positions = trajectory[:, :3, 3] | |
rr.log( | |
"world/trajectory/points", | |
rr.Points3D( | |
positions, | |
colors=np.full((len(positions), 4), [0.0, 0.8, 0.8, 1.0]) | |
), | |
timeless=True | |
) | |
if len(positions) > 1: | |
lines = np.stack([positions[:-1], positions[1:]], axis=1) | |
rr.log( | |
"world/trajectory/line", | |
rr.LineStrips3D( | |
lines, | |
colors=[(0.0, 0.8, 0.8, 1.0)] | |
), | |
timeless=True | |
) | |
for k in range(len(trajectory)): | |
rr.set_time_sequence("frame_idx", k) | |
translation = trajectory[k, :3, 3] | |
rotation_q = Rotation.from_matrix( | |
trajectory[k, :3, :3]).as_quat() | |
rr.log( | |
f"world/camera", | |
rr.Transform3D( | |
translation=translation, | |
rotation=rr.Quaternion(xyzw=rotation_q), | |
), | |
) | |
rr.log( | |
f"world/camera/image", | |
rr.Pinhole( | |
image_from_camera=self.K, | |
width=640, | |
height=480, | |
), | |
) | |
def log_character(self, char_feature: np.ndarray): | |
rr.log( | |
"world/character", | |
rr.Points3D( | |
char_feature.reshape(-1, 3), | |
colors=np.full( | |
(char_feature.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0]) | |
), | |
timeless=True | |
) | |
def visualize_et_data(traj_file: str, char_file: str) -> Optional[str]: | |
try: | |
data = load_trajectory_data(traj_file, char_file) | |
temp_dir = tempfile.mkdtemp() | |
rrd_path = os.path.join(temp_dir, "et_visualization.rrd") | |
logger = ETLogger() | |
logger.log_trajectory(data["matrix_trajectory"].numpy()) | |
logger.log_character(data["char_feat"].numpy()) | |
rr.save(rrd_path) | |
return rrd_path | |
except Exception as e: | |
print(f"Error visualizing E.T. data: {str(e)}") | |
return None | |