Commit
·
2529861
1
Parent(s):
febf487
update
Browse files- app.py +129 -88
- demo_hf.py +15 -12
- gradio_util.py +127 -129
app.py
CHANGED
@@ -3,7 +3,6 @@ import cv2
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
import gradio as gr
|
6 |
-
import spaces
|
7 |
import sys
|
8 |
import os
|
9 |
import socket
|
@@ -11,42 +10,64 @@ import webbrowser
|
|
11 |
sys.path.append('vggt/')
|
12 |
import shutil
|
13 |
from datetime import datetime
|
14 |
-
from demo_hf import demo_fn
|
15 |
from omegaconf import DictConfig, OmegaConf
|
16 |
import glob
|
17 |
import gc
|
18 |
import time
|
19 |
from viser_fn import viser_wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
-
def get_free_port():
|
23 |
-
"""Get a free port using socket."""
|
24 |
-
# return 80
|
25 |
-
# return 8080
|
26 |
-
# return 10088 # for debugging
|
27 |
-
# return 7860
|
28 |
-
# return 7888
|
29 |
-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
30 |
-
s.bind(('', 0))
|
31 |
-
port = s.getsockname()[1]
|
32 |
-
return port
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
|
|
|
|
37 |
@spaces.GPU(duration=240)
|
38 |
def vggt_demo(
|
39 |
input_video,
|
40 |
input_image,
|
|
|
|
|
|
|
41 |
):
|
42 |
start_time = time.time()
|
43 |
gc.collect()
|
44 |
torch.cuda.empty_cache()
|
45 |
|
46 |
-
|
47 |
-
|
48 |
|
49 |
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
50 |
target_dir = f"input_images_{timestamp}"
|
51 |
if os.path.exists(target_dir):
|
52 |
shutil.rmtree(target_dir)
|
@@ -65,9 +86,6 @@ def vggt_demo(
|
|
65 |
|
66 |
if input_image is not None:
|
67 |
input_image = sorted(input_image)
|
68 |
-
# recon_num = len(input_image)
|
69 |
-
|
70 |
-
# Copy files to the new directory
|
71 |
for file_name in input_image:
|
72 |
shutil.copy(file_name, target_dir_images)
|
73 |
elif input_video is not None:
|
@@ -90,26 +108,37 @@ def vggt_demo(
|
|
90 |
|
91 |
if count % frame_interval == 0:
|
92 |
cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame)
|
93 |
-
video_frame_num+=1
|
94 |
-
|
95 |
-
# recon_num = video_frame_num
|
96 |
-
# if recon_num<3:
|
97 |
-
# return None, "Please input at least three frames"
|
98 |
else:
|
99 |
-
return None, "Uploading not finished or Incorrect input format"
|
100 |
-
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
print(f"Files have been copied to {target_dir_images}")
|
103 |
cfg.SCENE_DIR = target_dir
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
109 |
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
112 |
|
|
|
|
|
|
|
113 |
del predictions
|
114 |
gc.collect()
|
115 |
torch.cuda.empty_cache()
|
@@ -120,10 +149,31 @@ def vggt_demo(
|
|
120 |
execution_time = end_time - start_time
|
121 |
print(f"Execution time: {execution_time} seconds")
|
122 |
|
123 |
-
# Return None for the 3D
|
124 |
# viser_url = f"Viser visualization is ready at: http://localhost:{viser_port}"
|
125 |
# print(viser_url) # Debug print
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
|
129 |
|
@@ -177,8 +227,17 @@ with gr.Blocks() as demo:
|
|
177 |
gr.Markdown("""
|
178 |
# 🏛️ VGGT: Visual Geometry Grounded Transformer
|
179 |
|
180 |
-
<div style="font-size: 16px; line-height: 1.
|
181 |
-
Alpha version (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
</div>
|
183 |
""")
|
184 |
|
@@ -186,87 +245,69 @@ with gr.Blocks() as demo:
|
|
186 |
with gr.Column(scale=1):
|
187 |
input_video = gr.Video(label="Upload Video", interactive=True)
|
188 |
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
|
189 |
-
|
190 |
-
|
191 |
with gr.Column(scale=3):
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
202 |
log_output = gr.Textbox(label="Log")
|
|
|
|
|
203 |
|
204 |
with gr.Row():
|
205 |
submit_btn = gr.Button("Reconstruct", scale=1)
|
206 |
-
|
|
|
207 |
|
208 |
|
209 |
|
210 |
|
211 |
examples = [
|
212 |
-
[
|
213 |
-
[
|
|
|
|
|
214 |
# [person_video, person_images],
|
215 |
# [statue_video, statue_images],
|
216 |
# [drums_video, drums_images],
|
217 |
-
[
|
218 |
-
[fern_video, fern_images],
|
219 |
-
[horns_video, horns_images],
|
220 |
# [apple_video, apple_images],
|
221 |
# [bonsai_video, bonsai_images],
|
222 |
]
|
223 |
|
224 |
-
def process_example(video, images):
|
225 |
-
"""Wrapper function to ensure outputs are properly captured"""
|
226 |
-
model_output, log = vggt_demo(video, images)
|
227 |
-
|
228 |
-
# viser_wrapper(predictions, port=log)
|
229 |
-
# Get the hostname - use the actual hostname or IP where the server is running
|
230 |
-
# hostname = socket.gethostname()
|
231 |
-
|
232 |
-
# Extract port from log
|
233 |
-
port = log
|
234 |
-
|
235 |
-
# Create the viser URL using the hostname
|
236 |
-
# viser_url = f"http://{hostname}:{port}"
|
237 |
-
|
238 |
-
viser_url = f"http://localhost:{log}"
|
239 |
-
print(f"Viser URL: {viser_url}")
|
240 |
-
|
241 |
-
# Create the iframe HTML code. Set width and height appropriately.
|
242 |
-
iframe_code = f'<iframe src="{viser_url}" width="100%" height="520px"></iframe>'
|
243 |
-
|
244 |
-
|
245 |
-
# Return the iframe code to update the gr.HTML component
|
246 |
-
return iframe_code, f"Visualization running at {viser_url}"
|
247 |
-
|
248 |
-
|
249 |
-
# TODO: move the selection of port outside of the demo function
|
250 |
-
# so that we can cache examples
|
251 |
-
|
252 |
gr.Examples(examples=examples,
|
253 |
-
inputs=[input_video, input_images],
|
254 |
-
outputs=[
|
255 |
-
fn=
|
256 |
cache_examples=False,
|
257 |
examples_per_page=50,
|
258 |
)
|
259 |
|
260 |
submit_btn.click(
|
261 |
-
|
262 |
-
[input_video, input_images],
|
263 |
-
[
|
264 |
# concurrency_limit=1
|
265 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
# demo.launch(debug=True, share=True)
|
268 |
# demo.launch(server_name="0.0.0.0", server_port=8082, debug=True, share=False)
|
269 |
# demo.queue(max_size=20).launch(show_error=True, share=True)
|
270 |
demo.queue(max_size=20).launch(show_error=True) #, share=True, server_port=7888, server_name="0.0.0.0")
|
|
|
271 |
# demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True)
|
272 |
########################################################################################################################
|
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
import gradio as gr
|
|
|
6 |
import sys
|
7 |
import os
|
8 |
import socket
|
|
|
10 |
sys.path.append('vggt/')
|
11 |
import shutil
|
12 |
from datetime import datetime
|
13 |
+
from demo_hf import demo_fn #, initialize_model
|
14 |
from omegaconf import DictConfig, OmegaConf
|
15 |
import glob
|
16 |
import gc
|
17 |
import time
|
18 |
from viser_fn import viser_wrapper
|
19 |
+
from gradio_util import demo_predictions_to_glb
|
20 |
+
from hydra.utils import instantiate
|
21 |
+
import spaces
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
# def get_free_port():
|
27 |
+
# """Get a free port using socket."""
|
28 |
+
# # return 80
|
29 |
+
# # return 8080
|
30 |
+
# # return 10088 # for debugging
|
31 |
+
# # return 7860
|
32 |
+
# # return 7888
|
33 |
+
# with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
34 |
+
# s.bind(('', 0))
|
35 |
+
# port = s.getsockname()[1]
|
36 |
+
# return port
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
cfg_file = "config/base.yaml"
|
41 |
+
cfg = OmegaConf.load(cfg_file)
|
42 |
+
vggt_model = instantiate(cfg, _recursive_=False)
|
43 |
+
_VGGT_URL = "https://huggingface.co/facebook/vggt_alpha/resolve/main/vggt_alpha_v0.pt"
|
44 |
+
# Reload vggt_model
|
45 |
+
pretrain_model = torch.hub.load_state_dict_from_url(_VGGT_URL)
|
46 |
|
47 |
+
if "vggt_model" in pretrain_model:
|
48 |
+
model_dict = pretrain_model["vggt_model"]
|
49 |
+
vggt_model.load_state_dict(model_dict, strict=False)
|
50 |
+
else:
|
51 |
+
vggt_model.load_state_dict(pretrain_model, strict=True)
|
52 |
|
53 |
|
54 |
+
# @torch.inference_mode()
|
55 |
+
|
56 |
@spaces.GPU(duration=240)
|
57 |
def vggt_demo(
|
58 |
input_video,
|
59 |
input_image,
|
60 |
+
conf_thres=3.0,
|
61 |
+
frame_filter="all",
|
62 |
+
mask_black_bg=False,
|
63 |
):
|
64 |
start_time = time.time()
|
65 |
gc.collect()
|
66 |
torch.cuda.empty_cache()
|
67 |
|
68 |
+
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
69 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
70 |
|
|
|
71 |
target_dir = f"input_images_{timestamp}"
|
72 |
if os.path.exists(target_dir):
|
73 |
shutil.rmtree(target_dir)
|
|
|
86 |
|
87 |
if input_image is not None:
|
88 |
input_image = sorted(input_image)
|
|
|
|
|
|
|
89 |
for file_name in input_image:
|
90 |
shutil.copy(file_name, target_dir_images)
|
91 |
elif input_video is not None:
|
|
|
108 |
|
109 |
if count % frame_interval == 0:
|
110 |
cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame)
|
111 |
+
video_frame_num+=1
|
|
|
|
|
|
|
|
|
112 |
else:
|
113 |
+
return None, "Uploading not finished or Incorrect input format", None, None
|
|
|
114 |
|
115 |
+
all_files = sorted(os.listdir(target_dir_images))
|
116 |
+
|
117 |
+
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
|
118 |
+
|
119 |
+
# Update frame_filter choices
|
120 |
+
frame_filter_choices = ["All"] + all_files
|
121 |
+
|
122 |
print(f"Files have been copied to {target_dir_images}")
|
123 |
cfg.SCENE_DIR = target_dir
|
124 |
|
125 |
+
print("Running demo_fn")
|
126 |
+
with torch.no_grad():
|
127 |
+
predictions = demo_fn(cfg, vggt_model)
|
128 |
+
predictions["pred_extrinsic_list"] = None
|
129 |
+
print("Saving predictions")
|
130 |
+
|
131 |
+
prediction_save_path = f"{target_dir}/predictions.npz"
|
132 |
|
133 |
+
np.savez(prediction_save_path, **predictions)
|
134 |
+
|
135 |
+
|
136 |
+
glbfile = target_dir + f"/glbscene_{conf_thres}_{frame_filter.replace('.', '_')}_mask{mask_black_bg}.glb"
|
137 |
+
|
138 |
|
139 |
+
glbscene = demo_predictions_to_glb(predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg)
|
140 |
+
glbscene.export(file_obj=glbfile)
|
141 |
+
|
142 |
del predictions
|
143 |
gc.collect()
|
144 |
torch.cuda.empty_cache()
|
|
|
149 |
execution_time = end_time - start_time
|
150 |
print(f"Execution time: {execution_time} seconds")
|
151 |
|
152 |
+
# Return None for the 3D vggt_model (since we're using viser) and the viser URL
|
153 |
# viser_url = f"Viser visualization is ready at: http://localhost:{viser_port}"
|
154 |
# print(viser_url) # Debug print
|
155 |
+
log = "Success. Waiting for visualization."
|
156 |
+
return glbfile, log, target_dir, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
def update_visualization(target_dir, conf_thres, frame_filter, mask_black_bg):
|
161 |
+
|
162 |
+
loaded = np.load(f"{target_dir}/predictions.npz", allow_pickle=True)
|
163 |
+
# predictions = np.load(f"{target_dir}/predictions.npz", allow_pickle=True)
|
164 |
+
# predictions["arr_0"]
|
165 |
+
# for key in predictions.files: print(key)
|
166 |
+
predictions = {key: loaded[key] for key in loaded.keys()}
|
167 |
+
|
168 |
+
glbfile = target_dir + f"/glbscene_{conf_thres}_{frame_filter.replace('.', '_')}_mask{mask_black_bg}.glb"
|
169 |
+
|
170 |
+
if not os.path.exists(glbfile):
|
171 |
+
glbscene = demo_predictions_to_glb(predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg)
|
172 |
+
glbscene.export(file_obj=glbfile)
|
173 |
+
return glbfile, "Updating Visualization", target_dir
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
|
178 |
|
179 |
|
|
|
227 |
gr.Markdown("""
|
228 |
# 🏛️ VGGT: Visual Geometry Grounded Transformer
|
229 |
|
230 |
+
<div style="font-size: 16px; line-height: 1.5;">
|
231 |
+
<p><strong>Alpha version</strong> (under active development)</p>
|
232 |
+
|
233 |
+
<p>Upload a video or images to create a 3D reconstruction. Once your media appears in the left panel, click the "Reconstruct" button to begin processing.</p>
|
234 |
+
|
235 |
+
<h3>Usage Tips:</h3>
|
236 |
+
<ol>
|
237 |
+
<li>After reconstruction, you can fine-tune the visualization by adjusting the confidence threshold or selecting specific frames to display, then click "Update Visualization".</li>
|
238 |
+
<li>Performance note: While the model itself processes quickly (~0.2 seconds), initial setup and visualization may take longer. First-time use requires downloading model weights, and rendering dense point clouds can be resource-intensive.</li>
|
239 |
+
<li>Known limitation: The model currently exhibits inconsistent behavior with videos centered around human subjects. This issue is being addressed in upcoming updates.</li>
|
240 |
+
</ol>
|
241 |
</div>
|
242 |
""")
|
243 |
|
|
|
245 |
with gr.Column(scale=1):
|
246 |
input_video = gr.Video(label="Upload Video", interactive=True)
|
247 |
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
|
248 |
+
|
|
|
249 |
with gr.Column(scale=3):
|
250 |
+
with gr.Column():
|
251 |
+
gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)**")
|
252 |
+
reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
|
253 |
+
# reconstruction_output = gr.Model3D(label="3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)", height=520, zoom_speed=0.5, pan_speed=0.5)
|
254 |
+
|
255 |
+
# Move these controls to a new row above the log output
|
256 |
+
with gr.Row():
|
257 |
+
conf_thres = gr.Slider(minimum=0.1, maximum=20.0, value=3.0, step=0.1, label="Conf Thres")
|
258 |
+
frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
|
259 |
+
mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
|
260 |
+
|
261 |
log_output = gr.Textbox(label="Log")
|
262 |
+
# Add a hidden textbox for target_dir
|
263 |
+
target_dir_output = gr.Textbox(label="Target Dir", visible=False)
|
264 |
|
265 |
with gr.Row():
|
266 |
submit_btn = gr.Button("Reconstruct", scale=1)
|
267 |
+
revisual_btn = gr.Button("Update Visualization", scale=1)
|
268 |
+
clear_btn = gr.ClearButton([input_video, input_images, reconstruction_output, log_output, target_dir_output], scale=1) #Modified reconstruction_output
|
269 |
|
270 |
|
271 |
|
272 |
|
273 |
examples = [
|
274 |
+
[counter_video, counter_images, 1.5, "All", False],
|
275 |
+
[flower_video, flower_images, 1.5, "All", False],
|
276 |
+
[kitchen_video, kitchen_images, 3, "All", False],
|
277 |
+
[fern_video, fern_images, 1.5, "All", False],
|
278 |
# [person_video, person_images],
|
279 |
# [statue_video, statue_images],
|
280 |
# [drums_video, drums_images],
|
281 |
+
# [horns_video, horns_images, 1.5, "All", False],
|
|
|
|
|
282 |
# [apple_video, apple_images],
|
283 |
# [bonsai_video, bonsai_images],
|
284 |
]
|
285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
gr.Examples(examples=examples,
|
287 |
+
inputs=[input_video, input_images, conf_thres, frame_filter, mask_black_bg],
|
288 |
+
outputs=[reconstruction_output, log_output, target_dir_output, frame_filter], # Added frame_filter
|
289 |
+
fn=vggt_demo, # Use our wrapper function
|
290 |
cache_examples=False,
|
291 |
examples_per_page=50,
|
292 |
)
|
293 |
|
294 |
submit_btn.click(
|
295 |
+
vggt_demo, # Use the same wrapper function
|
296 |
+
[input_video, input_images, conf_thres, frame_filter, mask_black_bg],
|
297 |
+
[reconstruction_output, log_output, target_dir_output, frame_filter], # Added frame_filter to outputs
|
298 |
# concurrency_limit=1
|
299 |
)
|
300 |
+
|
301 |
+
revisual_btn.click(
|
302 |
+
update_visualization,
|
303 |
+
[target_dir_output, conf_thres, frame_filter, mask_black_bg],
|
304 |
+
[reconstruction_output, log_output, target_dir_output],
|
305 |
+
)
|
306 |
|
307 |
# demo.launch(debug=True, share=True)
|
308 |
# demo.launch(server_name="0.0.0.0", server_port=8082, debug=True, share=False)
|
309 |
# demo.queue(max_size=20).launch(show_error=True, share=True)
|
310 |
demo.queue(max_size=20).launch(show_error=True) #, share=True, server_port=7888, server_name="0.0.0.0")
|
311 |
+
# share=True
|
312 |
# demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True)
|
313 |
########################################################################################################################
|
demo_hf.py
CHANGED
@@ -11,26 +11,29 @@ from viser_fn import viser_wrapper
|
|
11 |
|
12 |
|
13 |
# @hydra.main(config_path="config", config_name="base")
|
14 |
-
def demo_fn(cfg: DictConfig) -> None:
|
15 |
-
print(cfg)
|
16 |
-
model = instantiate(cfg, _recursive_=False)
|
17 |
|
18 |
if not torch.cuda.is_available():
|
19 |
raise ValueError("CUDA is not available. Check your environment.")
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
model = model.to(device)
|
23 |
|
24 |
-
_VGGT_URL = "https://huggingface.co/facebook/vggt_alpha/resolve/main/vggt_alpha_v0.pt"
|
25 |
|
26 |
-
# Reload model
|
27 |
-
pretrain_model = torch.hub.load_state_dict_from_url(_VGGT_URL)
|
28 |
|
29 |
-
if "model" in pretrain_model:
|
30 |
-
|
31 |
-
|
32 |
-
else:
|
33 |
-
|
34 |
|
35 |
|
36 |
# batch = torch.load("/fsx-repligen/jianyuan/cvpr2025_ckpts/batch.pth")
|
|
|
11 |
|
12 |
|
13 |
# @hydra.main(config_path="config", config_name="base")
|
14 |
+
def demo_fn(cfg: DictConfig, model) -> None:
|
15 |
+
print(cfg.SCENE_DIR)
|
|
|
16 |
|
17 |
if not torch.cuda.is_available():
|
18 |
raise ValueError("CUDA is not available. Check your environment.")
|
19 |
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
device = "cuda"
|
22 |
+
else:
|
23 |
+
device = "cpu"
|
24 |
+
|
25 |
model = model.to(device)
|
26 |
|
27 |
+
# _VGGT_URL = "https://huggingface.co/facebook/vggt_alpha/resolve/main/vggt_alpha_v0.pt"
|
28 |
|
29 |
+
# # Reload model
|
30 |
+
# pretrain_model = torch.hub.load_state_dict_from_url(_VGGT_URL)
|
31 |
|
32 |
+
# if "model" in pretrain_model:
|
33 |
+
# model_dict = pretrain_model["model"]
|
34 |
+
# model.load_state_dict(model_dict, strict=False)
|
35 |
+
# else:
|
36 |
+
# model.load_state_dict(pretrain_model, strict=True)
|
37 |
|
38 |
|
39 |
# batch = torch.load("/fsx-repligen/jianyuan/cvpr2025_ckpts/batch.pth")
|
gradio_util.py
CHANGED
@@ -1,56 +1,22 @@
|
|
1 |
-
|
2 |
-
import os
|
3 |
-
|
4 |
-
import trimesh
|
5 |
-
import open3d as o3d
|
6 |
-
|
7 |
-
import gradio as gr
|
8 |
-
import numpy as np
|
9 |
-
import matplotlib
|
10 |
-
from scipy.spatial.transform import Rotation
|
11 |
-
|
12 |
-
print("Successfully imported the packages for Gradio visualization")
|
13 |
-
except:
|
14 |
-
print(
|
15 |
-
f"Failed to import packages for Gradio visualization. Please disable gradio visualization"
|
16 |
-
)
|
17 |
-
|
18 |
-
|
19 |
-
def visualize_by_gradio(glbfile):
|
20 |
-
"""
|
21 |
-
Set up and launch a Gradio interface to visualize a GLB file.
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
"""
|
26 |
-
|
27 |
-
def load_glb_file(glb_path):
|
28 |
-
# Check if the file exists and return the path or error message
|
29 |
-
if os.path.exists(glb_path):
|
30 |
-
return glb_path, "3D Model Loaded Successfully"
|
31 |
-
else:
|
32 |
-
return None, "File not found"
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
# 3D Model viewer component
|
42 |
-
model_viewer = gr.Model3D(
|
43 |
-
label="3D Model Viewer", height=600, value=initial_model
|
44 |
-
)
|
45 |
|
46 |
-
# Textbox for log output
|
47 |
-
log_output = gr.Textbox(label="Log", lines=2, value=log_message)
|
48 |
|
49 |
-
# Launch the Gradio interface
|
50 |
-
demo.launch(share=True)
|
51 |
|
52 |
|
53 |
-
def
|
54 |
"""
|
55 |
Converts VGG SFM predictions to a 3D scene represented as a GLB.
|
56 |
|
@@ -61,27 +27,51 @@ def vggsfm_predictions_to_glb(predictions) -> trimesh.Scene:
|
|
61 |
trimesh.Scene: A 3D scene object.
|
62 |
"""
|
63 |
# Convert predictions to numpy arrays
|
64 |
-
|
65 |
-
colors_rgb = (predictions["points3D_rgb"].cpu().numpy() * 255).astype(
|
66 |
-
np.uint8
|
67 |
-
)
|
68 |
-
|
69 |
-
|
70 |
-
if True:
|
71 |
-
pcd = o3d.geometry.PointCloud()
|
72 |
-
pcd.points = o3d.utility.Vector3dVector(vertices_3d)
|
73 |
-
pcd.colors = o3d.utility.Vector3dVector(colors_rgb)
|
74 |
-
|
75 |
-
cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=1.0)
|
76 |
-
filtered_pcd = pcd.select_by_index(ind)
|
77 |
-
|
78 |
-
print(f"Filter out {len(vertices_3d) - len(filtered_pcd.points)} 3D points")
|
79 |
-
vertices_3d = np.asarray(filtered_pcd.points)
|
80 |
-
colors_rgb = np.asarray(filtered_pcd.colors).astype(np.uint8)
|
81 |
-
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# Calculate the 5th and 95th percentiles along each axis
|
87 |
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
|
@@ -122,39 +112,10 @@ def vggsfm_predictions_to_glb(predictions) -> trimesh.Scene:
|
|
122 |
# Align scene to the observation of the first camera
|
123 |
scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
|
124 |
|
|
|
125 |
return scene_3d
|
126 |
|
127 |
|
128 |
-
def apply_scene_alignment(
|
129 |
-
scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray
|
130 |
-
) -> trimesh.Scene:
|
131 |
-
"""
|
132 |
-
Aligns the 3D scene based on the extrinsics of the first camera.
|
133 |
-
|
134 |
-
Args:
|
135 |
-
scene_3d (trimesh.Scene): The 3D scene to be aligned.
|
136 |
-
extrinsics_matrices (np.ndarray): Camera extrinsic matrices.
|
137 |
-
|
138 |
-
Returns:
|
139 |
-
trimesh.Scene: Aligned 3D scene.
|
140 |
-
"""
|
141 |
-
# Set transformations for scene alignment
|
142 |
-
opengl_conversion_matrix = get_opengl_conversion_matrix()
|
143 |
-
|
144 |
-
# Rotation matrix for alignment (180 degrees around the y-axis)
|
145 |
-
align_rotation = np.eye(4)
|
146 |
-
align_rotation[:3, :3] = Rotation.from_euler(
|
147 |
-
"y", 180, degrees=True
|
148 |
-
).as_matrix()
|
149 |
-
|
150 |
-
# Apply transformation
|
151 |
-
initial_transformation = (
|
152 |
-
np.linalg.inv(extrinsics_matrices[0])
|
153 |
-
@ opengl_conversion_matrix
|
154 |
-
@ align_rotation
|
155 |
-
)
|
156 |
-
scene_3d.apply_transform(initial_transformation)
|
157 |
-
return scene_3d
|
158 |
|
159 |
|
160 |
def integrate_camera_into_scene(
|
@@ -215,40 +176,57 @@ def integrate_camera_into_scene(
|
|
215 |
scene.add_geometry(camera_mesh)
|
216 |
|
217 |
|
218 |
-
|
|
|
|
|
|
|
219 |
"""
|
220 |
-
|
221 |
|
222 |
Args:
|
223 |
-
|
|
|
224 |
|
225 |
Returns:
|
226 |
-
|
227 |
"""
|
228 |
-
#
|
229 |
-
|
230 |
-
num_vertices_cone = len(cone_shape.vertices)
|
231 |
|
232 |
-
for
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
-
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
|
251 |
-
return np.array(faces_list)
|
252 |
|
253 |
|
254 |
def transform_points(
|
@@ -280,18 +258,38 @@ def transform_points(
|
|
280 |
return result
|
281 |
|
282 |
|
283 |
-
|
|
|
284 |
"""
|
285 |
-
|
|
|
|
|
|
|
286 |
|
287 |
Returns:
|
288 |
-
|
289 |
"""
|
290 |
-
# Create
|
291 |
-
|
|
|
292 |
|
293 |
-
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
296 |
|
297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
import trimesh
|
4 |
+
# import open3d as o3d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib
|
9 |
+
from scipy.spatial.transform import Rotation
|
10 |
|
11 |
+
# except:
|
12 |
+
# print(
|
13 |
+
# f"Failed to import packages for Gradio visualization. Please disable gradio visualization"
|
14 |
+
# )
|
|
|
|
|
|
|
|
|
15 |
|
|
|
|
|
16 |
|
|
|
|
|
17 |
|
18 |
|
19 |
+
def demo_predictions_to_glb(predictions, conf_thres=3.0, filter_by_frames="all", mask_black_bg=False) -> trimesh.Scene:
|
20 |
"""
|
21 |
Converts VGG SFM predictions to a 3D scene represented as a GLB.
|
22 |
|
|
|
27 |
trimesh.Scene: A 3D scene object.
|
28 |
"""
|
29 |
# Convert predictions to numpy arrays
|
30 |
+
# pred_extrinsic_list', 'pred_world_points', 'pred_world_points_conf', 'images', 'last_pred_extrinsic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
print("Building GLB scene")
|
33 |
+
selected_frame_idx = None
|
34 |
+
if filter_by_frames != "all":
|
35 |
+
try:
|
36 |
+
# Extract the index part before the colon
|
37 |
+
selected_frame_idx = int(filter_by_frames.split(":")[0])
|
38 |
+
except (ValueError, IndexError):
|
39 |
+
# If parsing fails, default to using all frames
|
40 |
+
pass
|
41 |
+
|
42 |
+
pred_world_points = predictions["pred_world_points"][0] # remove batch dimension
|
43 |
+
pred_world_points_conf = predictions["pred_world_points_conf"][0]
|
44 |
+
images = predictions["images"][0]
|
45 |
+
last_pred_extrinsic = predictions["last_pred_extrinsic"][0]
|
46 |
+
|
47 |
+
|
48 |
+
if selected_frame_idx is not None:
|
49 |
+
pred_world_points = pred_world_points[selected_frame_idx][None]
|
50 |
+
pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
|
51 |
+
images = images[selected_frame_idx][None]
|
52 |
+
last_pred_extrinsic = last_pred_extrinsic[selected_frame_idx][None]
|
53 |
+
|
54 |
+
vertices_3d = pred_world_points.reshape(-1, 3)
|
55 |
+
colors_rgb = np.transpose(images, (0, 2, 3, 1)) #images.permute(0, 3, 1, 2)
|
56 |
+
colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
|
57 |
+
camera_matrices = last_pred_extrinsic
|
58 |
+
|
59 |
+
conf = pred_world_points_conf.reshape(-1)
|
60 |
+
conf_mask = conf > conf_thres
|
61 |
|
62 |
+
if mask_black_bg:
|
63 |
+
black_bg_mask = colors_rgb.sum(axis=1) >= 16
|
64 |
+
conf_mask = conf_mask & black_bg_mask
|
65 |
+
|
66 |
+
vertices_3d = vertices_3d[conf_mask]
|
67 |
+
colors_rgb = colors_rgb[conf_mask]
|
68 |
+
|
69 |
+
|
70 |
+
# vertices_3d = predictions["points3D"].cpu().numpy()
|
71 |
+
# colors_rgb = (predictions["points3D_rgb"].cpu().numpy() * 255).astype(
|
72 |
+
# np.uint8
|
73 |
+
# )
|
74 |
+
# camera_matrices = predictions["extrinsics_opencv"].cpu().numpy()
|
75 |
|
76 |
# Calculate the 5th and 95th percentiles along each axis
|
77 |
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
|
|
|
112 |
# Align scene to the observation of the first camera
|
113 |
scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
|
114 |
|
115 |
+
print("GLB Scene built")
|
116 |
return scene_3d
|
117 |
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
|
121 |
def integrate_camera_into_scene(
|
|
|
176 |
scene.add_geometry(camera_mesh)
|
177 |
|
178 |
|
179 |
+
|
180 |
+
def apply_scene_alignment(
|
181 |
+
scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray
|
182 |
+
) -> trimesh.Scene:
|
183 |
"""
|
184 |
+
Aligns the 3D scene based on the extrinsics of the first camera.
|
185 |
|
186 |
Args:
|
187 |
+
scene_3d (trimesh.Scene): The 3D scene to be aligned.
|
188 |
+
extrinsics_matrices (np.ndarray): Camera extrinsic matrices.
|
189 |
|
190 |
Returns:
|
191 |
+
trimesh.Scene: Aligned 3D scene.
|
192 |
"""
|
193 |
+
# Set transformations for scene alignment
|
194 |
+
opengl_conversion_matrix = get_opengl_conversion_matrix()
|
|
|
195 |
|
196 |
+
# Rotation matrix for alignment (180 degrees around the y-axis)
|
197 |
+
align_rotation = np.eye(4)
|
198 |
+
align_rotation[:3, :3] = Rotation.from_euler(
|
199 |
+
"y", 180, degrees=True
|
200 |
+
).as_matrix()
|
|
|
201 |
|
202 |
+
# Apply transformation
|
203 |
+
initial_transformation = (
|
204 |
+
np.linalg.inv(extrinsics_matrices[0])
|
205 |
+
@ opengl_conversion_matrix
|
206 |
+
@ align_rotation
|
207 |
+
)
|
208 |
+
scene_3d.apply_transform(initial_transformation)
|
209 |
+
return scene_3d
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
def get_opengl_conversion_matrix() -> np.ndarray:
|
215 |
+
"""
|
216 |
+
Constructs and returns the OpenGL conversion matrix.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
numpy.ndarray: A 4x4 OpenGL conversion matrix.
|
220 |
+
"""
|
221 |
+
# Create an identity matrix
|
222 |
+
matrix = np.identity(4)
|
223 |
+
|
224 |
+
# Flip the y and z axes
|
225 |
+
matrix[1, 1] = -1
|
226 |
+
matrix[2, 2] = -1
|
227 |
+
|
228 |
+
return matrix
|
229 |
|
|
|
|
|
230 |
|
231 |
|
232 |
def transform_points(
|
|
|
258 |
return result
|
259 |
|
260 |
|
261 |
+
|
262 |
+
def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
|
263 |
"""
|
264 |
+
Computes the faces for the camera mesh.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
cone_shape (trimesh.Trimesh): The shape of the camera cone.
|
268 |
|
269 |
Returns:
|
270 |
+
np.ndarray: Array of faces for the camera mesh.
|
271 |
"""
|
272 |
+
# Create pseudo cameras
|
273 |
+
faces_list = []
|
274 |
+
num_vertices_cone = len(cone_shape.vertices)
|
275 |
|
276 |
+
for face in cone_shape.faces:
|
277 |
+
if 0 in face:
|
278 |
+
continue
|
279 |
+
v1, v2, v3 = face
|
280 |
+
v1_offset, v2_offset, v3_offset = face + num_vertices_cone
|
281 |
+
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
|
282 |
|
283 |
+
faces_list.extend(
|
284 |
+
[
|
285 |
+
(v1, v2, v2_offset),
|
286 |
+
(v1, v1_offset, v3),
|
287 |
+
(v3_offset, v2, v3),
|
288 |
+
(v1, v2, v2_offset_2),
|
289 |
+
(v1, v1_offset_2, v3),
|
290 |
+
(v3_offset_2, v2, v3),
|
291 |
+
]
|
292 |
+
)
|
293 |
+
|
294 |
+
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
|
295 |
+
return np.array(faces_list)
|