Create app.py
Browse files
app.py
CHANGED
@@ -1,28 +1,169 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
server_name = args.server_name
|
21 |
else:
|
22 |
-
server_name = '0.0.0.0'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import gradio
|
4 |
+
import argparse
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import trimesh
|
9 |
+
import copy
|
10 |
+
import functools
|
11 |
+
from scipy.spatial.transform import Rotation
|
12 |
+
|
13 |
+
from dust3r.inference import inference
|
14 |
+
from dust3r.model import AsymmetricCroCo3DStereo
|
15 |
+
from dust3r.image_pairs import make_pairs
|
16 |
+
from dust3r.utils.image import load_images, rgb
|
17 |
+
from dust3r.utils.device import to_numpy
|
18 |
+
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
19 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
20 |
+
|
21 |
+
import matplotlib.pyplot as pl
|
22 |
+
pl.ion()
|
23 |
+
|
24 |
+
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
|
25 |
+
batch_size = 1
|
26 |
+
|
27 |
+
def run_dust3r_inference(args, filelist, schedule, niter, min_conf_thr, as_pointcloud,
|
28 |
+
mask_sky, clean_depth, transparent_cams, cam_size,
|
29 |
+
scenegraph_type, winsize, refid):
|
30 |
+
tmpdirname = tempfile.mkdtemp(suffix='dust3r_gradio_demo')
|
31 |
+
if args.tmp_dir is not None:
|
32 |
+
tmp_path = args.tmp_dir
|
33 |
+
os.makedirs(tmp_path, exist_ok=True)
|
34 |
+
tempfile.tempdir = tmp_path
|
35 |
+
|
36 |
+
if args.server_name is not None:
|
37 |
server_name = args.server_name
|
38 |
else:
|
39 |
+
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
40 |
+
|
41 |
+
if args.weights is not None:
|
42 |
+
weights_path = args.weights
|
43 |
+
else:
|
44 |
+
weights_path = "naver/" + args.model_name
|
45 |
+
model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device)
|
46 |
+
|
47 |
+
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, args.device, args.image_size)
|
48 |
+
model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname)
|
49 |
+
|
50 |
+
# Run the reconstruction function
|
51 |
+
scene, outfile, imgs = recon_fun(filelist, schedule, niter, min_conf_thr, as_pointcloud,
|
52 |
+
mask_sky, clean_depth, transparent_cams, cam_size,
|
53 |
+
scenegraph_type, winsize, refid)
|
54 |
+
|
55 |
+
# Return the result
|
56 |
+
return outfile, imgs
|
57 |
+
|
58 |
+
def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False):
|
59 |
+
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)
|
60 |
+
model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
|
61 |
+
with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="DUSt3R Demo") as demo:
|
62 |
+
# scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
|
63 |
+
scene = gradio.State(None)
|
64 |
+
gradio.HTML('<h2 style="text-align: center;">DUSt3R Demo</h2>')
|
65 |
+
with gradio.Column():
|
66 |
+
inputfiles = gradio.File(file_count="multiple")
|
67 |
+
with gradio.Row():
|
68 |
+
schedule = gradio.Dropdown(["linear", "cosine"],
|
69 |
+
value='linear', label="schedule", info="For global alignment!")
|
70 |
+
niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
|
71 |
+
label="num_iterations", info="For global alignment!")
|
72 |
+
scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
|
73 |
+
value='complete', label="Scenegraph",
|
74 |
+
info="Define how to make pairs",
|
75 |
+
interactive=True)
|
76 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
|
77 |
+
minimum=1, maximum=1, step=1, visible=False)
|
78 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
|
79 |
+
|
80 |
+
run_btn = gradio.Button("Run")
|
81 |
+
|
82 |
+
with gradio.Row():
|
83 |
+
# adjust the confidence threshold
|
84 |
+
min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
|
85 |
+
# adjust the camera size in the output pointcloud
|
86 |
+
cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
|
87 |
+
with gradio.Row():
|
88 |
+
as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
|
89 |
+
# two post process implemented
|
90 |
+
mask_sky = gradio.Checkbox(value=False, label="Mask sky")
|
91 |
+
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
|
92 |
+
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
|
93 |
+
|
94 |
+
outmodel = gradio.Model3D()
|
95 |
+
outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height="100%")
|
96 |
+
|
97 |
+
# events
|
98 |
+
scenegraph_type.change(set_scenegraph_options,
|
99 |
+
inputs=[inputfiles, winsize, refid, scenegraph_type],
|
100 |
+
outputs=[winsize, refid])
|
101 |
+
inputfiles.change(set_scenegraph_options,
|
102 |
+
inputs=[inputfiles, winsize, refid, scenegraph_type],
|
103 |
+
outputs=[winsize, refid])
|
104 |
+
run_btn.click(fn=recon_fun,
|
105 |
+
inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
|
106 |
+
mask_sky, clean_depth, transparent_cams, cam_size,
|
107 |
+
scenegraph_type, winsize, refid],
|
108 |
+
outputs=[scene, outmodel, outgallery])
|
109 |
+
min_conf_thr.release(fn=model_from_scene_fun,
|
110 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
111 |
+
clean_depth, transparent_cams, cam_size],
|
112 |
+
outputs=outmodel)
|
113 |
+
cam_size.change(fn=model_from_scene_fun,
|
114 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
115 |
+
clean_depth, transparent_cams, cam_size],
|
116 |
+
outputs=outmodel)
|
117 |
+
as_pointcloud.change(fn=model_from_scene_fun,
|
118 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
119 |
+
clean_depth, transparent_cams, cam_size],
|
120 |
+
outputs=outmodel)
|
121 |
+
mask_sky.change(fn=model_from_scene_fun,
|
122 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
123 |
+
clean_depth, transparent_cams, cam_size],
|
124 |
+
outputs=outmodel)
|
125 |
+
clean_depth.change(fn=model_from_scene_fun,
|
126 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
127 |
+
clean_depth, transparent_cams, cam_size],
|
128 |
+
outputs=outmodel)
|
129 |
+
transparent_cams.change(model_from_scene_fun,
|
130 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
131 |
+
clean_depth, transparent_cams, cam_size],
|
132 |
+
outputs=outmodel)
|
133 |
+
demo.launch(share=False, server_name=server_name, server_port=server_port)
|
134 |
+
|
135 |
+
# Gradio interface components
|
136 |
+
inputfiles = gradio.File(label="Input Images", file_count="multiple")
|
137 |
+
schedule = gradio.Dropdown(["linear", "cosine"], label="Schedule")
|
138 |
+
niter = gradio.Number(label="Number of Iterations", min=0, max=5000, step=1)
|
139 |
+
min_conf_thr = gradio.Slider(label="Minimum Confidence Threshold", min=1.0, max=20.0, step=0.1, default=3.0)
|
140 |
+
as_pointcloud = gradio.Checkbox(label="As Pointcloud", default=False)
|
141 |
+
mask_sky = gradio.Checkbox(label="Mask Sky", default=False)
|
142 |
+
clean_depth = gradio.Checkbox(label="Clean-up Depthmaps", default=True)
|
143 |
+
transparent_cams = gradio.Checkbox(label="Transparent Cameras", default=False)
|
144 |
+
cam_size = gradio.Slider(label="Camera Size", min=0.001, max=0.1, step=0.001, default=0.05)
|
145 |
+
scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"], label="Scene Graph Type")
|
146 |
+
winsize = gradio.Slider(label="Window Size", min=1, max=1, step=1, default=1)
|
147 |
+
refid = gradio.Slider(label="Reference ID", min=0, max=0, step=1, default=0)
|
148 |
+
|
149 |
+
# Function to connect Gradio inputs to your main logic
|
150 |
+
def run_inference(filelist, schedule, niter, min_conf_thr, as_pointcloud, mask_sky, clean_depth,
|
151 |
+
transparent_cams, cam_size, scenegraph_type, winsize, refid):
|
152 |
+
args = None # You need to define your args here
|
153 |
+
outfile, imgs = run_dust3r_inference(args, filelist, schedule, niter, min_conf_thr, as_pointcloud,
|
154 |
+
mask_sky, clean_depth, transparent_cams, cam_size,
|
155 |
+
scenegraph_type, winsize, refid)
|
156 |
+
return imgs
|
157 |
|
158 |
+
# Launch the Gradio interface
|
159 |
+
iface = gradio.Interface(
|
160 |
+
fn=run_inference,
|
161 |
+
inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud, mask_sky, clean_depth,
|
162 |
+
transparent_cams, cam_size, scenegraph_type, winsize, refid],
|
163 |
+
outputs=gradio.Gallery(label="Output Images", columns=3),
|
164 |
+
title="DUSt3R Demo",
|
165 |
+
description="Reconstruct 3D scenes from input images using DUSt3R.",
|
166 |
+
server_name='0.0.0.0',
|
167 |
+
server_port=7860
|
168 |
+
)
|
169 |
+
iface.launch(share=True)
|