sdsdsdadasd3 commited on
Commit
3838dc1
1 Parent(s): 7c2f6e2

[Release] v1.0.1

Browse files

- improve the performance
- improve efficiency

Files changed (3) hide show
  1. app.py +23 -15
  2. depthcrafter/utils.py +44 -0
  3. run.py +2 -50
app.py CHANGED
@@ -17,11 +17,11 @@ from huggingface_hub import hf_hub_download
17
  from depthcrafter.utils import read_video_frames, vis_sequence_depth, save_video
18
 
19
  examples = [
20
- ["examples/example_01.mp4", 10, 1.2, 1024, 60],
21
- ["examples/example_02.mp4", 10, 1.2, 1024, 60],
22
- ["examples/example_03.mp4", 10, 1.2, 1024, 60],
23
- ["examples/example_04.mp4", 10, 1.2, 1024, 60],
24
- ["examples/example_05.mp4", 10, 1.2, 1024, 60],
25
  ]
26
 
27
 
@@ -39,18 +39,18 @@ pipe = DepthCrafterPipeline.from_pretrained(
39
  pipe.to("cuda")
40
 
41
 
42
- @spaces.GPU(duration=140)
43
  def infer_depth(
44
  video: str,
45
  num_denoising_steps: int,
46
  guidance_scale: float,
47
  max_res: int = 1024,
48
- process_length: int = 195,
 
49
  #
50
  save_folder: str = "./demo_output",
51
  window_size: int = 110,
52
  overlap: int = 25,
53
- target_fps: int = 15,
54
  seed: int = 42,
55
  track_time: bool = True,
56
  save_npz: bool = False,
@@ -59,7 +59,6 @@ def infer_depth(
59
  pipe.enable_xformers_memory_efficient_attention()
60
 
61
  frames, target_fps = read_video_frames(video, process_length, target_fps, max_res)
62
- print(f"==> video name: {video}, frames shape: {frames.shape}")
63
 
64
  # inference the depth map using the DepthCrafter pipeline
65
  with torch.inference_mode():
@@ -82,6 +81,7 @@ def infer_depth(
82
  vis = vis_sequence_depth(res)
83
  # save the depth map and visualization with the target FPS
84
  save_path = os.path.join(save_folder, os.path.splitext(os.path.basename(video))[0])
 
85
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
86
  if save_npz:
87
  np.savez_compressed(save_path + ".npz", depth=res)
@@ -155,14 +155,14 @@ def construct_demo():
155
  label="num denoising steps",
156
  minimum=1,
157
  maximum=25,
158
- value=10,
159
  step=1,
160
  )
161
  guidance_scale = gr.Slider(
162
  label="cfg scale",
163
  minimum=1.0,
164
  maximum=1.2,
165
- value=1.2,
166
  step=0.1,
167
  )
168
  max_res = gr.Slider(
@@ -174,11 +174,18 @@ def construct_demo():
174
  )
175
  process_length = gr.Slider(
176
  label="process length",
177
- minimum=1,
178
  maximum=280,
179
  value=60,
180
  step=1,
181
  )
 
 
 
 
 
 
 
182
  generate_btn = gr.Button("Generate")
183
  with gr.Column(scale=2):
184
  pass
@@ -191,6 +198,7 @@ def construct_demo():
191
  guidance_scale,
192
  max_res,
193
  process_length,
 
194
  ],
195
  outputs=[output_video_1, output_video_2],
196
  fn=infer_depth,
@@ -216,6 +224,7 @@ def construct_demo():
216
  guidance_scale,
217
  max_res,
218
  process_length,
 
219
  ],
220
  outputs=[output_video_1, output_video_2],
221
  )
@@ -223,9 +232,8 @@ def construct_demo():
223
  return depthcrafter_iface
224
 
225
 
226
- demo = construct_demo()
227
-
228
  if __name__ == "__main__":
 
229
  demo.queue()
230
- # demo.launch(server_name="0.0.0.0", server_port=80, debug=True)
231
  demo.launch(share=True)
 
17
  from depthcrafter.utils import read_video_frames, vis_sequence_depth, save_video
18
 
