abreza commited on
Commit
98a8f68
·
1 Parent(s): 289635e

fix et issue

Browse files
Files changed (1) hide show
  1. visualization/et_visualizer.py +60 -82
visualization/et_visualizer.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import numpy as np
2
  import torch
3
  import torch.nn.functional as F
@@ -7,7 +10,6 @@ import rerun as rr
7
  from typing import Optional, Dict
8
  from visualization.logger import SimulationLogger
9
  from scipy.spatial.transform import Rotation
10
- from rerun.components import Material
11
 
12
 
13
  def load_trajectory_data(traj_file: str, char_file: str, num_cams: int = 30) -> Dict:
@@ -53,137 +55,113 @@ class ETLogger(SimulationLogger):
53
  super().__init__()
54
  rr.init("et_visualization")
55
  rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True)
56
-
57
- # Define default camera parameters
58
- self.camera_width = 640 # default width
59
- self.camera_height = 480 # default height
60
- self.focal_length = 500 # default focal length
61
  self.K = np.array([
62
- [self.focal_length, 0, self.camera_width/2],
63
- [0, self.focal_length, self.camera_height/2],
64
  [0, 0, 1]
65
  ])
66
 
67
  def log_trajectory(self, trajectory: np.ndarray, padding_mask: np.ndarray):
68
- """Log camera trajectory with enhanced visualization."""
69
  valid_frames = int(padding_mask.sum())
70
  valid_trajectory = trajectory[:valid_frames]
71
 
72
- # Log trajectory points with rainbow coloring
73
  positions = valid_trajectory[:, :3, 3]
74
- colors = np.zeros((len(positions), 4))
75
- colors[:, :3] = plt.cm.rainbow(
76
- np.linspace(0, 1, len(positions)))[:, :3]
77
- colors[:, 3] = 1.0 # Set alpha to 1
78
-
79
  rr.log(
80
  "world/trajectory/points",
81
  rr.Points3D(
82
  positions,
83
- colors=colors
84
  ),
85
  timeless=True
86
  )
87
 
88
- # Log trajectory line with gradient color
89
  if len(positions) > 1:
90
  lines = np.stack([positions[:-1], positions[1:]], axis=1)
91
- line_colors = np.zeros((len(lines), 4))
92
- line_colors[:, :3] = plt.cm.rainbow(
93
- np.linspace(0, 1, len(lines)))[:, :3]
94
- line_colors[:, 3] = 1.0
95
-
96
  rr.log(
97
  "world/trajectory/line",
98
  rr.LineStrips3D(
99
  lines,
100
- colors=line_colors
101
  ),
102
  timeless=True
103
  )
104
 
105
- # Log camera frustums
106
- for i in range(valid_frames):
107
- # Get camera position and rotation
108
- translation = valid_trajectory[i, :3, 3]
109
- rotation_matrix = valid_trajectory[i, :3, :3]
110
- rotation_quat = Rotation.from_matrix(rotation_matrix).as_quat()
111
 
112
- # Set time sequence for animation
113
- rr.set_time_sequence("frame_idx", i)
 
 
114
 
115
- # Log camera frustum
116
  rr.log(
117
- f"world/cameras/camera_{i}",
118
  rr.Transform3D(
119
  translation=translation,
120
- rotation=rr.Quaternion(xyzw=rotation_quat),
121
- )
122
  )
123
 
124
- # Add camera visualization
125
  rr.log(
126
- f"world/cameras/camera_{i}/frustum",
127
  rr.Pinhole(
128
  image_from_camera=self.K,
129
- width=self.camera_width,
130
- height=self.camera_height,
131
- focal_length=self.focal_length,
132
  ),
133
  )
134
 
135
- # Add coordinate axes for each camera
136
- rr.log(
137
- f"world/cameras/camera_{i}/axes",
138
- rr.Arrows3D(
139
- origins=np.zeros((3, 3)),
140
- vectors=np.eye(3) * 0.5, # 0.5 meter long axes
141
- colors=[[1, 0, 0, 1], [0, 1, 0, 1], [
142
- 0, 0, 1, 1]] # RGB colors for XYZ
143
- )
144
- )
145
-
146
  def log_character(self, char_feature: np.ndarray, padding_mask: np.ndarray):
147
- """Log character feature visualization with enhanced appearance."""
148
  valid_frames = int(padding_mask.sum())
149
  valid_char = char_feature[:, :valid_frames]
150
 
151
  if valid_char.shape[0] > 0:
152
- # Create gradient colors for character points
153
- num_points = valid_char.reshape(-1, 3).shape[0]
154
- colors = np.zeros((num_points, 4))
155
- colors[:, 0] = 0.8 # Red component
156
- colors[:, 1] = 0.2 # Green component
157
- colors[:, 2] = np.linspace(0.2, 0.8, num_points) # Blue gradient
158
- colors[:, 3] = 1.0 # Alpha
159
-
160
  rr.log(
161
  "world/character",
162
  rr.Points3D(
163
  valid_char.reshape(-1, 3),
164
- colors=colors,
165
- radii=0.05 # Add point size for better visibility
166
  ),
167
  timeless=True
168
  )
