vggsfm / app.py
JianyuanWang's picture
add two more examples; remove pillow
3a82b5b
raw
history blame
12.7 kB
import os
import cv2
import torch
import numpy as np
import gradio as gr
import trimesh
import sys
import os
sys.path.append('vggsfm_code/')
import shutil
from datetime import datetime
from vggsfm_code.hf_demo import demo_fn
from omegaconf import DictConfig, OmegaConf
from viz_utils.viz_fn import add_camera, apply_density_filter_np
import glob
#
from scipy.spatial.transform import Rotation
# import PIL
import gc
import open3d as o3d
# import spaces
# @spaces.GPU
def vggsfm_demo(
input_video,
input_image,
query_frame_num,
max_query_pts=4096,
):
import time
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
debug = False
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
max_input_image = 25
target_dir = f"input_images_{timestamp}"
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
target_dir_images = target_dir + "/images"
os.makedirs(target_dir_images)
if debug:
predictions = torch.load("predictions_scene2.pth")
else:
if input_video is not None:
if not isinstance(input_video, str):
input_video = input_video["video"]["path"]
cfg_file = "vggsfm_code/cfgs/demo.yaml"
cfg = OmegaConf.load(cfg_file)
if input_image is not None:
if len(input_image)<3:
return None, "Please input at least three frames"
input_image = sorted(input_image)
input_image = input_image[:max_input_image]
# Copy files to the new directory
for file_name in input_image:
shutil.copy(file_name, target_dir_images)
elif input_video is not None:
vs = cv2.VideoCapture(input_video)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_rate = 1
frame_interval = int(fps * frame_rate)
video_frame_num = 0
count = 0
while video_frame_num<=max_input_image:
(gotit, frame) = vs.read()
count +=1
if not gotit:
break
if count % frame_interval == 0:
cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame)
video_frame_num+=1
if video_frame_num<3:
return None, "Please input at least three frames"
else:
return None, "Input format incorrect"
cfg.query_frame_num = query_frame_num
cfg.max_query_pts = max_query_pts
print(f"Files have been copied to {target_dir_images}")
cfg.SCENE_DIR = target_dir
# try:
predictions = demo_fn(cfg)
# except:
# return None, "Something seems to be incorrect. Please verify that your inputs are formatted correctly. If the issue persists, kindly create a GitHub issue for further assistance."
glbscene = vggsfm_predictions_to_glb(predictions)
glbfile = target_dir + "/glbscene.glb"
glbscene.export(file_obj=glbfile)
# glbscene.export(file_obj=glbfile, line_settings= {'point_size': 20})
del predictions
gc.collect()
torch.cuda.empty_cache()
print(input_image)
print(input_video)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time} seconds")
return glbfile, "Success"
def vggsfm_predictions_to_glb(predictions, sphere=False):
# del predictions['reconstruction']
# torch.save(predictions, "predictions_scene2.pth")
# learned from https://github.com/naver/dust3r/blob/main/dust3r/viz.py
points3D = predictions["points3D"].cpu().numpy()
points3D_rgb = predictions["points3D_rgb"].cpu().numpy()
points3D_rgb = (points3D_rgb*255).astype(np.uint8)
extrinsics_opencv = predictions["extrinsics_opencv"].cpu().numpy()
intrinsics_opencv = predictions["intrinsics_opencv"].cpu().numpy()
raw_image_paths = predictions["raw_image_paths"]
images = predictions["images"].permute(0,2,3,1).cpu().numpy()
images = (images*255).astype(np.uint8)
glbscene = trimesh.Scene()
if True:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points3D)
pcd.colors = o3d.utility.Vector3dVector(points3D_rgb)
cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=1.0)
filtered_pcd = pcd.select_by_index(ind)
print(f"Filter out {len(points3D) - len(filtered_pcd.points)} 3D points")
points3D = np.asarray(filtered_pcd.points)
points3D_rgb = np.asarray(filtered_pcd.colors)
if sphere:
# TOO SLOW
print("testing sphere")
# point_size = 0.02
else:
point_cloud = trimesh.PointCloud(points3D, colors=points3D_rgb)
glbscene.add_geometry(point_cloud)
camera_edge_colors = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204),
(128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)]
frame_num = len(extrinsics_opencv)
extrinsics_opencv_4x4 = np.zeros((frame_num, 4, 4))
extrinsics_opencv_4x4[:, :3, :4] = extrinsics_opencv
extrinsics_opencv_4x4[:, 3, 3] = 1
for idx in range(frame_num):
cam_from_world = extrinsics_opencv_4x4[idx]
cam_to_world = np.linalg.inv(cam_from_world)
cur_cam_color = camera_edge_colors[idx % len(camera_edge_colors)]
cur_focal = intrinsics_opencv[idx, 0, 0]
add_camera(glbscene, cam_to_world, cur_cam_color, image=None, imsize=(1024,1024),
focal=None,screen_width=0.35)
opengl_mat = np.array([[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]])
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
glbscene.apply_transform(np.linalg.inv(np.linalg.inv(extrinsics_opencv_4x4[0]) @ opengl_mat @ rot))
# Calculate the bounding box center and apply the translation
# bounding_box = glbscene.bounds
# center = (bounding_box[0] + bounding_box[1]) / 2
# translation = np.eye(4)
# translation[:3, 3] = -center
# glbscene.apply_transform(translation)
# glbfile = "glbscene.glb"
# glbscene.export(file_obj=glbfile)
return glbscene
apple_video = "vggsfm_code/examples/videos/apple_video.mp4"
british_museum_video = "vggsfm_code/examples/videos/british_museum_video.mp4"
cake_video = "vggsfm_code/examples/videos/cake_video.mp4"
bonsai_video = "vggsfm_code/examples/videos/bonsai_video.mp4"
face_video = "vggsfm_code/examples/videos/in2n_face_video.mp4"
counter_video = "vggsfm_code/examples/videos/in2n_counter_video.mp4"
horns_video = "vggsfm_code/examples/videos/llff_horns_video.mp4"
person_video = "vggsfm_code/examples/videos/in2n_person_video.mp4"
flower_video = "vggsfm_code/examples/videos/llff_flower_video.mp4"
fern_video = "vggsfm_code/examples/videos/llff_fern_video.mp4"
apple_images = glob.glob(f'vggsfm_code/examples/apple/images/*')
bonsai_images = glob.glob(f'vggsfm_code/examples/bonsai/images/*')
cake_images = glob.glob(f'vggsfm_code/examples/cake/images/*')
british_museum_images = glob.glob(f'vggsfm_code/examples/british_museum/images/*')
face_images = glob.glob(f'vggsfm_code/examples/in2n_face/images/*')
counter_images = glob.glob(f'vggsfm_code/examples/in2n_counter/images/*')
horns_images = glob.glob(f'vggsfm_code/examples/llff_horns/images/*')
person_images = glob.glob(f'vggsfm_code/examples/in2n_person/images/*')
flower_images = glob.glob(f'vggsfm_code/examples/llff_flower/images/*')
fern_images = glob.glob(f'vggsfm_code/examples/llff_fern/images/*')
with gr.Blocks() as demo:
gr.Markdown("# 🏛️ VGGSfM: Visual Geometry Grounded Deep Structure From Motion")
gr.Markdown("""
<div style="text-align: left;">
<p>Welcome to <a href="https://vggsfm.github.io/" target="_blank">VGGSfM</a> demo!
This space demonstrates 3D reconstruction from input image frames. </p>
<p>To get started quickly, you can click on our <strong> examples (the bottom of the page) </strong>. If you want to reconstruct your own data, simply: </p>
<ul style="display: inline-block; text-align: left;">
<li>upload images (.jpg, .png, etc.), or </li>
<li>upload a video (.mp4, .mov, etc.) </li>
</ul>
<p>If both images and videos are uploaded, the demo will only reconstruct the uploaded images. By default, we extract <strong> 1 image frame per second from the input video </strong>. To prevent crashes on the Hugging Face space, we currently limit reconstruction to the first 25 image frames. </p>
<p>SfM methods are designed for <strong> rigid/static reconstruction </strong>. When dealing with dynamic/moving inputs, these methods may still work by focusing on the rigid parts of the scene. However, to ensure high-quality results, it is better to minimize the presence of moving objects in the input data. </p>
<p>The reconstruction should typically take <strong> up to 90 seconds </strong>. If it takes longer, the input data is likely not well-conditioned or the query images/points are set too high. </p>
<p>If you meet any problem, feel free to create an issue in our <a href="https://github.com/facebookresearch/vggsfm" target="_blank">GitHub Repo</a> ⭐</p>
<p>(Please note that running reconstruction on Hugging Face space is slower than on a local machine.) </p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
input_video = gr.Video(label="Input video", interactive=True)
input_images = gr.File(file_count="multiple", label="Input Images", interactive=True)
num_query_images = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of query images (key frames)",
info="More query images usually lead to better reconstruction at lower speeds. If the viewpoint differences between your images are minimal, you can set this value to 1. ")
num_query_points = gr.Slider(minimum=600, maximum=6000, step=1, value=2048, label="Number of query points",
info="More query points usually lead to denser reconstruction at lower speeds.")
with gr.Column(scale=3):
reconstruction_output = gr.Model3D(label="Reconstruction", height=520, zoom_speed=0.5, pan_speed=0.5)
log_output = gr.Textbox(label="Log")
with gr.Row():
submit_btn = gr.Button("Reconstruct", scale=1)
# submit_btn = gr.Button("Reconstruct", scale=1, elem_attributes={"style": "background-color: blue; color: white;"})
clear_btn = gr.ClearButton([input_video, input_images, num_query_images, num_query_points, reconstruction_output, log_output], scale=1)
examples = [
[counter_video, counter_images, 4, 2048],
[person_video, person_images, 3, 2048],
[horns_video, horns_images, 3, 4096],
[fern_video, fern_images, 2, 4096],
[flower_video, flower_images, 2, 4096],
[face_video, face_images, 4, 2048],
[apple_video, apple_images, 6, 2048],
[british_museum_video, british_museum_images, 1, 4096],
[bonsai_video, bonsai_images, 3, 2048],
# [cake_video, cake_images, 3, 2048],
]
gr.Examples(examples=examples,
inputs=[input_video, input_images, num_query_images, num_query_points],
outputs=[reconstruction_output, log_output], # Provide outputs
fn=vggsfm_demo, # Provide the function
cache_examples=True,
)
submit_btn.click(
vggsfm_demo,
[input_video, input_images, num_query_images, num_query_points],
[reconstruction_output, log_output],
concurrency_limit=1
)
# demo.launch(debug=True, share=True)
demo.queue(max_size=20).launch(show_error=True, share=True)
# demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True)
########################################################################################################################
# else:
# import glob
# files = glob.glob(f'vggsfm_code/examples/cake/images/*', recursive=True)
# vggsfm_demo(files, None, None)
# demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True)