File size: 9,299 Bytes
febf487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
"""Visualization utilities for 3D reconstruction results using Viser.

Provides tools to visualize predicted camera poses, 3D point clouds, and confidence
thresholding through an interactive web interface.
"""

import time
from pathlib import Path
from typing import List, Optional

import numpy as np
import tyro
from tqdm.auto import tqdm
import cv2
import viser
import viser.transforms as tf
import glob
import os
from scipy.spatial.transform import Rotation as R
# from camera import closed_form_inverse_se3
import torch
import threading

def viser_wrapper(
    pred_dict: dict,
    port: int = None,
    init_conf_threshold: float = 3.0,
) -> None:
    """Visualize
    Args:
        pred_dict: Dictionary containing predictions
        port: Optional port number for the viser server. If None, a random port will be used.
    """
    print(f"Starting viser server on port {port}")  # Debug print
    
    server = viser.ViserServer(host="0.0.0.0", port=port)
    # server = viser.ViserServer(port=port)
    server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")

    # Unpack and preprocess inputs
    images = pred_dict["images"]
    world_points = pred_dict["pred_world_points"]
    conf = pred_dict["pred_world_points_conf"]
    extrinsics = pred_dict["last_pred_extrinsic"]
    
    # Handle batch dimension if present
    if len(images.shape) > 4:
        images = images[0]
        world_points = world_points[0]
        conf = conf[0]
        extrinsics = extrinsics[0]

    colors = images.transpose(0, 2, 3, 1)  # Convert to (B, H, W, C)

    # Reshape for visualization
    S, H, W, _ = world_points.shape
    colors = (colors.reshape(-1, 3) * 255).astype(np.uint8)  # Convert to 0-255 range
    conf = conf.reshape(-1)
    world_points = world_points.reshape(-1, 3)

    # Calculate camera poses in world coordinates
    cam_to_world = closed_form_inverse_se3(extrinsics)
    extrinsics = cam_to_world[:, :3, :]

    # Center scene for better visualization
    scene_center = np.mean(world_points, axis=0)
    world_points -= scene_center
    extrinsics[..., -1] -= scene_center

    # set points3d as world_points
    points = world_points
    

    # frame_mask 

    frame_indices = np.arange(S)
    frame_indices = frame_indices[:, None, None]  # Shape: (S, 1, 1, 1)
    frame_indices = np.tile(frame_indices, (1, H, W))  # Shape: (S, H, W, 3)
    frame_indices = frame_indices.reshape(-1)

    ############################################################
    ############################################################



    gui_points_conf = server.gui.add_slider(
        "Confidence Thres",
        min=0.1,
        max=20,
        step=0.05,
        initial_value=init_conf_threshold,
    )
    


    gui_point_size = server.gui.add_slider(
        "Point size", min=0.00001, max=0.01, step=0.0001, initial_value=0.00001
    )

    # Change from "Frame Selector" to more descriptive name
    gui_frame_selector = server.gui.add_dropdown(
        "Filter by Frame",  # More action-oriented name
        options=["All"] + [str(i) for i in range(S)],
        initial_value="All",
    )

    # Initial mask shows all points passing confidence threshold
    init_conf_mask = conf > init_conf_threshold
    point_cloud = server.scene.add_point_cloud(
        name="viser_pcd",
        points=points[init_conf_mask],
        colors=colors[init_conf_mask],
        point_size=gui_point_size.value,
        point_shape="circle",
    )



    frames: List[viser.FrameHandle] = []

    def visualize_frames(extrinsics: np.ndarray, intrinsics: np.ndarray, images: np.ndarray) -> None:
        """Send all COLMAP elements to viser for visualization. This could be optimized
        a ton!"""
        extrinsics = np.copy(extrinsics)
        # Remove existing image frames.
        for frame in frames:
            frame.remove()
        frames.clear()


        def attach_callback(
            frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle
        ) -> None:
            @frustum.on_click
            def _(_) -> None:
                for client in server.get_clients().values():
                    client.camera.wxyz = frame.wxyz
                    client.camera.position = frame.position

        img_ids = sorted(range(S))
        for img_id in tqdm(img_ids):

            cam_to_world = extrinsics[img_id]

            T_world_camera = tf.SE3.from_matrix(cam_to_world)

            ratio = 1
            frame = server.scene.add_frame(
                f"frame_{img_id}",
                wxyz=T_world_camera.rotation().wxyz,
                position=T_world_camera.translation(),
                axes_length=0.05/ratio,
                axes_radius=0.002/ratio,
                origin_radius = 0.002/ratio
            )
            
            
            frames.append(frame)

            img = images[img_id]
            img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
            # import pdb;pdb.set_trace()
            H, W = img.shape[:2]
            # fy = intrinsics[img_id, 1, 1] * H
            fy = 1.1 * H
            image = img
            # image = image[::downsample_factor, ::downsample_factor]
            frustum = server.scene.add_camera_frustum(
                f"frame_{img_id}/frustum",
                fov=2 * np.arctan2(H / 2, fy),
                aspect=W / H,
                scale=0.05/ratio,
                image=image,
                line_width=1.0,
                # line_thickness=0.01,
            )
            
            attach_callback(frustum, frame)


    @gui_points_conf.on_update
    def _(_) -> None:
        conf_mask = conf > gui_points_conf.value
        frame_mask = np.ones_like(conf_mask)  # Default to all frames
        if gui_frame_selector.value != "All":
            selected_idx = int(gui_frame_selector.value)
            frame_mask = (frame_indices == selected_idx)
        
        combined_mask = conf_mask & frame_mask
        point_cloud.points = points[combined_mask]
        point_cloud.colors = colors[combined_mask]
        
    @gui_point_size.on_update
    def _(_) -> None:
        point_cloud.point_size = gui_point_size.value

    @gui_frame_selector.on_update
    def _(_) -> None:
        """Update points based on frame selection."""
        conf_mask = conf > gui_points_conf.value
        
        if gui_frame_selector.value == "All":
            # Show all points passing confidence threshold
            point_cloud.points = points[conf_mask]
            point_cloud.colors = colors[conf_mask]
        else:
            # Show only selected frame's points
            selected_idx = int(gui_frame_selector.value)
            frame_mask = (frame_indices == selected_idx)
            combined_mask = conf_mask & frame_mask
            point_cloud.points = points[combined_mask]
            point_cloud.colors = colors[combined_mask]

            # Move camera to selected frame
            # if 0 <= selected_idx < len(frames):
            #     selected_frame = frames[selected_idx]
            #     for client in server.get_clients().values():
            #         client.camera.wxyz = selected_frame.wxyz
            #         client.camera.position = selected_frame.position


    # Initial visualization
    visualize_frames(extrinsics, None, images)
        
    # # Start server update loop in a background thread
    def server_loop():
        while True:
            time.sleep(1e-3)  # Small sleep to prevent CPU hogging

    thread = threading.Thread(target=server_loop, daemon=True)
    thread.start()
    
    