169
 
170
- # Add a semi-transparent hull around character points
171
- try:
172
- from scipy.spatial import ConvexHull
173
- points = valid_char.reshape(-1, 3)
174
- hull = ConvexHull(points)
175
-
176
- rr.log(
177
- "world/character/hull",
178
- rr.Mesh3D(
179
- vertex_positions=points[hull.vertices],
180
- indices=hull.simplices,
181
- mesh_material=Material(
182
- # Semi-transparent red
183
- albedo_factor=[0.8, 0.2, 0.2, 0.3]
184
- )
185
- ),
186
- timeless=True
187
- )
188
- except Exception:
189
- pass # Skip hull visualization if it fails
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import os
3
+ import spaces
4
  import numpy as np
5
  import torch
6
  import torch.nn.functional as F
 
10
  from typing import Optional, Dict
11
  from visualization.logger import SimulationLogger
12
  from scipy.spatial.transform import Rotation
 
13
 
14
 
15
  def load_trajectory_data(traj_file: str, char_file: str, num_cams: int = 30) -> Dict:
 
55
  super().__init__()
56
  rr.init("et_visualization")
57
  rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True)
58
+ # Default camera intrinsics
 
 
 
 
59
  self.K = np.array([
60
+ [500, 0, 320],
61
+ [0, 500, 240],
62
  [0, 0, 1]
63
  ])
64
 
65
  def log_trajectory(self, trajectory: np.ndarray, padding_mask: np.ndarray):
66
+ """Log camera trajectory."""
67
  valid_frames = int(padding_mask.sum())
68
  valid_trajectory = trajectory[:valid_frames]
69
 
70
+ # Log trajectory points
71
  positions = valid_trajectory[:, :3, 3]
 
 
 
 
 
72
  rr.log(
73
  "world/trajectory/points",
74
  rr.Points3D(
75
  positions,
76
+ colors=np.full((len(positions), 4), [0.0, 0.8, 0.8, 1.0])
77
  ),
78
  timeless=True
79
  )
80
 
81
+ # Log trajectory line
82
  if len(positions) > 1:
83
  lines = np.stack([positions[:-1], positions[1:]], axis=1)
 
 
 
 
 
84
  rr.log(
85
  "world/trajectory/line",
86
  rr.LineStrips3D(
87
  lines,
88
+ colors=[(0.0, 0.8, 0.8, 1.0)]
89
  ),
90
  timeless=True
91
  )
92
 
93
+ # Log cameras
94
+ for k in range(valid_frames):
95
+ # Set time sequence
96
+ rr.set_time_sequence("frame_idx", k)
 
 
97
 
98
+ # Get camera pose
99
+ translation = valid_trajectory[k, :3, 3]
100
+ rotation_q = Rotation.from_matrix(
101
+ valid_trajectory[k, :3, :3]).as_quat()
102
 
103
+ # Log camera transform
104
  rr.log(
105
+ f"world/camera",
106
  rr.Transform3D(
107
  translation=translation,
108
+ rotation=rr.Quaternion(xyzw=rotation_q),
109
+ ),
110
  )
111
 
112
+ # Log camera frustum
113
  rr.log(
114
+ f"world/camera/image",
115
  rr.Pinhole(
116
  image_from_camera=self.K,
117
+ width=640,
118
+ height=480,
 
119
  ),
120
  )
121
 
 
 
 
 
 
 
 
 
 
 
 
122
  def log_character(self, char_feature: np.ndarray, padding_mask: np.ndarray):
123
+ """Log character feature visualization."""
124
  valid_frames = int(padding_mask.sum())
125
  valid_char = char_feature[:, :valid_frames]
126
 
127
  if valid_char.shape[0] > 0:
 
 
 
 
 
 
 
 
128
  rr.log(
129
  "world/character",
130
  rr.Points3D(
131
  valid_char.reshape(-1, 3),
132
+ colors=np.full(
133
+ (valid_char.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0])
134
  ),
135
  timeless=True
136
  )
137
 
138
+
139
+ @spaces.GPU
140
+ def visualize_et_data(traj_file: str, char_file: str) -> Optional[str]:
141
+ """Visualize E.T. dataset using Rerun."""
142
+ try:
143
+ # Load data
144
+ data = load_trajectory_data(traj_file, char_file)
145
+
146
+ # Create temporary file for RRD
147
+ temp_dir = tempfile.mkdtemp()
148
+ rrd_path = os.path.join(temp_dir, "et_visualization.rrd")
149
+
150
+ # Initialize logger and log data
151
+ logger = ETLogger()
152
+ logger.log_trajectory(
153
+ data["raw_matrix_trajectory"].numpy(),
154
+ data["padding_mask"].numpy()
155
+ )
156
+ logger.log_character(
157
+ data["char_feat"].numpy(),
158
+ data["padding_mask"].numpy()
159
+ )
160
+
161
+ # Save visualization
162
+ rr.save(rrd_path)
163
+ return rrd_path
164
+
165
+ except Exception as e:
166
+ print(f"Error visualizing E.T. data: {str(e)}")
167
+ return None