|
import os |
|
import tempfile |
|
import gradio |
|
import argparse |
|
import math |
|
import torch |
|
import numpy as np |
|
import trimesh |
|
import copy |
|
import functools |
|
from scipy.spatial.transform import Rotation |
|
|
|
from dust3r.inference import inference |
|
from dust3r.model import AsymmetricCroCo3DStereo |
|
from dust3r.image_pairs import make_pairs |
|
from dust3r.utils.image import load_images, rgb |
|
from dust3r.utils.device import to_numpy |
|
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes |
|
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode |
|
|
|
import matplotlib.pyplot as pl |
|
pl.ion() |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
batch_size = 1 |
|
|
|
def run_dust3r_inference(args, filelist, schedule, niter, min_conf_thr, as_pointcloud, |
|
mask_sky, clean_depth, transparent_cams, cam_size, |
|
scenegraph_type, winsize, refid): |
|
tmpdirname = tempfile.mkdtemp(suffix='dust3r_gradio_demo') |
|
if args.tmp_dir is not None: |
|
tmp_path = args.tmp_dir |
|
os.makedirs(tmp_path, exist_ok=True) |
|
tempfile.tempdir = tmp_path |
|
|
|
if args.server_name is not None: |
|
server_name = args.server_name |
|
else: |
|
server_name = '0.0.0.0' if args.local_network else '127.0.0.1' |
|
|
|
if args.weights is not None: |
|
weights_path = args.weights |
|
else: |
|
weights_path = "naver/" + args.model_name |
|
model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(args.device) |
|
|
|
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, args.device, args.image_size) |
|
model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname) |
|
|
|
|
|
scene, outfile, imgs = recon_fun(filelist, schedule, niter, min_conf_thr, as_pointcloud, |
|
mask_sky, clean_depth, transparent_cams, cam_size, |
|
scenegraph_type, winsize, refid) |
|
|
|
|
|
return outfile, imgs |
|
|
|
def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False): |
|
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size) |
|
model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent) |
|
with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="DUSt3R Demo") as demo: |
|
|
|
scene = gradio.State(None) |
|
gradio.HTML('<h2 style="text-align: center;">DUSt3R Demo</h2>') |
|
with gradio.Column(): |
|
inputfiles = gradio.File(file_count="multiple") |
|
with gradio.Row(): |
|
schedule = gradio.Dropdown(["linear", "cosine"], |
|
value='linear', label="schedule", info="For global alignment!") |
|
niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000, |
|
label="num_iterations", info="For global alignment!") |
|
scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"], |
|
value='complete', label="Scenegraph", |
|
info="Define how to make pairs", |
|
interactive=True) |
|
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1, |
|
minimum=1, maximum=1, step=1, visible=False) |
|
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False) |
|
|
|
run_btn = gradio.Button("Run") |
|
|
|
with gradio.Row(): |
|
|
|
min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1) |
|
|
|
cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001) |
|
with gradio.Row(): |
|
as_pointcloud = gradio.Checkbox(value=False, label="As pointcloud") |
|
|
|
mask_sky = gradio.Checkbox(value=False, label="Mask sky") |
|
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps") |
|
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras") |
|
|
|
outmodel = gradio.Model3D() |
|
outgallery = gradio.Gallery(label='rgb,depth,confidence', columns=3, height="100%") |
|
|
|
|
|
scenegraph_type.change(set_scenegraph_options, |
|
inputs=[inputfiles, winsize, refid, scenegraph_type], |
|
outputs=[winsize, refid]) |
|
inputfiles.change(set_scenegraph_options, |
|
inputs=[inputfiles, winsize, refid, scenegraph_type], |
|
outputs=[winsize, refid]) |
|
run_btn.click(fn=recon_fun, |
|
inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud, |
|
mask_sky, clean_depth, transparent_cams, cam_size, |
|
scenegraph_type, winsize, refid], |
|
outputs=[scene, outmodel, outgallery]) |
|
min_conf_thr.release(fn=model_from_scene_fun, |
|
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, |
|
clean_depth, transparent_cams, cam_size], |
|
outputs=outmodel) |
|
cam_size.change(fn=model_from_scene_fun, |
|
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, |
|
clean_depth, transparent_cams, cam_size], |
|
outputs=outmodel) |
|
as_pointcloud.change(fn=model_from_scene_fun, |
|
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, |
|
clean_depth, transparent_cams, cam_size], |
|
outputs=outmodel) |
|
mask_sky.change(fn=model_from_scene_fun, |
|
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, |
|
clean_depth, transparent_cams, cam_size], |
|
outputs=outmodel) |
|
clean_depth.change(fn=model_from_scene_fun, |
|
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, |
|
clean_depth, transparent_cams, cam_size], |
|
outputs=outmodel) |
|
transparent_cams.change(model_from_scene_fun, |
|
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky, |
|
clean_depth, transparent_cams, cam_size], |
|
outputs=outmodel) |
|
demo.launch(share=False, server_name=server_name, server_port=server_port) |
|
|
|
|
|
inputfiles = gradio.File(label="Input Images", file_count="multiple") |
|
schedule = gradio.Dropdown(["linear", "cosine"], label="Schedule") |
|
niter = gradio.Number(label="Number of Iterations", min=0, max=5000, step=1) |
|
min_conf_thr = gradio.Slider(label="Minimum Confidence Threshold", min=1.0, max=20.0, step=0.1, default=3.0) |
|
as_pointcloud = gradio.Checkbox(label="As Pointcloud", default=False) |
|
mask_sky = gradio.Checkbox(label="Mask Sky", default=False) |
|
clean_depth = gradio.Checkbox(label="Clean-up Depthmaps", default=True) |
|
transparent_cams = gradio.Checkbox(label="Transparent Cameras", default=False) |
|
cam_size = gradio.Slider(label="Camera Size", min=0.001, max=0.1, step=0.001, default=0.05) |
|
scenegraph_type = gradio.Dropdown(["complete", "swin", "oneref"], label="Scene Graph Type") |
|
winsize = gradio.Slider(label="Window Size", min=1, max=1, step=1, default=1) |
|
refid = gradio.Slider(label="Reference ID", min=0, max=0, step=1, default=0) |
|
|
|
|
|
def run_inference(filelist, schedule, niter, min_conf_thr, as_pointcloud, mask_sky, clean_depth, |
|
transparent_cams, cam_size, scenegraph_type, winsize, refid): |
|
args = None |
|
outfile, imgs = run_dust3r_inference(args, filelist, schedule, niter, min_conf_thr, as_pointcloud, |
|
mask_sky, clean_depth, transparent_cams, cam_size, |
|
scenegraph_type, winsize, refid) |
|
return imgs |
|
|
|
|
|
iface = gradio.Interface( |
|
fn=run_inference, |
|
inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud, mask_sky, clean_depth, |
|
transparent_cams, cam_size, scenegraph_type, winsize, refid], |
|
outputs=gradio.Gallery(label="Output Images", columns=3), |
|
title="DUSt3R Demo", |
|
description="Reconstruct 3D scenes from input images using DUSt3R.", |
|
server_name='0.0.0.0', |
|
server_port=7860 |
|
) |
|
iface.launch(share=True) |
|
|