def closed_form_inverse_se3(se3, R=None, T=None):
    """
    Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.

    If `R` and `T` are provided, they must correspond to the rotation and translation
    components of `se3`. Otherwise, they will be extracted from `se3`.

    Args:
        se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
        R (optional): Nx3x3 array or tensor of rotation matrices.
        T (optional): Nx3x1 array or tensor of translation vectors.

    Returns:
        Inverted SE3 matrices with the same type and device as `se3`.

    Shapes:
        se3: (N, 4, 4)
        R: (N, 3, 3)
        T: (N, 3, 1)
    """
    # Check if se3 is a numpy array or a torch tensor
    is_numpy = isinstance(se3, np.ndarray)

    # Validate shapes
    if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
        raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")

    # Extract R and T if not provided
    if R is None:
        R = se3[:, :3, :3]  # (N,3,3)
    if T is None:
        T = se3[:, :3, 3:]  # (N,3,1)

    # Transpose R
    if is_numpy:
        # Compute the transpose of the rotation for NumPy
        R_transposed = np.transpose(R, (0, 2, 1))
        # -R^T t for NumPy
        top_right = -np.matmul(R_transposed, T)
        inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
    else:
        R_transposed = R.transpose(1, 2)  # (N,3,3)
        top_right = -torch.bmm(R_transposed, T)  # (N,3,1)
        inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
        inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)

    inverted_matrix[:, :3, :3] = R_transposed
    inverted_matrix[:, :3, 3:] = top_right

    return inverted_matrix