File size: 4,586 Bytes
98a8f68
 
 
161b8e5
 
 
 
 
 
 
 
289635e
161b8e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44f6fd9
289635e
98a8f68
 
289635e
 
 
161b8e5
 
 
 
 
 
 
 
 
98a8f68
161b8e5
 
 
 
 
 
 
 
 
 
98a8f68
161b8e5
 
 
 
98a8f68
44f6fd9
98a8f68
289635e
98a8f68
 
 
289635e
 
98a8f68
289635e
 
98a8f68
 
289635e
 
 
98a8f68
289635e
 
98a8f68
 
289635e
 
 
161b8e5
 
 
 
 
 
 
 
 
98a8f68
 
161b8e5
 
 
 
98a8f68
 
 
 
44f6fd9
98a8f68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import tempfile
import os
import spaces
import numpy as np
import torch
import torch.nn.functional as F
from evo.tools.file_interface import read_kitti_poses_file
from pathlib import Path
import rerun as rr
from typing import Optional, Dict
from visualization.logger import SimulationLogger
from scipy.spatial.transform import Rotation


def load_trajectory_data(traj_file: str, char_file: str, num_cams: int = 30) -> Dict:
    trajectory = read_kitti_poses_file(traj_file)
    matrix_trajectory = torch.from_numpy(
        np.array(trajectory.poses_se3)).to(torch.float32)

    raw_trans = torch.clone(matrix_trajectory[:, :3, 3])
    raw_rot = matrix_trajectory[:, :3, :3]

    rot6d = raw_rot[:, :, :2].permute(0, 2, 1).reshape(-1, 6)
    trajectory_feature = torch.hstack([rot6d, raw_trans]).permute(1, 0)

    padded_trajectory_feature = F.pad(
        trajectory_feature,
        (0, num_cams - trajectory_feature.shape[1])
    )

    padding_mask = torch.ones((num_cams))
    padding_mask[trajectory_feature.shape[1]:] = 0

    char_feature = torch.from_numpy(np.load(char_file)).to(torch.float32)
    padding_size = num_cams - char_feature.shape[0]
    padded_char_feature = F.pad(
        char_feature, (0, 0, 0, padding_size)).permute(1, 0)

    return {
        "traj_filename": Path(traj_file).name,
        "char_filename": Path(char_file).name,
        "traj_feat": padded_trajectory_feature,
        "char_feat": padded_char_feature,
        "padding_mask": padding_mask,
        "raw_matrix_trajectory": matrix_trajectory
    }


class ETLogger(SimulationLogger):
    def __init__(self):
        super().__init__()
        rr.init("et_visualization")
        rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Y_UP, timeless=True)

        self.K = np.array([
            [500, 0, 320],
            [0, 500, 240],
            [0, 0, 1]
        ])

    def log_trajectory(self, trajectory: np.ndarray, padding_mask: np.ndarray):
        valid_frames = int(padding_mask.sum())
        valid_trajectory = trajectory[:valid_frames]

        positions = valid_trajectory[:, :3, 3]
        rr.log(
            "world/trajectory/points",
            rr.Points3D(
                positions,
                colors=np.full((len(positions), 4), [0.0, 0.8, 0.8, 1.0])
            ),
            timeless=True
        )

        if len(positions) > 1:
            lines = np.stack([positions[:-1], positions[1:]], axis=1)
            rr.log(
                "world/trajectory/line",
                rr.LineStrips3D(
                    lines,
                    colors=[(0.0, 0.8, 0.8, 1.0)]
                ),
                timeless=True
            )

        for k in range(valid_frames):

            rr.set_time_sequence("frame_idx", k)

            translation = valid_trajectory[k, :3, 3]
            rotation_q = Rotation.from_matrix(
                valid_trajectory[k, :3, :3]).as_quat()

            rr.log(
                f"world/camera",
                rr.Transform3D(
                    translation=translation,
                    rotation=rr.Quaternion(xyzw=rotation_q),
                ),
            )

            rr.log(
                f"world/camera/image",
                rr.Pinhole(
                    image_from_camera=self.K,
                    width=640,
                    height=480,
                ),
            )

    def log_character(self, char_feature: np.ndarray, padding_mask: np.ndarray):
        valid_frames = int(padding_mask.sum())
        valid_char = char_feature[:, :valid_frames]

        if valid_char.shape[0] > 0:
            rr.log(
                "world/character",
                rr.Points3D(
                    valid_char.reshape(-1, 3),
                    colors=np.full(
                        (valid_char.reshape(-1, 3).shape[0], 4), [0.8, 0.2, 0.2, 1.0])
                ),
                timeless=True
            )


@spaces.GPU
def visualize_et_data(traj_file: str, char_file: str) -> Optional[str]:
    try:

        data = load_trajectory_data(traj_file, char_file)

        temp_dir = tempfile.mkdtemp()
        rrd_path = os.path.join(temp_dir, "et_visualization.rrd")

        logger = ETLogger()
        logger.log_trajectory(
            data["raw_matrix_trajectory"].numpy(),
            data["padding_mask"].numpy()
        )
        logger.log_character(
            data["char_feat"].numpy(),
            data["padding_mask"].numpy()
        )

        rr.save(rrd_path)
        return rrd_path

    except Exception as e:
        print(f"Error visualizing E.T. data: {str(e)}")
        return None