haotongl commited on
Commit
71ea1f8
·
1 Parent(s): fa2db85

inital version

Browse files
Files changed (1) hide show
  1. app.py +8 -218
app.py CHANGED
@@ -1,222 +1,12 @@
1
- import os
2
- import time
3
- import shutil
4
- from pathlib import Path
5
- from typing import Union
6
- import atexit
7
  import spaces
8
- from concurrent.futures import ThreadPoolExecutor
9
- import open3d as o3d
10
- import trimesh
11
-
12
  import gradio as gr
13
- from gradio_imageslider import ImageSlider
14
- import cv2
15
- import numpy as np
16
- import click
17
- import imageio
18
- from promptda.promptda import PromptDA
19
- from promptda.utils.io_wrapper import load_image, load_depth
20
- from promptda.utils.depth_utils import visualize_depth, unproject_depth
21
- # import torch
22
- DEVICE = 'cuda'
23
- # if torch.cuda.is_available(
24
- # ) else 'mps' if torch.backends.mps.is_available() else 'cpu'
25
- # model = PromptDA.from_pretrained('depth-anything/promptda_vitl').to(DEVICE).eval()
26
- model = PromptDA.from_pretrained('depth-anything/promptda_vitl').eval()
27
- thread_pool_executor = ThreadPoolExecutor(max_workers=1)
28
-
29
- def delete_later(path: Union[str, os.PathLike], delay: int = 300):
30
- print(f"Deleting file: {path}")
31
- def _delete():
32
- try:
33
- if os.path.isfile(path):
34
- os.remove(path)
35
- print(f"Deleted file: {path}")
36
- elif os.path.isdir(path):
37
- shutil.rmtree(path)
38
- print(f"Deleted directory: {path}")
39
- except:
40
- pass
41
- def _wait_and_delete():
42
- time.sleep(delay)
43
- _delete(path)
44
- thread_pool_executor.submit(_wait_and_delete)
45
- atexit.register(_delete)
46
-
47
 
48
  @spaces.GPU
