Spaces:
Sleeping
Sleeping
fix et issue
Browse files- visualization/et_visualizer.py +60 -82
visualization/et_visualizer.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
@@ -7,7 +10,6 @@ import rerun as rr
|
|
7 |
from typing import Optional, Dict
|
8 |
from visualization.logger import SimulationLogger
|
9 |
from scipy.spatial.transform import Rotation
|
10 |
-
from rerun.components import Material
|
11 |
|
12 |
|
13 |
def load_trajectory_data(traj_file: str, char_file: str, num_cams: int = 30) -> Dict:
|
@@ -53,137 +55,113 @@ class ETLogger(SimulationLogger):
|
|
53 |
super().__init__()
|
54 |
rr.init("et_visualization")
|
55 |
rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True)
|
56 |
-
|
57 |
-
# Define default camera parameters
|
58 |
-
self.camera_width = 640 # default width
|
59 |
-
self.camera_height = 480 # default height
|
60 |
-
self.focal_length = 500 # default focal length
|
61 |
self.K = np.array([
|
62 |
-
[
|
63 |
-
[0,
|
64 |
[0, 0, 1]
|
65 |
])
|
66 |
|
67 |
def log_trajectory(self, trajectory: np.ndarray, padding_mask: np.ndarray):
|
68 |
-
"""Log camera trajectory
|
69 |
valid_frames = int(padding_mask.sum())
|
70 |
valid_trajectory = trajectory[:valid_frames]
|
71 |
|
72 |
-
# Log trajectory points
|
73 |
positions = valid_trajectory[:, :3, 3]
|
74 |
-
colors = np.zeros((len(positions), 4))
|
75 |
-
colors[:, :3] = plt.cm.rainbow(
|
76 |
-
np.linspace(0, 1, len(positions)))[:, :3]
|
77 |
-
colors[:, 3] = 1.0 # Set alpha to 1
|
78 |
-
|
79 |
rr.log(
|
80 |
"world/trajectory/points",
|
81 |
rr.Points3D(
|
82 |
positions,
|
83 |
-
colors=
|
84 |
),
|
85 |
timeless=True
|
86 |
)
|
87 |
|
88 |
-
# Log trajectory line
|
89 |
if len(positions) > 1:
|
90 |
lines = np.stack([positions[:-1], positions[1:]], axis=1)
|
91 |
-
line_colors = np.zeros((len(lines), 4))
|
92 |
-
line_colors[:, :3] = plt.cm.rainbow(
|
93 |
-
np.linspace(0, 1, len(lines)))[:, :3]
|
94 |
-
line_colors[:, 3] = 1.0
|
95 |
-
|
96 |
rr.log(
|
97 |
"world/trajectory/line",
|
98 |
rr.LineStrips3D(
|
99 |
lines,
|
100 |
-
colors=
|
101 |
),
|
102 |
timeless=True
|
103 |
)
|
104 |
|
105 |
-
# Log
|
106 |
-
for
|
107 |
-
#
|
108 |
-
|
109 |
-
rotation_matrix = valid_trajectory[i, :3, :3]
|
110 |
-
rotation_quat = Rotation.from_matrix(rotation_matrix).as_quat()
|
111 |
|
112 |
-
#
|
113 |
-
|
|
|
|
|
114 |
|
115 |
-
# Log camera
|
116 |
rr.log(
|
117 |
-
f"world/
|
118 |
rr.Transform3D(
|
119 |
translation=translation,
|
120 |
-
rotation=rr.Quaternion(xyzw=
|
121 |
-
)
|
122 |
)
|
123 |
|
124 |
-
#
|
125 |
rr.log(
|
126 |
-
f"world/
|
127 |
rr.Pinhole(
|
128 |
image_from_camera=self.K,
|
129 |
-
width=
|
130 |
-
height=
|
131 |
-
focal_length=self.focal_length,
|
132 |
),
|
133 |
)
|
134 |
|
135 |
-
# Add coordinate axes for each camera
|
136 |
-
rr.log(
|
137 |
-
f"world/cameras/camera_{i}/axes",
|
138 |
-
rr.Arrows3D(
|
139 |
-
origins=np.zeros((3, 3)),
|
140 |
-
vectors=np.eye(3) * 0.5, # 0.5 meter long axes
|
141 |
-
colors=[[1, 0, 0, 1], [0, 1, 0, 1], [
|
142 |
-
0, 0, 1, 1]] # RGB colors for XYZ
|
143 |
-
)
|
144 |
-
)
|
145 |
-
|
146 |
def log_character(self, char_feature: np.ndarray, padding_mask: np.ndarray):
|
147 |
-
"""Log character feature visualization
|
148 |
valid_frames = int(padding_mask.sum())
|
149 |
valid_char = char_feature[:, :valid_frames]
|
150 |
|
151 |
if valid_char.shape[0] > 0:
|
152 |
-
# Create gradient colors for character points
|
153 |
-
num_points = valid_char.reshape(-1, 3).shape[0]
|
154 |
-
colors = np.zeros((num_points, 4))
|
155 |
-
colors[:, 0] = 0.8 # Red component
|
156 |
-
colors[:, 1] = 0.2 # Green component
|
157 |
-
colors[:, 2] = np.linspace(0.2, 0.8, num_points) # Blue gradient
|
158 |
-
colors[:, 3] = 1.0 # Alpha
|
159 |
-
|
160 |
rr.log(
|
161 |
"world/character",
|
162 |
rr.Points3D(
|
163 |
valid_char.reshape(-1, 3),
|
164 |
-
colors=
|
165 |
-
|
166 |
),
|
167 |
timeless=True
|
168 |
)
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tempfile
|
2 |
+
import os
|
3 |
+
import spaces
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
|
|
10 |
from typing import Optional, Dict
|
11 |
from visualization.logger import SimulationLogger
|
12 |
from scipy.spatial.transform import Rotation
|
|
|
13 |
|
14 |
|
15 |
def load_trajectory_data(traj_file: str, char_file: str, num_cams: int = 30) -> Dict:
|
|
|
55 |
super().__init__()
|
56 |
rr.init("et_visualization")
|
57 |
rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True)
|
58 |
+
# Default camera intrinsics
|
|
|
|
|
|
|
|
|
59 |
self.K = np.array([
|
60 |
+
[500, 0, 320],
|
61 |
+
[0, 500, 240],
|
62 |
[0, 0, 1]
|
63 |
])
|
64 |
|
65 |
def log_trajectory(self, trajectory: np.ndarray, padding_mask: np.ndarray):
|
66 |
+
"""Log camera trajectory."""
|
67 |
valid_frames = int(padding_mask.sum())
|
68 |
valid_trajectory = trajectory[:valid_frames]
|
69 |
|
70 |
+
# Log trajectory points
|
71 |
positions = valid_trajectory[:, :3, 3]
|
|
|
|
|
|
|
|
|
|
|
72 |
rr.log(
|
73 |
"world/trajectory/points",
|
74 |
rr.Points3D(
|
75 |
positions,
|
76 |
+
colors=np.full((len(positions), 4), [0.0, 0.8, 0.8, 1.0])
|
77 |
),
|
78 |
timeless=True
|
79 |
)
|
80 |
|
81 |
+
# Log trajectory line
|
82 |
if len(positions) > 1:
|
83 |
lines = np.stack([positions[:-1], positions[1:]], axis=1)
|
|
|
|
|
|
|
|
|
|
|
84 |
rr.log(
|
85 |
"world/trajectory/line",
|
86 |
rr.LineStrips3D(
|
87 |
lines,
|
88 |
+
colors=[(0.0, 0.8, 0.8, 1.0)]
|
89 |
),
|
90 |
timeless=True
|
91 |
)
|
92 |
|
93 |
+
# Log cameras
|
94 |
+
for k in range(valid_frames):
|
95 |
+
# Set time sequence
|
96 |
+
rr.set_time_sequence("frame_idx", k)
|
|
|
|
|
97 |
|
98 |
+
# Get camera pose
|
99 |
+
translation = valid_trajectory[k, :3, 3]
|
100 |
+
rotation_q = Rotation.from_matrix(
|
101 |
+
valid_trajectory[k, :3, :3]).as_quat()
|
102 |
|
103 |
+
# Log camera transform
|
104 |
rr.log(
|
105 |
+
f"world/camera",
|
106 |
rr.Transform3D(
|
107 |
translation=translation,
|
108 |
+
rotation=rr.Quaternion(xyzw=rotation_q),
|
109 |
+
),
|
110 |
)
|
111 |
|
112 |
+
# Log camera frustum
|
113 |
rr.log(
|
114 |
+
f"world/camera/image",
|
115 |
rr.Pinhole(
|
116 |
image_from_camera=self.K,
|
117 |
+
width=640,
|
118 |
+
height=480,
|
|
|
119 |
),
|
120 |
)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
def log_character(self, char_feature: np.ndarray, padding_mask: np.ndarray):
|
123 |
+
"""Log character feature visualization."""
|
124 |
valid_frames = int(padding_mask.sum())
|
125 |
valid_char = char_feature[:, :valid_frames]
|
126 |
|
127 |
if valid_char.shape[0] > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
rr.log(
|
129 |
"world/character",
|
130 |
rr.Points3D(
|
131 |
valid_char.reshape(-1, 3),
|
132 |
+
colors=np.full(
|
133 |
+
(valid_char.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0])
|
134 |
),
|
135 |
timeless=True
|
136 |
)
|
137 |
|
138 |
+
|
139 |
+
@spaces.GPU
|
140 |
+
def visualize_et_data(traj_file: str, char_file: str) -> Optional[str]:
|
141 |
+
"""Visualize E.T. dataset using Rerun."""
|
142 |
+
try:
|
143 |
+
# Load data
|
144 |
+
data = load_trajectory_data(traj_file, char_file)
|
145 |
+
|
146 |
+
# Create temporary file for RRD
|
147 |
+
temp_dir = tempfile.mkdtemp()
|
148 |
+
rrd_path = os.path.join(temp_dir, "et_visualization.rrd")
|
149 |
+
|
150 |
+
# Initialize logger and log data
|
151 |
+
logger = ETLogger()
|
152 |
+
logger.log_trajectory(
|
153 |
+
data["raw_matrix_trajectory"].numpy(),
|
154 |
+
data["padding_mask"].numpy()
|
155 |
+
)
|
156 |
+
logger.log_character(
|
157 |
+
data["char_feat"].numpy(),
|
158 |
+
data["padding_mask"].numpy()
|
159 |
+
)
|
160 |
+
|
161 |
+
# Save visualization
|
162 |
+
rr.save(rrd_path)
|
163 |
+
return rrd_path
|
164 |
+
|
165 |
+
except Exception as e:
|
166 |
+
print(f"Error visualizing E.T. data: {str(e)}")
|
167 |
+
return None
|