19
  examples = [
20
+ ["examples/example_01.mp4", 5, 1.0, 1024, -1, -1],
21
+ ["examples/example_02.mp4", 5, 1.0, 1024, -1, -1],
22
+ ["examples/example_03.mp4", 5, 1.0, 1024, -1, -1],
23
+ ["examples/example_04.mp4", 5, 1.0, 1024, -1, -1],
24
+ ["examples/example_05.mp4", 5, 1.0, 1024, -1, -1],
25
  ]
26
 
27
 
 
39
  pipe.to("cuda")
40
 
41
 
42
+ @spaces.GPU(duration=120)
43
  def infer_depth(
44
  video: str,
45
  num_denoising_steps: int,
46
  guidance_scale: float,
47
  max_res: int = 1024,
48
+ process_length: int = -1,
49
+ target_fps: int = -1,
50
  #
51
  save_folder: str = "./demo_output",
52
  window_size: int = 110,
53
  overlap: int = 25,
 
54
  seed: int = 42,
55
  track_time: bool = True,
56
  save_npz: bool = False,
 
59
  pipe.enable_xformers_memory_efficient_attention()
60
 
61
  frames, target_fps = read_video_frames(video, process_length, target_fps, max_res)
 
62
 
63
  # inference the depth map using the DepthCrafter pipeline
64
  with torch.inference_mode():
 
81
  vis = vis_sequence_depth(res)
82
  # save the depth map and visualization with the target FPS
83
  save_path = os.path.join(save_folder, os.path.splitext(os.path.basename(video))[0])
84
+ print(f"==> saving results to {save_path}")
85
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
86
  if save_npz:
87
  np.savez_compressed(save_path + ".npz", depth=res)
 
155
  label="num denoising steps",
156
  minimum=1,
157
  maximum=25,
158
+ value=5,
159
  step=1,
160
  )
161
  guidance_scale = gr.Slider(
162
  label="cfg scale",
163
  minimum=1.0,
164
  maximum=1.2,
165
+ value=1.0,
166
  step=0.1,
167
  )
168
  max_res = gr.Slider(
 
174
  )
175
  process_length = gr.Slider(
176
  label="process length",
177
+ minimum=-1,
178
  maximum=280,
179
  value=60,
180
  step=1,
181
  )
182
+ process_target_fps = gr.Slider(
183
+ label="target FPS",
184
+ minimum=-1,
185
+ maximum=30,
186
+ value=15,
187
+ step=1,
188
+ )
189
  generate_btn = gr.Button("Generate")
190
  with gr.Column(scale=2):
191
  pass
 
198
  guidance_scale,
199
  max_res,
200
  process_length,
201
+ process_target_fps,
202
  ],
203
  outputs=[output_video_1, output_video_2],
204
  fn=infer_depth,
 
224
  guidance_scale,
225
  max_res,
226
  process_length,
227
+ process_target_fps,
228
  ],
229
  outputs=[output_video_1, output_video_2],
230
  )
 
232
  return depthcrafter_iface
233
 
234
 
 
 
235
  if __name__ == "__main__":
236
+ demo = construct_demo()
237
  demo.queue()
238
+ # demo.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False)
239
  demo.launch(share=True)
depthcrafter/utils.py CHANGED
@@ -5,6 +5,50 @@ import PIL.Image
5
  import matplotlib.cm as cm