49
- def run_with_gpu(image, prompt_depth):
50
- image = image.to(DEVICE)
51
- prompt_depth = prompt_depth.to(DEVICE)
52
- model.to(DEVICE)
53
- depth = model.predict(image, prompt_depth)
54
- depth = depth[0, 0].detach().cpu().numpy()
55
- return depth
56
-
57
- def check_is_stray_scanner_app_capture(input_dir):
58
- assert os.path.exists(os.path.join(input_dir, 'rgb.mp4')), 'rgb.mp4 not found'
59
- pass
60
-
61
- def run(input_file, resolution):
62
- # unzip zip file
63
- input_file = input_file.name
64
- root_dir = os.path.dirname(input_file)
65
- scene_name = input_file.split('/')[-1].split('.')[0]
66
- input_dir = os.path.join(root_dir, scene_name)
67
- cmd = f'unzip -o {input_file} -d {root_dir}'
68
- os.system(cmd)
69
- check_is_stray_scanner_app_capture(input_dir)
70
-
71
- # extract rgb images
72
- os.makedirs(os.path.join(input_dir, 'rgb'), exist_ok=True)
73
- cmd = f'ffmpeg -i {input_dir}/rgb.mp4 -start_number 0 -frames:v 10 -q:v 2 {input_dir}/rgb/%06d.jpg'
74
- os.system(cmd)
75
-
76
- # Loading & Inference
77
- image_path = os.path.join(input_dir, 'rgb', '000000.jpg')
78
- image = load_image(image_path)
79
- prompt_depth_path = os.path.join(input_dir, 'depth/000000.png')
80
- prompt_depth = load_depth(prompt_depth_path)
81
- depth = run_with_gpu(image, prompt_depth)
82
-
83
-
84
- color = (image[0].permute(1,2,0).cpu().numpy() * 255.).astype(np.uint8)
85
-
86
- # Visualization file
87
- vis_depth, depth_min, depth_max = visualize_depth(depth, ret_minmax=True)
88
- vis_prompt_depth = visualize_depth(prompt_depth[0, 0].detach().cpu().numpy(), depth_min=depth_min, depth_max=depth_max)
89
- vis_prompt_depth = cv2.resize(vis_prompt_depth, (vis_depth.shape[1], vis_depth.shape[0]), interpolation=cv2.INTER_NEAREST)
90
-
91
- # PLY File
92
- ixt_path = os.path.join(input_dir, f'camera_matrix.csv')
93
- ixt = np.loadtxt(ixt_path, delimiter=',')
94
- orig_max = 1920
95
- now_max = max(color.shape[1], color.shape[0])
96
- scale = orig_max / now_max
97
- ixt[:2] = ixt[:2] / scale
98
- pcd = unproject_depth(depth, ixt=ixt, color=color, ret_pcd=True)
99
- ply_path = os.path.join(input_dir, f'pointcloud.ply')
100
- o3d.io.write_point_cloud(ply_path, pcd)
101
-
102
- glb_path = os.path.join(input_dir, f'pointcloud.glb')
103
- scene_3d = trimesh.Scene()
104
- glb_colors = np.asarray(pcd.colors).astype(np.float32)
105
- glb_colors = np.concatenate([glb_colors, np.ones_like(glb_colors[:, :1])], axis=1)
106
- # glb_colors = (np.asarray(pcd.colors) * 255).astype(np.uint8)
107
- pcd_data = trimesh.PointCloud(
108
- vertices=np.asarray(pcd.points) * np.array([[1, -1, -1]]),
109
- colors=glb_colors.astype(np.float64),
110
- )
111
- scene_3d.add_geometry(pcd_data)
112
- scene_3d.export(file_obj=glb_path)
113
- # o3d.io.write_point_cloud(glb_path, pcd)
114
-
115
- # Depth Map Original Value
116
- depth_path = os.path.join(input_dir, f'depth.png')
117
- output_depth = (depth * 1000).astype(np.uint16)
118
- imageio.imwrite(depth_path, output_depth)
119
-
120
-
121
- delete_later(Path(input_dir))
122
- delete_later(Path(input_file))
123
-
124
- return color, (vis_depth, vis_prompt_depth), Path(glb_path), Path(ply_path).as_posix(), Path(depth_path).as_posix()
125
-
126
- DESCRIPTION = """
127
- # Estimate accurate and high-resolution depth maps from your iPhone capture.
128
-
129
- Project Page: [Prompt Depth Anything](https://promptda.github.io/)
130
-
131
- ## Requirements:
132
- 1. iPhone 12 Pro or later Pro models, iPad 2020 Pro or later Pro models
133
- 2. Free iOS App: [Stray Scanner App](https://apps.apple.com/us/app/stray-scanner/id1557051662)
134
-
135
- ## Testing Steps:
136
- 1. Capture a scene with the Stray Scanner App.
137
- 2. Use the iPhone [Files App](https://apps.apple.com/us/app/files/id1232058109) to compress it into a zip file and transfer it to your computer. (Long press the capture folder to compress)
138
- 3. Upload the zip file and click "Submit" to get the depth map of the first frame.
139
-
140
- Note:
141
- - Currently, this demo only supports inference for the first frame. If you need to obtain all depth frames, please refer to our [GitHub repo](https://github.com/DepthAnything/PromptDA).
142
- - The depth map is stored as uint16, with a unit of millimeters.
143
- """
144
-
145
- @click.command()
146
- @click.option('--share', is_flag=True, help='Whether to run the app in shared mode.')
147
- def main(share: bool):
148
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
149
- gr.Markdown(DESCRIPTION)
150
-
151
- with gr.Row():
152
- input_file = gr.File(type="filepath", label="Stray scanner app capture zip file")
153
- resolution = gr.Dropdown(choices=['756x1008', '1428x1904'], value='756x1008', label="Inference resolution")
154
- submit_btn = gr.Button("Submit")
155
-
156
- gr.Examples(examples=[
157
- ["data/assets/example0_chair.zip", "756x1008"]
158
- ],
159
- inputs=[input_file, resolution],
160
- label="Examples",
161
- )
162
-
163
- with gr.Row():
164
- with gr.Column():
165
- output_rgb = gr.Image(type="numpy", label="RGB Image")
166
- with gr.Column():
167
- output_depths = ImageSlider(label="Output depth / prompt depth", position=0.5)
168
-
169
- with gr.Row():
170
- with gr.Column():
171
- output_3d_model = gr.Model3D(label="3D Viewer", display_mode='solid', clear_color=[1.0, 1.0, 1.0, 1.0])
172
- with gr.Column():
173
- output_ply = gr.File(type="filepath", label="Download the unprojected point cloud as .ply file", height=30)
174
- output_depth_map = gr.File(type="filepath", label="Download the depth map as .png file", height=30)
175
- outputs = [
176
- output_rgb,
177
- output_depths,
178
- output_3d_model,
179
- output_ply,
180
- output_depth_map,
181
- ]
182
- # gr.Examples(examples=[
183
- # ["data/assets/example0_chair.zip", "756x1008"]
184
- # ],
185
- # fn=run,
186
- # inputs=[input_file, resolution],
187
- # outputs=outputs,
188
- # label="Examples",
189
- # cache_examples=True,
190
- # )
191
- submit_btn.click(run,
192
- inputs=[input_file, resolution],
193
- outputs=outputs)
194
-
195
- demo.launch(share=share)
196
- # def main(share: bool):
197
- # gr.Interface(
198
- # fn=run,
199
- # inputs=[
200
- # gr.File(type="filepath", label="Upload a stray scanner app capture zip file"),
201
- # gr.Dropdown(choices=['756x1008', '1428x1904'], value='756x1008', label="Inference resolution")
202
- # ],
203
- # outputs=[
204
- # gr.Image(type="numpy", label="RGB Image"),
205
- # ImageSlider(label="Depth map / prompt depth", position=0.5),
206
- # gr.Model3D(label="3D Viewer", display_mode='solid', clear_color=[1.0, 1.0, 1.0, 1.0]),
207
- # gr.File(type="filepath", label="Download the unprojected point cloud as .ply file"),
208
- # gr.File(type="filepath", label="Download the depth map as .png file"),
209
- # ],
210
- # title=None,
211
- # description=DESCRIPTION,
212
- # clear_btn=None,
213
- # allow_flagging="never",
214
- # theme=gr.themes.Soft(),
215
- # examples=[
216
- # ["data/assets/8b98276b0a.zip"]
217
- # ]
218
- # ).launch(share=True)
219
-
220
-
221
- if __name__ == '__main__':
222
- main()
 
 
 
 
 
 
 
1
  import spaces
 
 
 
 
2
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  @spaces.GPU
5
+ def generate(prompt):
6
+ return 'hello, world'
7
+
8
+ gr.Interface(
9
+ fn=generate,
10
+ inputs=gr.Text(),
11
+ outputs=gr.Gallery(),
12
+ ).launch()