haotongl commited on
Commit
141f1e8
·
1 Parent(s): 71ea1f8

inital version

Browse files
Files changed (1) hide show
  1. app.py +215 -8
app.py CHANGED
@@ -1,12 +1,219 @@
 
 
 
 
 
 
1
  import spaces
 
 
 
 
2
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  @spaces.GPU
5
- def generate(prompt):
6
- return 'hello, world'
7
-
8
- gr.Interface(
9
- fn=generate,
10
- inputs=gr.Text(),
11
- outputs=gr.Gallery(),
12
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import shutil
4
+ from pathlib import Path
5
+ from typing import Union
6
+ import atexit
7
  import spaces
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ import open3d as o3d
10
+ import trimesh
11
+
12
  import gradio as gr
13
+ from gradio_imageslider import ImageSlider
14
+ import cv2
15
+ import numpy as np
16
+ import click
17
+ import imageio
18
+ from promptda.promptda import PromptDA
19
+ from promptda.utils.io_wrapper import load_image, load_depth
20
+ from promptda.utils.depth_utils import visualize_depth, unproject_depth
21
+ # import torch
22
+ DEVICE = 'cuda'
23
+ # if torch.cuda.is_available(
24
+ # ) else 'mps' if torch.backends.mps.is_available() else 'cpu'
25
+ # model = PromptDA.from_pretrained('depth-anything/promptda_vitl').to(DEVICE).eval()
26
+ model = PromptDA.from_pretrained('depth-anything/promptda_vitl').eval()
27
+ thread_pool_executor = ThreadPoolExecutor(max_workers=1)
28
+
29
+ def delete_later(path: Union[str, os.PathLike], delay: int = 300):
30
+ print(f"Deleting file: {path}")
31
+ def _delete():
32
+ try:
33
+ if os.path.isfile(path):
34
+ os.remove(path)
35
+ print(f"Deleted file: {path}")
36
+ elif os.path.isdir(path):
37
+ shutil.rmtree(path)
38
+ print(f"Deleted directory: {path}")
39
+ except:
40
+ pass
41
+ def _wait_and_delete():
42
+ time.sleep(delay)
43
+ _delete(path)
44
+ thread_pool_executor.submit(_wait_and_delete)
45
+ atexit.register(_delete)
46
+
47
 
48
  @spaces.GPU
49
+ def run_with_gpu(image, prompt_depth):
50
+ image = image.to(DEVICE)
51
+ prompt_depth = prompt_depth.to(DEVICE)
52
+ model.to(DEVICE)
53
+ depth = model.predict(image, prompt_depth)
54
+ depth = depth[0, 0].detach().cpu().numpy()
55
+ return depth
56
+
57
+ def check_is_stray_scanner_app_capture(input_dir):
58
+ assert os.path.exists(os.path.join(input_dir, 'rgb.mp4')), 'rgb.mp4 not found'
59
+ pass
60
+
61
+ def run(input_file, resolution):
62
+ # unzip zip file
63
+ input_file = input_file.name
64
+ root_dir = os.path.dirname(input_file)
65
+ scene_name = input_file.split('/')[-1].split('.')[0]
66
+ input_dir = os.path.join(root_dir, scene_name)
67
+ cmd = f'unzip -o {input_file} -d {root_dir}'
68
+ os.system(cmd)
69
+ check_is_stray_scanner_app_capture(input_dir)
70
+
71
+ # extract rgb images
72
+ os.makedirs(os.path.join(input_dir, 'rgb'), exist_ok=True)
73
+ cmd = f'ffmpeg -i {input_dir}/rgb.mp4 -start_number 0 -frames:v 10 -q:v 2 {input_dir}/rgb/%06d.jpg'
74
+ os.system(cmd)
75
+
76
+ # Loading & Inference
77
+ image_path = os.path.join(input_dir, 'rgb', '000000.jpg')
78
+ image = load_image(image_path)
79
+ prompt_depth_path = os.path.join(input_dir, 'depth/000000.png')
80
+ prompt_depth = load_depth(prompt_depth_path)
81
+ depth = run_with_gpu(image, prompt_depth)
82
+
83
+
84
+ color = (image[0].permute(1,2,0).cpu().numpy() * 255.).astype(np.uint8)
85
+
86
+ # Visualization file
87
+ vis_depth, depth_min, depth_max = visualize_depth(depth, ret_minmax=True)
88
+ vis_prompt_depth = visualize_depth(prompt_depth[0, 0].detach().cpu().numpy(), depth_min=depth_min, depth_max=depth_max)
89
+ vis_prompt_depth = cv2.resize(vis_prompt_depth, (vis_depth.shape[1], vis_depth.shape[0]), interpolation=cv2.INTER_NEAREST)
90
+
91
+ # PLY File
92
+ ixt_path = os.path.join(input_dir, f'camera_matrix.csv')
93
+ ixt = np.loadtxt(ixt_path, delimiter=',')
94
+ orig_max = 1920
95
+ now_max = max(color.shape[1], color.shape[0])
96
+ scale = orig_max / now_max
97
+ ixt[:2] = ixt[:2] / scale
98
+ pcd = unproject_depth(depth, ixt=ixt, color=color, ret_pcd=True)
99
+ ply_path = os.path.join(input_dir, f'pointcloud.ply')
100
+ o3d.io.write_point_cloud(ply_path, pcd)
101
+
102
+ glb_path = os.path.join(input_dir, f'pointcloud.glb')
103
+ scene_3d = trimesh.Scene()
104
+ glb_colors = np.asarray(pcd.colors).astype(np.float32)
105
+ glb_colors = np.concatenate([glb_colors, np.ones_like(glb_colors[:, :1])], axis=1)
106
+ # glb_colors = (np.asarray(pcd.colors) * 255).astype(np.uint8)
107
+ pcd_data = trimesh.PointCloud(
108
+ vertices=np.asarray(pcd.points) * np.array([[1, -1, -1]]),
109
+ colors=glb_colors.astype(np.float64),
110
+ )
111
+ scene_3d.add_geometry(pcd_data)
112
+ scene_3d.export(file_obj=glb_path)
113
+ # o3d.io.write_point_cloud(glb_path, pcd)
114
+
115
+ # Depth Map Original Value
116
+ depth_path = os.path.join(input_dir, f'depth.png')
117
+ output_depth = (depth * 1000).astype(np.uint16)
118
+ imageio.imwrite(depth_path, output_depth)
119
+
120
+
121
+ delete_later(Path(input_dir))
122
+ delete_later(Path(input_file))
123
+
124
+ return color, (vis_depth, vis_prompt_depth), Path(glb_path), Path(ply_path).as_posix(), Path(depth_path).as_posix()
125
+
126
+ DESCRIPTION = """
127
+ # Estimate accurate and high-resolution depth maps from your iPhone capture.
128
+
129
+ Project Page: [Prompt Depth Anything](https://promptda.github.io/)
130
+
131
+ ## Requirements:
132
+ 1. iPhone 12 Pro or later Pro models, iPad 2020 Pro or later Pro models
133
+ 2. Free iOS App: [Stray Scanner App](https://apps.apple.com/us/app/stray-scanner/id1557051662)
134
+
135
+ ## Testing Steps:
136
+ 1. Capture a scene with the Stray Scanner App.
137
+ 2. Use the iPhone [Files App](https://apps.apple.com/us/app/files/id1232058109) to compress it into a zip file and transfer it to your computer. (Long press the capture folder to compress)
138
+ 3. Upload the zip file and click "Submit" to get the depth map of the first frame.
139
+
140
+ Note:
141
+ - Currently, this demo only supports inference for the first frame. If you need to obtain all depth frames, please refer to our [GitHub repo](https://github.com/DepthAnything/PromptDA).
142
+ - The depth map is stored as uint16, with a unit of millimeters.
143
+ """
144
+
145
+ # @click.command()
146
+ # @click.option('--share', is_flag=True, help='Whether to run the app in shared mode.')
147
+ # def main():
148
+ # with gr.Blocks(theme=gr.themes.Soft()) as demo:
149
+ # gr.Markdown(DESCRIPTION)
150
+
151
+ # with gr.Row():
152
+ # input_file = gr.File(type="filepath", label="Stray scanner app capture zip file")
153
+ # resolution = gr.Dropdown(choices=['756x1008', '1428x1904'], value='756x1008', label="Inference resolution")
154
+ # submit_btn = gr.Button("Submit")
155
+
156
+ # # gr.Examples(examples=[
157
+ # # ["data/assets/example0_chair.zip", "756x1008"]
158
+ # # ],
159
+ # # inputs=[input_file, resolution],
160
+ # # label="Examples",
161
+ # # )
162
+
163
+ # with gr.Row():
164
+ # with gr.Column():
165
+ # output_rgb = gr.Image(type="numpy", label="RGB Image")
166
+ # with gr.Column():
167
+ # output_depths = ImageSlider(label="Output depth / prompt depth", position=0.5)
168
+
169
+ # with gr.Row():
170
+ # with gr.Column():
171
+ # output_3d_model = gr.Model3D(label="3D Viewer", display_mode='solid', clear_color=[1.0, 1.0, 1.0, 1.0])
172
+ # with gr.Column():
173
+ # output_ply = gr.File(type="filepath", label="Download the unprojected point cloud as .ply file", height=30)
174
+ # output_depth_map = gr.File(type="filepath", label="Download the depth map as .png file", height=30)
175
+ # outputs = [
176
+ # output_rgb,
177
+ # output_depths,
178
+ # output_3d_model,
179
+ # output_ply,
180
+ # output_depth_map,
181
+ # ]
182
+ # gr.Examples(examples=[
183
+ # ["data/assets/example0_chair.zip", "756x1008"]
184
+ # ],
185
+ # fn=run,
186
+ # inputs=[input_file, resolution],
187
+ # outputs=outputs,
188
+ # label="Examples",
189
+ # cache_examples=True,
190
+ # )
191
+ # submit_btn.click(run,
192
+ # inputs=[input_file, resolution],
193
+ # outputs=outputs)
194
+
195
+ # demo.launch()
196
+ def main():
197
+ gr.Interface(
198
+ fn=run,
199
+ inputs=[
200
+ gr.File(type="filepath", label="Stray scanner app capture zip file"),
201
+ gr.Dropdown(choices=['756x1008', '1428x1904'], value='756x1008', label="Inference resolution")
202
+ ],
203
+ outputs=[
204
+ gr.Image(type="numpy", label="RGB Image"),
205
+ ImageSlider(label="Depth map / prompt depth", position=0.5),
206
+ gr.Model3D(label="3D Viewer", display_mode='solid', clear_color=[1.0, 1.0, 1.0, 1.0]),
207
+ gr.File(type="filepath", label="Download the unprojected point cloud as .ply file"),
208
+ gr.File(type="filepath", label="Download the depth map as .png file"),
209
+ ],
210
+ title=None,
211
+ description=DESCRIPTION,
212
+ clear_btn=None,
213
+ allow_flagging="never",
214
+ theme=gr.themes.Soft(),
215
+ examples=[
216
+ ["data/assets/example0_chair.zip"]
217
+ ]
218
+ ).launch()
219
+ main()