import os import cv2 import torch import numpy as np import gradio as gr import sys import os import socket import webbrowser sys.path.append('vggt/') import shutil from datetime import datetime from demo_hf import demo_fn from omegaconf import DictConfig, OmegaConf import glob import gc import time from viser_fn import viser_wrapper def get_free_port(): """Get a free port using socket.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(('', 0)) port = s.getsockname()[1] return port def vggt_demo( input_video, input_image, ): start_time = time.time() gc.collect() torch.cuda.empty_cache() debug = False timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") target_dir = f"input_images_{timestamp}" if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) target_dir_images = target_dir + "/images" os.makedirs(target_dir_images) if input_video is not None: if not isinstance(input_video, str): input_video = input_video["video"]["path"] cfg_file = "config/base.yaml" cfg = OmegaConf.load(cfg_file) if input_image is not None: input_image = sorted(input_image) # recon_num = len(input_image) # Copy files to the new directory for file_name in input_image: shutil.copy(file_name, target_dir_images) elif input_video is not None: vs = cv2.VideoCapture(input_video) fps = vs.get(cv2.CAP_PROP_FPS) frame_rate = 1 frame_interval = int(fps * frame_rate) video_frame_num = 0 count = 0 while True: (gotit, frame) = vs.read() count +=1 if not gotit: break if count % frame_interval == 0: cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame) video_frame_num+=1 else: return None, "Uploading not finished or Incorrect input format" print(f"Files have been copied to {target_dir_images}") cfg.SCENE_DIR = target_dir predictions = demo_fn(cfg) # Get a free port for viser viser_port = get_free_port() # Start viser visualization in a separate thread/process viser_wrapper(predictions, port=viser_port) del predictions gc.collect() torch.cuda.empty_cache() print(input_image) print(input_video) end_time = time.time() execution_time = end_time - start_time print(f"Execution time: {execution_time} seconds") return None, viser_port statue_video = "examples/videos/statue_video.mp4" apple_video = "examples/videos/apple_video.mp4" british_museum_video = "examples/videos/british_museum_video.mp4" cake_video = "examples/videos/cake_video.mp4" bonsai_video = "examples/videos/bonsai_video.mp4" face_video = "examples/videos/in2n_face_video.mp4" counter_video = "examples/videos/in2n_counter_video.mp4" horns_video = "examples/videos/llff_horns_video.mp4" person_video = "examples/videos/in2n_person_video.mp4" flower_video = "examples/videos/llff_flower_video.mp4" fern_video = "examples/videos/llff_fern_video.mp4" drums_video = "examples/videos/drums_video.mp4" kitchen_video = "examples/videos/kitchen_video.mp4" ########################################################################################### apple_images = glob.glob(f'examples/apple/images/*') bonsai_images = glob.glob(f'examples/bonsai/images/*') cake_images = glob.glob(f'examples/cake/images/*') british_museum_images = glob.glob(f'examples/british_museum/images/*') face_images = glob.glob(f'examples/in2n_face/images/*') counter_images = glob.glob(f'examples/in2n_counter/images/*') horns_images = glob.glob(f'examples/llff_horns/images/*') person_images = glob.glob(f'examples/in2n_person/images/*') flower_images = glob.glob(f'examples/llff_flower/images/*') fern_images = glob.glob(f'examples/llff_fern/images/*') statue_images = glob.glob(f'examples/statue/images/*') drums_images = glob.glob(f'examples/drums/images/*') kitchen_images = glob.glob(f'examples/kitchen/images/*') ########################################################################################### with gr.Blocks() as demo: gr.Markdown(""" # 🏛️ VGGT: Visual Geometry Grounded Transformer
Alpha version (testing).
""") with gr.Row(): with gr.Column(scale=1): input_video = gr.Video(label="Upload Video", interactive=True) input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True) with gr.Column(scale=3): viser_output = gr.HTML( label="Viser Visualization", value='''
3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)
''' ) log_output = gr.Textbox(label="Log") with gr.Row(): submit_btn = gr.Button("Reconstruct", scale=1) clear_btn = gr.ClearButton([input_video, input_images, viser_output, log_output], scale=1) #Modified viser_output examples = [ [flower_video, flower_images], [kitchen_video, kitchen_images], [counter_video, counter_images], [fern_video, fern_images], [horns_video, horns_images], ] def process_example(video, images): """Wrapper function to ensure outputs are properly captured""" model_output, log = vggt_demo(video, images) viser_url = f"http://localhost:{log}" print(f"Viser URL: {viser_url}") # Create the iframe HTML code. Set width and height appropriately. iframe_code = f'' return iframe_code, f"Visualization running at {viser_url}" gr.Examples(examples=examples, inputs=[input_video, input_images], outputs=[viser_output, log_output], # Output to viser_output fn=process_example, # Use our wrapper function cache_examples=False, examples_per_page=50, ) submit_btn.click( process_example, # Use the same wrapper function [input_video, input_images], [viser_output, log_output], # Output to viser_output concurrency_limit=1 ) demo.queue(max_size=20).launch(show_error=True, share=True, server_port=7888, server_name="0.0.0.0")