abreza commited on
Commit
c065381
·
1 Parent(s): 5481052

remove padding

Browse files
Files changed (1) hide show
  1. visualization/et_visualizer.py +20 -53
visualization/et_visualizer.py CHANGED
@@ -12,37 +12,18 @@ from visualization.logger import SimulationLogger
12
  from scipy.spatial.transform import Rotation
13
 
14
 
15
- def load_trajectory_data(traj_file: str, char_file: str, num_cams: int = 30) -> Dict:
16
  trajectory = read_kitti_poses_file(traj_file)
17
  matrix_trajectory = torch.from_numpy(
18
  np.array(trajectory.poses_se3)).to(torch.float32)
19
 
20
- raw_trans = torch.clone(matrix_trajectory[:, :3, 3])
21
- raw_rot = matrix_trajectory[:, :3, :3]
22
-
23
- rot6d = raw_rot[:, :, :2].permute(0, 2, 1).reshape(-1, 6)
24
- trajectory_feature = torch.hstack([rot6d, raw_trans]).permute(1, 0)
25
-
26
- padded_trajectory_feature = F.pad(
27
- trajectory_feature,
28
- (0, num_cams - trajectory_feature.shape[1])
29
- )
30
-
31
- padding_mask = torch.ones((num_cams))
32
- padding_mask[trajectory_feature.shape[1]:] = 0
33
-
34
  char_feature = torch.from_numpy(np.load(char_file)).to(torch.float32)
35
- padding_size = num_cams - char_feature.shape[0]
36
- padded_char_feature = F.pad(
37
- char_feature, (0, 0, 0, padding_size)).permute(1, 0)
38
 
39
  return {
40
  "traj_filename": Path(traj_file).name,
41
  "char_filename": Path(char_file).name,
42
- "traj_feat": padded_trajectory_feature,
43
- "char_feat": padded_char_feature,
44
- "padding_mask": padding_mask,
45
- "raw_matrix_trajectory": matrix_trajectory
46
  }
47
 
48
 
@@ -58,11 +39,8 @@ class ETLogger(SimulationLogger):
58
  [0, 0, 1]
59
  ])
60
 
61
- def log_trajectory(self, trajectory: np.ndarray, padding_mask: np.ndarray):
62
- valid_frames = int(padding_mask.sum())
63
- valid_trajectory = trajectory[:valid_frames]
64
-
65
- positions = valid_trajectory[:, :3, 3]
66
  rr.log(
67
  "world/trajectory/points",
68
  rr.Points3D(
@@ -83,13 +61,12 @@ class ETLogger(SimulationLogger):
83
  timeless=True
84
  )
85
 
86
- for k in range(valid_frames):
87
-
88
  rr.set_time_sequence("frame_idx", k)
89
 
90
- translation = valid_trajectory[k, :3, 3]
91
  rotation_q = Rotation.from_matrix(
92
- valid_trajectory[k, :3, :3]).as_quat()
93
 
94
  rr.log(
95
  f"world/camera",
@@ -108,20 +85,16 @@ class ETLogger(SimulationLogger):
108
  ),
109
  )
110
 
111
- def log_character(self, char_feature: np.ndarray, padding_mask: np.ndarray):
112
- valid_frames = int(padding_mask.sum())
113
- valid_char = char_feature[:, :valid_frames]
114
-
115
- if valid_char.shape[0] > 0:
116
- rr.log(
117
- "world/character",
118
- rr.Points3D(
119
- valid_char.reshape(-1, 3),
120
- colors=np.full(
121
- (valid_char.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0])
122
- ),
123
- timeless=True
124
- )
125
 
126
 
127
  @spaces.GPU
@@ -134,14 +107,8 @@ def visualize_et_data(traj_file: str, char_file: str) -> Optional[str]:
134
  rrd_path = os.path.join(temp_dir, "et_visualization.rrd")
135
 
136
  logger = ETLogger()
137
- logger.log_trajectory(
138
- data["raw_matrix_trajectory"].numpy(),
139
- data["padding_mask"].numpy()
140
- )
141
- logger.log_character(
142
- data["char_feat"].numpy(),
143
- data["padding_mask"].numpy()
144
- )
145
 
146
  rr.save(rrd_path)
147
  return rrd_path
 
12
  from scipy.spatial.transform import Rotation
13
 
14
 
15
+ def load_trajectory_data(traj_file: str, char_file: str) -> Dict:
16
  trajectory = read_kitti_poses_file(traj_file)
17
  matrix_trajectory = torch.from_numpy(
18
  np.array(trajectory.poses_se3)).to(torch.float32)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  char_feature = torch.from_numpy(np.load(char_file)).to(torch.float32)
 
 
 
21
 
22
  return {
23
  "traj_filename": Path(traj_file).name,
24
  "char_filename": Path(char_file).name,
25
+ "char_feat": char_feature,
26
+ "matrix_trajectory": matrix_trajectory
 
 
27
  }
28
 
29
 
 
39
  [0, 0, 1]
40
  ])
41
 
42
+ def log_trajectory(self, trajectory: np.ndarray):
43
+ positions = trajectory[:, :3, 3]
 
 
 
44
  rr.log(
45
  "world/trajectory/points",
46
  rr.Points3D(
 
61
  timeless=True
62
  )
63
 
64
+ for k in range(len(trajectory)):
 
65
  rr.set_time_sequence("frame_idx", k)
66
 
67
+ translation = trajectory[k, :3, 3]
68
  rotation_q = Rotation.from_matrix(
69
+ trajectory[k, :3, :3]).as_quat()
70
 
71
  rr.log(
72
  f"world/camera",
 
85
  ),
86
  )
87
 
88
+ def log_character(self, char_feature: np.ndarray):
89
+ rr.log(
90
+ "world/character",
91
+ rr.Points3D(
92
+ char_feature.reshape(-1, 3),
93
+ colors=np.full(
94
+ (char_feature.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0])
95
+ ),
96
+ timeless=True
97
+ )
 
 
 
 
98
 
99
 
100
  @spaces.GPU
 
107
  rrd_path = os.path.join(temp_dir, "et_visualization.rrd")
108
 
109
  logger = ETLogger()
110
+ logger.log_trajectory(data["matrix_trajectory"].numpy())
111
+ logger.log_character(data["char_feat"].numpy())
 
 
 
 
 
 
112
 
113
  rr.save(rrd_path)
114
  return rrd_path