Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import shutil | |
from pathlib import Path | |
from typing import Union | |
import atexit | |
import spaces | |
from concurrent.futures import ThreadPoolExecutor | |
import trimesh | |
import gradio as gr | |
from gradio_imageslider import ImageSlider | |
import cv2 | |
import numpy as np | |
import imageio | |
from promptda.promptda import PromptDA | |
from promptda.utils.io_wrapper import load_image, load_depth | |
from promptda.utils.depth_utils import visualize_depth, unproject_depth | |
DEVICE = 'cuda' | |
# if torch.cuda.is_available( | |
# ) else 'mps' if torch.backends.mps.is_available() else 'cpu' | |
model = PromptDA.from_pretrained('depth-anything/promptda_vitl').to(DEVICE).eval() | |
# model = PromptDA.from_pretrained('depth-anything/promptda_vitl').eval() | |
thread_pool_executor = ThreadPoolExecutor(max_workers=1) | |
def delete_later(path: Union[str, os.PathLike], delay: int = 300): | |
print(f"Deleting file: {path}") | |
def _delete(): | |
try: | |
if os.path.isfile(path): | |
os.remove(path) | |
print(f"Deleted file: {path}") | |
elif os.path.isdir(path): | |
shutil.rmtree(path) | |
print(f"Deleted directory: {path}") | |
except: | |
pass | |
def _wait_and_delete(): | |
time.sleep(delay) | |
_delete(path) | |
thread_pool_executor.submit(_wait_and_delete) | |
atexit.register(_delete) | |
def run_with_gpu(image, prompt_depth): | |
image = image.to(DEVICE) | |
prompt_depth = prompt_depth.to(DEVICE) | |
depth = model.predict(image, prompt_depth) | |
depth = depth[0, 0].detach().cpu().numpy() | |
return depth | |
def check_is_stray_scanner_app_capture(input_dir): | |
assert os.path.exists(os.path.join(input_dir, 'rgb.mp4')), 'rgb.mp4 not found' | |
pass | |
# @spaces.GPU | |
def run(input_file, resolution): | |
# unzip zip file | |
input_file = input_file.name | |
root_dir = os.path.dirname(input_file) | |
scene_name = input_file.split('/')[-1].split('.')[0] | |
input_dir = os.path.join(root_dir, scene_name) | |
cmd = f'unzip -o {input_file} -d {root_dir}' | |
os.system(cmd) | |
check_is_stray_scanner_app_capture(input_dir) | |
# extract rgb images | |
os.makedirs(os.path.join(input_dir, 'rgb'), exist_ok=True) | |
cmd = f'ffmpeg -i {input_dir}/rgb.mp4 -start_number 0 -frames:v 10 -q:v 2 {input_dir}/rgb/%06d.jpg' | |
os.system(cmd) | |
# Loading & Inference | |
image_path = os.path.join(input_dir, 'rgb', '000000.jpg') | |
image = load_image(image_path) | |
prompt_depth_path = os.path.join(input_dir, 'depth/000000.png') | |
prompt_depth = load_depth(prompt_depth_path) | |
depth = run_with_gpu(image, prompt_depth) | |
color = (image[0].permute(1,2,0).cpu().numpy() * 255.).astype(np.uint8) | |
# Visualization file | |
vis_depth, depth_min, depth_max = visualize_depth(depth, ret_minmax=True) | |
vis_prompt_depth = visualize_depth(prompt_depth[0, 0].detach().cpu().numpy(), depth_min=depth_min, depth_max=depth_max) | |
vis_prompt_depth = cv2.resize(vis_prompt_depth, (vis_depth.shape[1], vis_depth.shape[0]), interpolation=cv2.INTER_NEAREST) | |
# Add text to vis_prompt_depth | |
text_x = vis_prompt_depth.shape[1] - 250 + 15 | |
text_y = vis_prompt_depth.shape[0] - 45 + 27 | |
vis_prompt_depth = cv2.rectangle(vis_prompt_depth, | |
(vis_prompt_depth.shape[1] - 250, vis_prompt_depth.shape[0] - 45), | |
(vis_prompt_depth.shape[1] - 5, vis_prompt_depth.shape[0] - 5), | |
(70, 70, 70), -1) | |
vis_prompt_depth = cv2.putText(vis_prompt_depth, 'Prompt depth', | |
(text_x, text_y), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1, (255, 255, 255), 2, cv2.LINE_AA) | |
text_x = 5 + 15 | |
text_y = vis_depth.shape[0] - 45 + 27 | |
vis_depth = cv2.rectangle(vis_depth, | |
(5, vis_depth.shape[0] - 45), | |
(250, vis_depth.shape[0] - 5), | |
(70, 70, 70), -1) | |
vis_depth = cv2.putText(vis_depth, 'Output depth', | |
(text_x, text_y), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1, (255, 255, 255), 2, cv2.LINE_AA) | |
# PLY File | |
ixt_path = os.path.join(input_dir, f'camera_matrix.csv') | |
ixt = np.loadtxt(ixt_path, delimiter=',') | |
orig_max = 1920 | |
now_max = max(color.shape[1], color.shape[0]) | |
scale = orig_max / now_max | |
ixt[:2] = ixt[:2] / scale | |
points, colors = unproject_depth(depth, ixt=ixt, color=color, ret_pcd=False) | |
pcd = trimesh.PointCloud(vertices=points, colors=colors) | |
ply_path = os.path.join(input_dir, f'pointcloud.ply') | |
pcd.export(ply_path) | |
# o3d.io.write_point_cloud(ply_path, pcd) | |
glb_path = os.path.join(input_dir, f'pointcloud.glb') | |
scene_3d = trimesh.Scene() | |
glb_colors = np.asarray(colors).astype(np.float32) | |
glb_colors = np.concatenate([glb_colors, np.ones_like(glb_colors[:, :1])], axis=1) | |
# glb_colors = (np.asarray(pcd.colors) * 255).astype(np.uint8) | |
pcd_data = trimesh.PointCloud( | |
vertices=np.asarray(points) * np.array([[1, -1, -1]]), | |
colors=glb_colors.astype(np.float64), | |
) | |
scene_3d.add_geometry(pcd_data) | |
scene_3d.export(file_obj=glb_path) | |
# o3d.io.write_point_cloud(glb_path, pcd) | |
# Depth Map Original Value | |
depth_path = os.path.join(input_dir, f'depth.png') | |
output_depth = (depth * 1000).astype(np.uint16) | |
imageio.imwrite(depth_path, output_depth) | |
delete_later(Path(input_dir)) | |
delete_later(Path(input_file)) | |
return color, (vis_depth, vis_prompt_depth), Path(glb_path), Path(ply_path).as_posix(), Path(depth_path).as_posix() | |
DESCRIPTION = """ | |
# Estimate accurate and high-resolution depth maps from your iPhone capture. | |
Project Page: [Prompt Depth Anything](https://promptda.github.io/) | |
## Requirements: | |
1. iPhone 12 Pro or later Pro models, iPad 2020 Pro or later Pro models. | |
2. Free iOS App: [Stray Scanner App](https://apps.apple.com/us/app/stray-scanner/id1557051662). | |
## Testing Steps: | |
1. Capture a scene with the Stray Scanner App. Use the iPhone [Files App](https://apps.apple.com/us/app/files/id1232058109) to compress it into a zip file and transfer it to your computer. [Example screen recording.](https://haotongl.github.io/promptda/assets/ScreenRecording_12-16-2024.mp4). | |
2. Upload the zip file and click "Submit" to get the depth map of the first frame. | |
Note: | |
- Currently, this demo only supports inference for the first frame. If you need to obtain all depth frames, please refer to our [GitHub repo](https://github.com/DepthAnything/PromptDA). | |
- The depth map is stored as uint16, with a unit of millimeters. | |
- **You can refer to the bottom of this page for an example demo.** | |
""" | |
def main(): | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
input_file = gr.File(type="filepath", label="Stray scanner app capture zip file") | |
resolution = gr.Dropdown(choices=['756x1008', '1428x1904'], value='756x1008', label="Inference resolution") | |
submit_btn = gr.Button("Submit") | |
# gr.Examples(examples=[ | |
# ["data/assets/example0_chair.zip", "756x1008"] | |
# ], | |
# inputs=[input_file, resolution], | |
# label="Examples", | |
# ) | |
with gr.Row(): | |
with gr.Column(): | |
output_rgb = gr.Image(type="numpy", label="RGB Image") | |
with gr.Column(): | |
output_depths = ImageSlider(label="Output depth / prompt depth", position=0.5) | |
with gr.Row(): | |
with gr.Column(): | |
output_3d_model = gr.Model3D(label="3D Viewer", display_mode='solid', clear_color=[1.0, 1.0, 1.0, 1.0]) | |
with gr.Column(): | |
output_ply = gr.File(type="filepath", label="Download the unprojected point cloud as .ply file", height=30) | |
output_depth_map = gr.File(type="filepath", label="Download the depth map as .png file", height=30) | |
outputs = [ | |
output_rgb, | |
output_depths, | |
output_3d_model, | |
output_ply, | |
output_depth_map, | |
] | |
gr.Examples(examples=[ | |
["data/assets/example0_chair.zip", "756x1008"] | |
], | |
fn=run, | |
inputs=[input_file, resolution], | |
outputs=outputs, | |
label="Examples", | |
cache_examples=True, | |
) | |
submit_btn.click(run, | |
inputs=[input_file, resolution], | |
outputs=outputs) | |
demo.launch(share=True) | |
# def main(): | |
# gr.Interface( | |
# fn=run, | |
# inputs=[ | |
# gr.File(type="filepath", label="Stray scanner app capture zip file"), | |
# gr.Dropdown(choices=['756x1008', '1428x1904'], value='756x1008', label="Inference resolution") | |
# ], | |
# outputs=[ | |
# gr.Image(type="numpy", label="RGB Image"), | |
# ImageSlider(label="Depth map / prompt depth", position=0.5), | |
# gr.Model3D(label="3D Viewer", display_mode='solid', clear_color=[1.0, 1.0, 1.0, 1.0]), | |
# gr.File(type="filepath", label="Download the unprojected point cloud as .ply file"), | |
# gr.File(type="filepath", label="Download the depth map as .png file"), | |
# ], | |
# title=None, | |
# description=DESCRIPTION, | |
# clear_btn=None, | |
# allow_flagging="never", | |
# theme=gr.themes.Soft(), | |
# examples=[ | |
# ["data/assets/example0_chair.zip"] | |
# ] | |
# ).launch() | |
main() |