6
  import mediapy
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def save_video(
 
5
  import matplotlib.cm as cm
6
  import mediapy
7
  import torch
8
+ from decord import VideoReader, cpu
9
+
10
+ dataset_res_dict = {
11
+ "sintel": [448, 1024],
12
+ "scannet": [640, 832],
13
+ "KITTI": [384, 1280],
14
+ "bonn": [512, 640],
15
+ "NYUv2": [448, 640],
16
+ }
17
+
18
+
19
+ def read_video_frames(video_path, process_length, target_fps, max_res, dataset="open"):
20
+ if dataset == "open":
21
+ print("==> processing video: ", video_path)
22
+ vid = VideoReader(video_path, ctx=cpu(0))
23
+ print("==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:]))
24
+ original_height, original_width = vid.get_batch([0]).shape[1:3]
25
+ height = round(original_height / 64) * 64
26
+ width = round(original_width / 64) * 64
27
+ if max(height, width) > max_res:
28
+ scale = max_res / max(original_height, original_width)
29
+ height = round(original_height * scale / 64) * 64
30
+ width = round(original_width * scale / 64) * 64
31
+ else:
32
+ height = dataset_res_dict[dataset][0]
33
+ width = dataset_res_dict[dataset][1]
34
+
35
+ vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
36
+
37
+ fps = vid.get_avg_fps() if target_fps == -1 else target_fps
38
+ stride = round(vid.get_avg_fps() / fps)
39
+ stride = max(stride, 1)
40
+ frames_idx = list(range(0, len(vid), stride))
41
+ print(
42
+ f"==> downsampled shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}, with stride: {stride}"
43
+ )
44
+ if process_length != -1 and process_length < len(frames_idx):
45
+ frames_idx = frames_idx[:process_length]
46
+ print(
47
+ f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}"
48
+ )
49
+ frames = vid.get_batch(frames_idx).asnumpy().astype("float32") / 255.0
50
+
51
+ return frames, fps
52
 
53
 
54
  def save_video(
run.py CHANGED
@@ -3,21 +3,12 @@ import os
3
  import numpy as np
4
  import torch
5
 
6
- from decord import VideoReader, cpu
7
  from diffusers.training_utils import set_seed
8
  from fire import Fire
9
 
10
  from depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
11
  from depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
12
- from depthcrafter.utils import vis_sequence_depth, save_video
13
-
14
- dataset_res_dict = {
15
- "sintel": [448, 1024],
16
- "scannet": [640, 832],
17
- "KITTI": [384, 1280],
18
- "bonn": [512, 640],
19
- "NYUv2": [448, 640],
20
- }
21
 
22
 
23
  class DepthCrafterDemo:
@@ -59,45 +50,6 @@ class DepthCrafterDemo:
59
  print("Xformers is not enabled")
60
  self.pipe.enable_attention_slicing()
61
 
62
- @staticmethod
63
- def read_video_frames(
64
- video_path, process_length, target_fps, max_res, dataset="open"
65
- ):
66
- if dataset == "open":
67
- print("==> processing video: ", video_path)
68
- vid = VideoReader(video_path, ctx=cpu(0))
69
- print(
70
- "==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:])
71
- )
72
- original_height, original_width = vid.get_batch([0]).shape[1:3]
73
- height = round(original_height / 64) * 64
74
- width = round(original_width / 64) * 64
75
- if max(height, width) > max_res:
76
- scale = max_res / max(original_height, original_width)
77
- height = round(original_height * scale / 64) * 64
78
- width = round(original_width * scale / 64) * 64
79
- else:
80
- height = dataset_res_dict[dataset][0]
81
- width = dataset_res_dict[dataset][1]
82
-
83
- vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
84
-
85
- fps = vid.get_avg_fps() if target_fps == -1 else target_fps
86
- stride = round(vid.get_avg_fps() / fps)
87
- stride = max(stride, 1)
88
- frames_idx = list(range(0, len(vid), stride))
89
- print(
90
- f"==> downsampled shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}, with stride: {stride}"
91
- )
92
- if process_length != -1 and process_length < len(frames_idx):
93
- frames_idx = frames_idx[:process_length]
94
- print(
95
- f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}"
96
- )
97
- frames = vid.get_batch(frames_idx).asnumpy().astype("float32") / 255.0
98
-
99
- return frames, fps
100
-
101
  def infer(
102
  self,
103
  video: str,
@@ -116,7 +68,7 @@ class DepthCrafterDemo:
116
  ):
117
  set_seed(seed)
118
 
119
- frames, target_fps = self.read_video_frames(
120
  video,
121
  process_length,
122
  target_fps,
 
3
  import numpy as np
4
  import torch
5
 
 
6
  from diffusers.training_utils import set_seed
7
  from fire import Fire
8
 
9
  from depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
10
  from depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
11
+ from depthcrafter.utils import vis_sequence_depth, save_video, read_video_frames
 
 
 
 
 
 
 
 
12
 
13
 
14
  class DepthCrafterDemo:
 
50
  print("Xformers is not enabled")
51
  self.pipe.enable_attention_slicing()
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def infer(
54
  self,
55
  video: str,
 
68
  ):
69
  set_seed(seed)
70
 
71
+ frames, target_fps = read_video_frames(
72
  video,
73
  process_length,
74
  target_fps,