arminak6 commited on
Commit
e7d52aa
·
verified ·
1 Parent(s): 8f1d080

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -25
app.py CHANGED
@@ -1,28 +1,169 @@
1
- from dust3r.demo import *
2
-
3
- if True:
4
- # parser = get_args_parser()
5
- # args = parser.parse_args()
6
- # args = {
7
- # "--weights": "./dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth",
8
- # "tmp_dir": None,
9
- # "server_name": None,
10
- # "device": "cuda",
11
-
12
- # }
13
-
14
- # if args.tmp_dir is not None:
15
- # tmp_path = args.tmp_dir
16
- # os.makedirs(tmp_path, exist_ok=True)
17
- # tempfile.tempdir = tmp_path
18
-
19
- if False: #args.server_name is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  server_name = args.server_name
21
  else:
22
- server_name = '0.0.0.0' # if args.local_network else '127.0.0.1'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- model = load_model("./dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", "cuda")
25
- # dust3r will write the 3D model inside tmpdirname
26
- with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
27
- print('Outputing stuff in', tmpdirname)
28
- main_demo(tmpdirname, model, "cuda", 512, server_name, 7860)
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import gradio
4
+ import argparse
5
+ import math
6
+ import torch
7
+ import numpy as np
8
+ import trimesh
9
+ import copy
10
+ import functools
11
+ from scipy.spatial.transform import Rotation
12
+
13
+ from dust3r.inference import inference
14
+ from dust3r.model import AsymmetricCroCo3DStereo
15
+ from dust3r.image_pairs import make_pairs
16
+ from dust3r.utils.image import load_images, rgb
17
+ from dust3r.utils.device import to_numpy
18
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
19
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
20
+
21
+ import matplotlib.pyplot as pl
22
+ pl.ion()
23
+
24
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
25
+ batch_size = 1
26
+
27
+ def run_dust3r_inference(args, filelist, schedule, niter, min_conf_thr, as_pointcloud,
28
+ mask_sky, clean_depth, transparent_cams, cam_size,
29
+ scenegraph_type, winsize, refid):
30
+ tmpdirname = tempfile.mkdtemp(suffix='dust3r_gradio_demo')
31
+ if args.tmp_dir is not None:
32
+ tmp_path = args.tmp_dir
33
+ os.makedirs(tmp_path, exist_ok=True)
34
+ tempfile.tempdir = tmp_path
35
+
36
+ if args.server_name is not None:
37
  server_name = args.server_name
38
  else:
39
+ server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
40
+
41
+ if args.weights is not None:
42
+ weights_path = args.weights
43
+ else:
44
+ weights_path = "naver/" + args.model_name
45
+ model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device)
46
+
47
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, args.device, args.image_size)
48
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname)
49
+
50
+ # Run the reconstruction function
51
+ scene, outfile, imgs = recon_fun(filelist, schedule, niter, min_conf_thr, as_pointcloud,
52
+ mask_sky, clean_depth, transparent_cams, cam_size,
53
+ scenegraph_type, winsize, refid)
54
+
55
+ # Return the result
56
+ return outfile, imgs
57
+
58
+ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False):
59
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)
60
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
61
+ with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="DUSt3R Demo") as demo:
62
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
63
+ scene = gradio.State(None)
64
+ gradio.HTML('<h2 style="text-align: center;">DUSt3R Demo</h2>')
65
+ with gradio.Column():
66
+ inputfiles = gradio.File(file_count="multiple")
67
+ with gradio.Row():
68
+ schedule = gradio.Dropdown(["linear", "cosine"],
69
+ value='linear', label="schedule", info="For global alignment!")
70
+ niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
71
+ label="num_iterations", info="For global alignment!")
72
+ scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"],
73
+ value='complete', label="Scenegraph",
74
+ info="Define how to make pairs",
75
+ interactive=True)
76
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
77
+ minimum=1, maximum=1, step=1, visible=False)
78
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
79
+
80
+ run_btn = gradio.Button("Run")
81
+
82
+ with gradio.Row():
83
+ # adjust the confidence threshold
84
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1)
85
+ # adjust the camera size in the output pointcloud
86
+ cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001)
87
+ with gradio.Row():
88
+ as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud")
89
+ # two post process implemented
90
+ mask_sky = gradio.Checkbox(value=False, label="Mask sky")
91
+ clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
92
+ transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
93
+
94
+ outmodel = gradio.Model3D()
95
+ outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height="100%")
96
+
97
+ # events
98
+ scenegraph_type.change(set_scenegraph_options,
99
+ inputs=[inputfiles, winsize, refid, scenegraph_type],
100
+ outputs=[winsize, refid])
101
+ inputfiles.change(set_scenegraph_options,
102
+ inputs=[inputfiles, winsize, refid, scenegraph_type],
103
+ outputs=[winsize, refid])
104
+ run_btn.click(fn=recon_fun,
105
+ inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
106
+ mask_sky, clean_depth, transparent_cams, cam_size,
107
+ scenegraph_type, winsize, refid],
108
+ outputs=[scene, outmodel, outgallery])
109
+ min_conf_thr.release(fn=model_from_scene_fun,
110
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
111
+ clean_depth, transparent_cams, cam_size],
112
+ outputs=outmodel)
113
+ cam_size.change(fn=model_from_scene_fun,
114
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
115
+ clean_depth, transparent_cams, cam_size],
116
+ outputs=outmodel)
117
+ as_pointcloud.change(fn=model_from_scene_fun,
118
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
119
+ clean_depth, transparent_cams, cam_size],
120
+ outputs=outmodel)
121
+ mask_sky.change(fn=model_from_scene_fun,
122
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
123
+ clean_depth, transparent_cams, cam_size],
124
+ outputs=outmodel)
125
+ clean_depth.change(fn=model_from_scene_fun,
126
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
127
+ clean_depth, transparent_cams, cam_size],
128
+ outputs=outmodel)
129
+ transparent_cams.change(model_from_scene_fun,
130
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
131
+ clean_depth, transparent_cams, cam_size],
132
+ outputs=outmodel)
133
+ demo.launch(share=False, server_name=server_name, server_port=server_port)
134
+
135
+ # Gradio interface components
136
+ inputfiles = gradio.File(label="Input Images", file_count="multiple")
137
+ schedule = gradio.Dropdown(["linear", "cosine"], label="Schedule")
138
+ niter = gradio.Number(label="Number of Iterations", min=0, max=5000, step=1)
139
+ min_conf_thr = gradio.Slider(label="Minimum Confidence Threshold", min=1.0, max=20.0, step=0.1, default=3.0)
140
+ as_pointcloud = gradio.Checkbox(label="As Pointcloud", default=False)
141
+ mask_sky = gradio.Checkbox(label="Mask Sky", default=False)
142
+ clean_depth = gradio.Checkbox(label="Clean-up Depthmaps", default=True)
143
+ transparent_cams = gradio.Checkbox(label="Transparent Cameras", default=False)
144
+ cam_size = gradio.Slider(label="Camera Size", min=0.001, max=0.1, step=0.001, default=0.05)
145
+ scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"], label="Scene Graph Type")
146
+ winsize = gradio.Slider(label="Window Size", min=1, max=1, step=1, default=1)
147
+ refid = gradio.Slider(label="Reference ID", min=0, max=0, step=1, default=0)
148
+
149
+ # Function to connect Gradio inputs to your main logic
150
+ def run_inference(filelist, schedule, niter, min_conf_thr, as_pointcloud, mask_sky, clean_depth,
151
+ transparent_cams, cam_size, scenegraph_type, winsize, refid):
152
+ args = None # You need to define your args here
153
+ outfile, imgs = run_dust3r_inference(args, filelist, schedule, niter, min_conf_thr, as_pointcloud,
154
+ mask_sky, clean_depth, transparent_cams, cam_size,
155
+ scenegraph_type, winsize, refid)
156
+ return imgs
157
 
158
+ # Launch the Gradio interface
159
+ iface = gradio.Interface(
160
+ fn=run_inference,
161
+ inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud, mask_sky, clean_depth,
162
+ transparent_cams, cam_size, scenegraph_type, winsize, refid],
163
+ outputs=gradio.Gallery(label="Output Images", columns=3),
164
+ title="DUSt3R Demo",
165
+ description="Reconstruct 3D scenes from input images using DUSt3R.",
166
+ server_name='0.0.0.0',
167
+ server_port=7860
168
+ )
169
+ iface.launch(share=True)