jhaoshao commited on
Commit
861fa04
·
1 Parent(s): b508869

release v1 demo

Browse files
.gitattributes copy DELETED
@@ -1,37 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- files/sora_1764106507569053773.mp4 filter=lfs diff=lfs merge=lfs -text
37
- files/sora_e2.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -30,31 +30,77 @@ import spaces
30
  import gradio as gr
31
  import numpy as np
32
  import torch as torch
 
 
33
  from PIL import Image
34
  from tqdm import tqdm
35
  import mediapy as media
36
 
37
  from huggingface_hub import login
38
 
39
- from chronodepth_pipeline import ChronoDepthPipeline
40
- from gradio_patches.examples import Examples
 
 
 
41
 
42
  default_seed = 2024
43
 
44
  default_num_inference_steps = 5
45
- default_num_frames = 10
46
- default_window_size = 9
47
  default_video_processing_resolution = 768
48
- default_video_out_max_frames = 80
49
- default_decode_chunk_size = 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def process_video(
52
  pipe,
53
  path_input,
54
  num_inference_steps=default_num_inference_steps,
55
- num_frames=default_num_frames,
56
- window_size=default_window_size,
57
- out_max_frames=default_video_out_max_frames,
58
  progress=gr.Progress(),
59
  ):
60
  if path_input is None:
@@ -75,14 +121,13 @@ def process_video(
75
  start_time = time.time()
76
  zipf = None
77
  try:
78
- if window_size is None or window_size == num_frames:
79
- inpaint_inference = False
80
- else:
81
- inpaint_inference = True
82
- data_ls = []
83
  video_data = media.read_video(path_input)
84
- video_length = len(video_data)
85
  fps = video_data.metadata.fps
 
 
86
 
87
  duration_sec = video_length / fps
88
 
@@ -92,113 +137,16 @@ def process_video(
92
  f"Only the first ~{int(out_duration_sec)} seconds will be processed; "
93
  f"use alternative setups such as ChronoDepth on github for full processing"
94
  )
95
- video_length = out_max_frames
96
-
97
- for i in tqdm(range(video_length-num_frames+1)):
98
- is_first_clip = i == 0
99
- is_last_clip = i == video_length - num_frames
100
- is_new_clip = (
101
- (inpaint_inference and i % window_size == 0)
102
- or (inpaint_inference == False and i % num_frames == 0)
103
- )
104
- if is_first_clip or is_last_clip or is_new_clip:
105
- data_ls.append(np.array(video_data[i: i+num_frames])) # [t, H, W, 3]
106
 
107
  zipf = zipfile.ZipFile(path_out_16bit, "w", zipfile.ZIP_DEFLATED)
108
 
109
- depth_colored_pred = []
110
- depth_pred = []
111
  # -------------------- Inference and saving --------------------
112
- with torch.no_grad():
113
- for iter, batch in enumerate(tqdm(data_ls)):
114
- rgb_int = batch
115
- input_images = [Image.fromarray(rgb_int[i]) for i in range(num_frames)]
116
-
117
- # Predict depth
118
- if iter == 0: # First clip
119
- pipe_out = pipe(
120
- input_images,
121
- num_frames=len(input_images),
122
- num_inference_steps=num_inference_steps,
123
- decode_chunk_size=default_decode_chunk_size,
124
- motion_bucket_id=127,
125
- fps=7,
126
- noise_aug_strength=0.0,
127
- generator=generator,
128
- )
129
- elif inpaint_inference and (iter == len(data_ls) - 1): # temporal inpaint inference for last clip
130
- last_window_size = window_size if video_length%window_size == 0 else video_length%window_size
131
- pipe_out = pipe(
132
- input_images,
133
- num_frames=num_frames,
134
- num_inference_steps=num_inference_steps,
135
- decode_chunk_size=default_decode_chunk_size,
136
- motion_bucket_id=127,
137
- fps=7,
138
- noise_aug_strength=0.0,
139
- generator=generator,
140
- depth_pred_last=depth_frames_pred_ts[last_window_size:],
141
- )
142
- elif inpaint_inference and iter > 0: # temporal inpaint inference
143
- pipe_out = pipe(
144
- input_images,
145
- num_frames=num_frames,
146
- num_inference_steps=num_inference_steps,
147
- decode_chunk_size=default_decode_chunk_size,
148
- motion_bucket_id=127,
149
- fps=7,
150
- noise_aug_strength=0.0,
151
- generator=generator,
152
- depth_pred_last=depth_frames_pred_ts[window_size:],
153
- )
154
- else: # separate inference
155
- pipe_out = pipe(
156
- input_images,
157
- num_frames=num_frames,
158
- num_inference_steps=num_inference_steps,
159
- decode_chunk_size=default_decode_chunk_size,
160
- motion_bucket_id=127,
161
- fps=7,
162
- noise_aug_strength=0.0,
163
- generator=generator,
164
- )
165
-
166
- depth_frames_pred = [pipe_out.depth_np[i] for i in range(num_frames)]
167
-
168
- depth_frames_colored_pred = []
169
- for i in range(num_frames):
170
- depth_frame_colored_pred = np.array(pipe_out.depth_colored[i])
171
- depth_frames_colored_pred.append(depth_frame_colored_pred)
172
- depth_frames_colored_pred = np.stack(depth_frames_colored_pred, axis=0)
173
-
174
- depth_frames_pred = np.stack(depth_frames_pred, axis=0)
175
- depth_frames_pred_ts = torch.from_numpy(depth_frames_pred).to(pipe.device)
176
- depth_frames_pred_ts = depth_frames_pred_ts * 2 - 1
177
-
178
- if inpaint_inference == False:
179
- if iter == len(data_ls) - 1:
180
- last_window_size = num_frames if video_length%num_frames == 0 else video_length%num_frames
181
- depth_colored_pred.append(depth_frames_colored_pred[-last_window_size:])
182
- depth_pred.append(depth_frames_pred[-last_window_size:])
183
- else:
184
- depth_colored_pred.append(depth_frames_colored_pred)
185
- depth_pred.append(depth_frames_pred)
186
- else:
187
- if iter == 0:
188
- depth_colored_pred.append(depth_frames_colored_pred)
189
- depth_pred.append(depth_frames_pred)
190
- elif iter == len(data_ls) - 1:
191
- depth_colored_pred.append(depth_frames_colored_pred[-last_window_size:])
192
- depth_pred.append(depth_frames_pred[-last_window_size:])
193
- else:
194
- depth_colored_pred.append(depth_frames_colored_pred[-window_size:])
195
- depth_pred.append(depth_frames_pred[-window_size:])
196
-
197
- depth_colored_pred = np.concatenate(depth_colored_pred, axis=0)
198
- depth_pred = np.concatenate(depth_pred, axis=0)
199
 
200
  # -------------------- Save results --------------------
201
- # Save images
202
  for i in tqdm(range(len(depth_pred))):
203
  archive_path = os.path.join(
204
  f"{name_base}_depth_16bit", f"{i:05d}.png"
@@ -211,6 +159,7 @@ def process_video(
211
 
212
  # Export to video
213
  media.write_video(path_out_vis, depth_colored_pred, fps=fps)
 
214
  finally:
215
  if zipf is not None:
216
  zipf.close()
@@ -225,7 +174,7 @@ def process_video(
225
 
226
  def run_demo_server(pipe):
227
  process_pipe_video = spaces.GPU(
228
- functools.partial(process_video, pipe), duration=220
229
  )
230
  os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
231
 
@@ -257,27 +206,27 @@ def run_demo_server(pipe):
257
  }
258
  """,
259
  ) as demo:
260
- gr.Markdown(
261
  """
262
- # ChronoDepth Video Depth Estimation
263
-
264
- <p align="center">
265
- <a title="Website" href="https://jhaoshao.github.io/ChronoDepth/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
266
- <img src="https://img.shields.io/website?url=https%3A%2F%2Fjhaoshao.github.io%2FChronoDepth%2F&up_message=ChronoDepth&up_color=blue&style=flat&logo=timescale&logoColor=%23FFDC0F">
267
- </a>
268
- <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
269
- <img src="https://img.shields.io/badge/arXiv-PDF-b31b1b">
270
- </a>
271
- <a title="Github" href="https://github.com/jhaoshao/ChronoDepth" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
272
- <img src="https://img.shields.io/github/stars/jhaoshao/ChronoDepth?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
273
- </a>
 
 
274
  </p>
275
-
276
- ChronoDepth is the state-of-the-art video depth estimator for videos in the wild.
277
- Upload your video and have a try!<br>
278
- We set denoising steps to 5, number of frames for each video clip to 10, and overlap between clips to 1.
279
-
280
- """
281
  )
282
 
283
  with gr.Row():
@@ -301,19 +250,16 @@ def run_demo_server(pipe):
301
  elem_id="download",
302
  interactive=False,
303
  )
304
- Examples(
305
- fn=process_pipe_video,
306
  examples=[
307
- os.path.join("files", name)
308
- for name in [
309
- "sora_e2.mp4",
310
- "sora_1758192960116785459.mp4",
311
- ]
312
  ],
313
  inputs=[video_input],
314
  outputs=[video_output_video, video_output_files],
 
315
  cache_examples=True,
316
- directory_name="examples_video",
317
  )
318
 
319
  video_submit_btn.click(
@@ -339,17 +285,30 @@ def run_demo_server(pipe):
339
 
340
 
341
  def main():
342
- CHECKPOINT = "jhshao/ChronoDepth"
343
 
344
  if "HF_TOKEN_LOGIN" in os.environ:
345
  login(token=os.environ["HF_TOKEN_LOGIN"])
346
 
347
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
348
  print(f"Running on device: {device}")
349
- pipe = ChronoDepthPipeline.from_pretrained(CHECKPOINT)
350
- try:
351
- import xformers
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  pipe.enable_xformers_memory_efficient_attention()
354
  except:
355
  pass # run without xformers
 
30
  import gradio as gr
31
  import numpy as np
32
  import torch as torch
33
+ import torch.nn.functional as F
34
+ import xformers
35
  from PIL import Image
36
  from tqdm import tqdm
37
  import mediapy as media
38
 
39
  from huggingface_hub import login
40
 
41
+ from chronodepth.unet_chronodepth import DiffusersUNetSpatioTemporalConditionModelChronodepth
42
+ from chronodepth.chronodepth_pipeline import ChronoDepthPipeline
43
+ from chronodepth.video_utils import resize_max_res, colorize_video_depth
44
+
45
+ MAX_FRAME=15
46
 
47
  default_seed = 2024
48
 
49
  default_num_inference_steps = 5
50
+ default_n_tokens = 10
51
+ default_chunk_size = 5
52
  default_video_processing_resolution = 768
53
+ default_decode_chunk_size = 8
54
+
55
+
56
+ @torch.no_grad()
57
+ def run_pipeline(pipe, video_rgb, generator, device):
58
+ """
59
+ Run the pipe on the input video.
60
+ args:
61
+ pipe: ChronoDepthPipeline object
62
+ video_rgb: input video, torch.Tensor, shape [T, H, W, 3], range [0, 255]
63
+ generator: torch.Generator
64
+ returns:
65
+ video_depth_pred: predicted depth, torch.Tensor, shape [T, H, W], range [0, 1]
66
+ """
67
+ if isinstance(video_rgb, torch.Tensor):
68
+ video_rgb = video_rgb.cpu().numpy()
69
+
70
+ original_height = video_rgb.shape[1]
71
+ original_width = video_rgb.shape[2]
72
+
73
+ # resize the video to the max resolution
74
+ video_rgb = resize_max_res(video_rgb, default_video_processing_resolution)
75
+
76
+ video_rgb = video_rgb.astype(np.float32) / 255.0
77
+
78
+ pipe_out = pipe(
79
+ video_rgb,
80
+ num_inference_steps=default_num_inference_steps,
81
+ decode_chunk_size=default_decode_chunk_size,
82
+ motion_bucket_id=127,
83
+ fps=7,
84
+ noise_aug_strength=0.0,
85
+ generator=generator,
86
+ infer_mode="ours",
87
+ sigma_epsilon=-4,
88
+ )
89
+
90
+ depth_frames_pred = pipe_out.frames
91
+ depth_frames_pred = torch.from_numpy(depth_frames_pred).to(device)
92
+ depth_frames_pred = F.interpolate(depth_frames_pred, size=(original_height, original_width), mode="bilinear", align_corners=False)
93
+ depth_frames_pred = depth_frames_pred.clamp(0, 1)
94
+ depth_frames_pred = depth_frames_pred.squeeze(1)
95
+
96
+ return depth_frames_pred
97
+
98
 
99
  def process_video(
100
  pipe,
101
  path_input,
102
  num_inference_steps=default_num_inference_steps,
103
+ out_max_frames=MAX_FRAME,
 
 
104
  progress=gr.Progress(),
105
  ):
106
  if path_input is None:
 
121
  start_time = time.time()
122
  zipf = None
123
  try:
124
+ # -------------------- data --------------------
125
+ video_name = path_input.split('/')[-1].split('.')[0]
 
 
 
126
  video_data = media.read_video(path_input)
127
+
128
  fps = video_data.metadata.fps
129
+ video_length = len(video_data)
130
+ video_rgb = np.array(video_data)
131
 
132
  duration_sec = video_length / fps
133
 
 
137
  f"Only the first ~{int(out_duration_sec)} seconds will be processed; "
138
  f"use alternative setups such as ChronoDepth on github for full processing"
139
  )
140
+ video_rgb = video_rgb[:out_max_frames]
 
 
 
 
 
 
 
 
 
 
141
 
142
  zipf = zipfile.ZipFile(path_out_16bit, "w", zipfile.ZIP_DEFLATED)
143
 
 
 
144
  # -------------------- Inference and saving --------------------
145
+ depth_pred = run_pipeline(pipe, video_rgb, generator, pipe.device) # range [0, 1]
146
+ depth_pred = depth_pred.cpu().numpy()
147
+ depth_colored_pred = colorize_video_depth(depth_pred) # range [0, 1] -> [0, 255]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  # -------------------- Save results --------------------
 
150
  for i in tqdm(range(len(depth_pred))):
151
  archive_path = os.path.join(
152
  f"{name_base}_depth_16bit", f"{i:05d}.png"
 
159
 
160
  # Export to video
161
  media.write_video(path_out_vis, depth_colored_pred, fps=fps)
162
+
163
  finally:
164
  if zipf is not None:
165
  zipf.close()
 
174
 
175
  def run_demo_server(pipe):
176
  process_pipe_video = spaces.GPU(
177
+ functools.partial(process_video, pipe), duration=100
178
  )
179
  os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
180
 
 
206
  }
207
  """,
208
  ) as demo:
209
+ gr.HTML(
210
  """
211
+ <h1>⏰ChronoDepth: Learning Temporally Consistent Video Depth from Video Diffusion Priors</h1>
212
+ <div style="text-align: center; margin-top: 20px;">
213
+ <a title="Website" href="https://jhaoshao.github.io/ChronoDepth/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
214
+ <img src="https://img.shields.io/website?url=https%3A%2F%2Fjhaoshao.github.io%2FChronoDepth%2F&up_message=ChronoDepth&up_color=blue&style=flat&logo=timescale&logoColor=%23FFDC0F">
215
+ </a>
216
+ <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
217
+ <img src="https://img.shields.io/badge/arXiv-PDF-b31b1b">
218
+ </a>
219
+ <a title="Github" href="https://github.com/jhaoshao/ChronoDepth" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
220
+ <img src="https://img.shields.io/github/stars/jhaoshao/ChronoDepth?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
221
+ </a>
222
+ </div>
223
+ <p style="margin-top: 20px; text-align: justify;">
224
+ ChronoDepth is the state-of-the-art video depth estimator for streaming videos in the wild.
225
  </p>
226
+ <p style="margin-top: 20px; text-align: justify;">
227
+ PS: The maximum video length is limited to 100 frames for the demo. To process longer videos, please use the ChronoDepth on github.
228
+ </p>
229
+ """
 
 
230
  )
231
 
232
  with gr.Row():
 
250
  elem_id="download",
251
  interactive=False,
252
  )
253
+ gr.Examples(
 
254
  examples=[
255
+ ["files/elephant.mp4"],
256
+ ["files/kitti360_seq_0000.mp4"],
 
 
 
257
  ],
258
  inputs=[video_input],
259
  outputs=[video_output_video, video_output_files],
260
+ fn=process_pipe_video,
261
  cache_examples=True,
262
+ cache_mode="examples_video",
263
  )
264
 
265
  video_submit_btn.click(
 
285
 
286
 
287
  def main():
288
+ CHECKPOINT = "jhshao/ChronoDepth-v1"
289
 
290
  if "HF_TOKEN_LOGIN" in os.environ:
291
  login(token=os.environ["HF_TOKEN_LOGIN"])
292
 
293
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
294
  print(f"Running on device: {device}")
 
 
 
295
 
296
+ # -------------------- Model --------------------
297
+ unet = DiffusersUNetSpatioTemporalConditionModelChronodepth.from_pretrained(
298
+ CHECKPOINT,
299
+ low_cpu_mem_usage=True,
300
+ torch_dtype=torch.float16,
301
+ )
302
+ pipe = ChronoDepthPipeline.from_pretrained(
303
+ "stabilityai/stable-video-diffusion-img2vid-xt",
304
+ unet=unet,
305
+ torch_dtype=torch.float16,
306
+ variant="fp16",
307
+ )
308
+ pipe.n_tokens = default_n_tokens
309
+ pipe.chunk_size = default_chunk_size
310
+
311
+ try:
312
  pipe.enable_xformers_memory_efficient_attention()
313
  except:
314
  pass # run without xformers
chronodepth/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .chronodepth_pipeline import ChronoDepthPipeline
chronodepth/chronodepth_pipeline.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Union, Optional, List
3
+
4
+ import torch
5
+ import numpy as np
6
+ from tqdm.auto import tqdm
7
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
8
+ _resize_with_antialiasing,
9
+ StableVideoDiffusionPipelineOutput,
10
+ StableVideoDiffusionPipeline,
11
+ )
12
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
13
+ from einops import rearrange
14
+
15
+
16
+ class ChronoDepthPipeline(StableVideoDiffusionPipeline):
17
+
18
+ @torch.inference_mode()
19
+ def encode_images(self,
20
+ images: torch.Tensor,
21
+ decode_chunk_size=5,
22
+ ):
23
+ video_length = images.shape[1]
24
+ images = rearrange(images, "b f c h w -> (b f) c h w")
25
+ latents = []
26
+ for i in range(0, images.shape[0], decode_chunk_size):
27
+ latents_chunk = self.vae.encode(images[i : i + decode_chunk_size]).latent_dist.sample()
28
+ latents.append(latents_chunk)
29
+ latents = torch.cat(latents, dim=0)
30
+ latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
31
+ latents = latents * self.vae.config.scaling_factor
32
+ return latents
33
+
34
+ @torch.inference_mode()
35
+ def _encode_image(self, images, device, discard=True, chunk_size=14):
36
+ '''
37
+ set image to zero tensor discards the image embeddings if discard is True
38
+ '''
39
+ dtype = next(self.image_encoder.parameters()).dtype
40
+
41
+ images = _resize_with_antialiasing(images, (224, 224))
42
+ images = (images + 1.0) / 2.0
43
+
44
+ if discard:
45
+ images = torch.zeros_like(images)
46
+
47
+ image_embeddings = []
48
+ for i in range(0, images.shape[0], chunk_size):
49
+ tmp = self.feature_extractor(
50
+ images=images[i : i + chunk_size],
51
+ do_normalize=True,
52
+ do_center_crop=False,
53
+ do_resize=False,
54
+ do_rescale=False,
55
+ return_tensors="pt",
56
+ ).pixel_values
57
+
58
+ tmp = tmp.to(device=device, dtype=dtype)
59
+ image_embeddings.append(self.image_encoder(tmp).image_embeds)
60
+ image_embeddings = torch.cat(image_embeddings, dim=0)
61
+ image_embeddings = image_embeddings.unsqueeze(1) # [t, 1, 1024]
62
+
63
+ return image_embeddings
64
+
65
+ def decode_depth(self, depth_latent: torch.Tensor, decode_chunk_size=5) -> torch.Tensor:
66
+ num_frames = depth_latent.shape[1]
67
+ depth_latent = rearrange(depth_latent, "b f c h w -> (b f) c h w")
68
+
69
+ depth_latent = depth_latent / self.vae.config.scaling_factor
70
+
71
+ forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
72
+ accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
73
+
74
+ depth_frames = []
75
+ for i in range(0, depth_latent.shape[0], decode_chunk_size):
76
+ num_frames_in = depth_latent[i : i + decode_chunk_size].shape[0]
77
+ decode_kwargs = {}
78
+ if accepts_num_frames:
79
+ # we only pass num_frames_in if it's expected
80
+ decode_kwargs["num_frames"] = num_frames_in
81
+
82
+ depth_frame = self.vae.decode(depth_latent[i : i + decode_chunk_size], **decode_kwargs).sample
83
+ depth_frames.append(depth_frame)
84
+
85
+ depth_frames = torch.cat(depth_frames, dim=0)
86
+ depth_frames = depth_frames.reshape(-1, num_frames, *depth_frames.shape[1:])
87
+ depth_mean = depth_frames.mean(dim=2, keepdim=True)
88
+
89
+ return depth_mean
90
+
91
+ @staticmethod
92
+ def check_inputs(images, height, width):
93
+ if (
94
+ not isinstance(images, torch.Tensor)
95
+ and not isinstance(images, np.ndarray)
96
+ ):
97
+ raise ValueError(
98
+ "`images` has to be of type `torch.Tensor` or `numpy.ndarray` but is"
99
+ f" {type(images)}"
100
+ )
101
+
102
+ if height % 64 != 0 or width % 64 != 0:
103
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
104
+
105
+ @torch.no_grad()
106
+ def __call__(
107
+ self,
108
+ input_images: Union[np.ndarray, torch.FloatTensor],
109
+ height: int = 576,
110
+ width: int = 768,
111
+ num_inference_steps: int = 10,
112
+ fps: int = 7,
113
+ motion_bucket_id: int = 127,
114
+ noise_aug_strength: float = 0.02,
115
+ decode_chunk_size: Optional[int] = None,
116
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
117
+ show_progress_bar: bool = True,
118
+ latents: Optional[torch.Tensor] = None,
119
+ infer_mode: str = 'ours',
120
+ sigma_epsilon: float = -4,
121
+ ):
122
+ """
123
+ Args:
124
+ input_images: shape [T, H, W, 3] if np.ndarray or [T, 3, H, W] if torch.FloatTensor, range [0, 1]
125
+ height: int, height of the input image
126
+ width: int, width of the input image
127
+ num_inference_steps: int, number of inference steps
128
+ fps: int, frames per second
129
+ motion_bucket_id: int, motion bucket id
130
+ noise_aug_strength: float, noise augmentation strength
131
+ decode_chunk_size: int, decode chunk size
132
+ generator: torch.Generator or List[torch.Generator], random number generator
133
+ show_progress_bar: bool, show progress bar
134
+ """
135
+ assert height >= 0 and width >=0
136
+ assert num_inference_steps >=1
137
+
138
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
139
+
140
+ # 1. Check inputs. Raise error if not correct
141
+ self.check_inputs(input_images, height, width)
142
+
143
+ # 2. Define call parameters
144
+ batch_size = 1 # only support batch size 1 for now
145
+ device = self._execution_device
146
+
147
+ # 3. Encode input image
148
+ if isinstance(input_images, np.ndarray):
149
+ input_images = torch.from_numpy(input_images.transpose(0, 3, 1, 2))
150
+ else:
151
+ assert isinstance(input_images, torch.Tensor)
152
+ input_images = input_images.to(device=device)
153
+ input_images = input_images * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
154
+
155
+ discard_clip_features = True
156
+ image_embeddings = self._encode_image(input_images, device,
157
+ discard=discard_clip_features,
158
+ chunk_size=decode_chunk_size
159
+ )
160
+
161
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
162
+ # is why it is reduced here.
163
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
164
+ fps = fps - 1
165
+
166
+ # 4. Encode input image using VAE
167
+ noise = randn_tensor(input_images.shape, generator=generator, device=device, dtype=input_images.dtype)
168
+ input_images = input_images + noise_aug_strength * noise
169
+
170
+ rgb_batch = input_images.unsqueeze(0)
171
+
172
+ added_time_ids = self._get_add_time_ids(
173
+ fps,
174
+ motion_bucket_id,
175
+ noise_aug_strength,
176
+ image_embeddings.dtype,
177
+ batch_size,
178
+ 1, # do not modify this!
179
+ False, # do not modify this!
180
+ )
181
+ added_time_ids = added_time_ids.to(device)
182
+
183
+ if infer_mode == 'ours':
184
+ depth_pred_raw = self.single_infer_ours(
185
+ rgb_batch,
186
+ image_embeddings,
187
+ added_time_ids,
188
+ num_inference_steps,
189
+ show_progress_bar,
190
+ generator,
191
+ decode_chunk_size=decode_chunk_size,
192
+ latents=latents,
193
+ sigma_epsilon=sigma_epsilon,
194
+ )
195
+ elif infer_mode == 'replacement':
196
+ depth_pred_raw = self.single_infer_replacement(
197
+ rgb_batch,
198
+ image_embeddings,
199
+ added_time_ids,
200
+ num_inference_steps,
201
+ show_progress_bar,
202
+ generator,
203
+ decode_chunk_size=decode_chunk_size,
204
+ latents=latents,
205
+ )
206
+ elif infer_mode == 'naive':
207
+ depth_pred_raw = self.single_infer_naive_sliding_window(
208
+ rgb_batch,
209
+ image_embeddings,
210
+ added_time_ids,
211
+ num_inference_steps,
212
+ show_progress_bar,
213
+ generator,
214
+ decode_chunk_size=decode_chunk_size,
215
+ latents=latents,
216
+ )
217
+
218
+
219
+ depth_frames = depth_pred_raw.cpu().numpy().astype(np.float32)
220
+
221
+ self.maybe_free_model_hooks()
222
+
223
+ return StableVideoDiffusionPipelineOutput(
224
+ frames = depth_frames,
225
+ )
226
+
227
+ @torch.no_grad()
228
+ def single_infer_ours(self,
229
+ input_rgb: torch.Tensor,
230
+ image_embeddings: torch.Tensor,
231
+ added_time_ids: torch.Tensor,
232
+ num_inference_steps: int,
233
+ show_pbar: bool,
234
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
235
+ decode_chunk_size=1,
236
+ latents: Optional[torch.Tensor] = None,
237
+ sigma_epsilon: float = -4,
238
+ ):
239
+ device = input_rgb.device
240
+ H, W = input_rgb.shape[-2:]
241
+
242
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
243
+ if needs_upcasting:
244
+ self.vae.to(dtype=torch.float32)
245
+
246
+ rgb_latent = self.encode_images(input_rgb)
247
+ rgb_latent = rgb_latent.to(image_embeddings.dtype)
248
+
249
+ torch.cuda.empty_cache()
250
+
251
+ # cast back to fp16 if needed
252
+ if needs_upcasting:
253
+ self.vae.to(dtype=torch.float16)
254
+
255
+ # Prepare timesteps
256
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
257
+ timesteps = self.scheduler.timesteps
258
+
259
+ batch_size, n_frames, _, _, _ = rgb_latent.shape
260
+ num_channels_latents = self.unet.config.in_channels
261
+
262
+ curr_frame = 0
263
+ depth_latent = torch.tensor([], dtype=image_embeddings.dtype, device=device)
264
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
265
+
266
+ # first chunk
267
+ horizon = min(n_frames-curr_frame, self.n_tokens)
268
+ start_frame = 0
269
+ chunk = self.prepare_latents(
270
+ batch_size,
271
+ horizon,
272
+ num_channels_latents,
273
+ H,
274
+ W,
275
+ image_embeddings.dtype,
276
+ device,
277
+ generator,
278
+ latents,
279
+ )
280
+ depth_latent = torch.cat([depth_latent, chunk], 1)
281
+ if show_pbar:
282
+ iterable = tqdm(
283
+ enumerate(timesteps),
284
+ total=len(timesteps),
285
+ leave=False,
286
+ desc=" " * 4 + "Diffusion denoising first chunk",
287
+ )
288
+ else:
289
+ iterable = enumerate(timesteps)
290
+
291
+ for i, t in iterable:
292
+ curr_timesteps = torch.tensor([t]*horizon).to(device)
293
+ depth_latent = self.scheduler.scale_model_input(depth_latent, t)
294
+ noise_pred = self.unet(
295
+ torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2),
296
+ curr_timesteps[start_frame:],
297
+ image_embeddings[start_frame:curr_frame+horizon],
298
+ added_time_ids=added_time_ids
299
+ )[0]
300
+ depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample
301
+
302
+ self.scheduler._step_index = None
303
+ curr_frame += horizon
304
+ pbar.update(horizon)
305
+
306
+ while curr_frame < n_frames:
307
+ if self.chunk_size > 0:
308
+ horizon = min(n_frames - curr_frame, self.chunk_size)
309
+ else:
310
+ horizon = min(n_frames - curr_frame, self.n_tokens)
311
+ assert horizon <= self.n_tokens, "horizon exceeds the number of tokens."
312
+ chunk = self.prepare_latents(
313
+ batch_size,
314
+ horizon,
315
+ num_channels_latents,
316
+ H,
317
+ W,
318
+ image_embeddings.dtype,
319
+ device,
320
+ generator,
321
+ latents,
322
+ )
323
+ depth_latent = torch.cat([depth_latent, chunk], 1)
324
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
325
+
326
+ pbar.set_postfix(
327
+ {
328
+ "start": start_frame,
329
+ "end": curr_frame + horizon,
330
+ }
331
+ )
332
+
333
+ if show_pbar:
334
+ iterable = tqdm(
335
+ enumerate(timesteps),
336
+ total=len(timesteps),
337
+ leave=False,
338
+ desc=" " * 4 + "Diffusion denoising ",
339
+ )
340
+ else:
341
+ iterable = enumerate(timesteps)
342
+
343
+ for i, t in iterable:
344
+ t_horizon = torch.tensor([t]*horizon).to(device)
345
+ # t_context = timesteps[-1] * torch.ones((curr_frame,), dtype=t.dtype).to(device)
346
+ t_context = sigma_epsilon * torch.ones((curr_frame,), dtype=t.dtype).to(device)
347
+ curr_timesteps = torch.concatenate((t_context, t_horizon), 0)
348
+ depth_latent[:, curr_frame:] = self.scheduler.scale_model_input(depth_latent[:, curr_frame:], t)
349
+ noise_pred = self.unet(
350
+ torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2),
351
+ curr_timesteps[start_frame:],
352
+ image_embeddings[start_frame:curr_frame+horizon],
353
+ added_time_ids=added_time_ids
354
+ )[0]
355
+ depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample
356
+
357
+ self.scheduler._step_index = None
358
+ curr_frame += horizon
359
+ pbar.update(horizon)
360
+
361
+ torch.cuda.empty_cache()
362
+ if needs_upcasting:
363
+ self.vae.to(dtype=torch.float16)
364
+ depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size)
365
+ # clip prediction
366
+ depth = torch.clip(depth, -1.0, 1.0)
367
+ # shift to [0, 1]
368
+ depth = (depth + 1.0) / 2.0
369
+
370
+ return depth.squeeze(0)
371
+
372
+ @torch.no_grad()
373
+ def single_infer_replacement(self,
374
+ input_rgb: torch.Tensor,
375
+ image_embeddings: torch.Tensor,
376
+ added_time_ids: torch.Tensor,
377
+ num_inference_steps: int,
378
+ show_pbar: bool,
379
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
380
+ decode_chunk_size=1,
381
+ latents: Optional[torch.Tensor] = None,
382
+ ):
383
+ device = input_rgb.device
384
+ H, W = input_rgb.shape[-2:]
385
+
386
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
387
+ if needs_upcasting:
388
+ self.vae.to(dtype=torch.float32)
389
+
390
+ rgb_latent = self.encode_images(input_rgb)
391
+ rgb_latent = rgb_latent.to(image_embeddings.dtype)
392
+
393
+ torch.cuda.empty_cache()
394
+
395
+ # cast back to fp16 if needed
396
+ if needs_upcasting:
397
+ self.vae.to(dtype=torch.float16)
398
+
399
+ # Prepare timesteps
400
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
401
+ timesteps = self.scheduler.timesteps
402
+
403
+ batch_size, n_frames, _, _, _ = rgb_latent.shape
404
+ num_channels_latents = self.unet.config.in_channels
405
+
406
+ curr_frame = 0
407
+ depth_latent = torch.tensor([], dtype=image_embeddings.dtype, device=device)
408
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
409
+
410
+ # first chunk
411
+ horizon = min(n_frames-curr_frame, self.n_tokens)
412
+ start_frame = 0
413
+ chunk = self.prepare_latents(
414
+ batch_size,
415
+ horizon,
416
+ num_channels_latents,
417
+ H,
418
+ W,
419
+ image_embeddings.dtype,
420
+ device,
421
+ generator,
422
+ latents,
423
+ )
424
+ depth_latent = torch.cat([depth_latent, chunk], 1)
425
+ if show_pbar:
426
+ iterable = tqdm(
427
+ enumerate(timesteps),
428
+ total=len(timesteps),
429
+ leave=False,
430
+ desc=" " * 4 + "Diffusion denoising first chunk",
431
+ )
432
+ else:
433
+ iterable = enumerate(timesteps)
434
+
435
+ for i, t in iterable:
436
+ curr_timesteps = torch.tensor([t]*horizon).to(device)
437
+ depth_latent = self.scheduler.scale_model_input(depth_latent, t)
438
+ noise_pred = self.unet(
439
+ torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2),
440
+ curr_timesteps[start_frame:],
441
+ image_embeddings[start_frame:curr_frame+horizon],
442
+ added_time_ids=added_time_ids
443
+ )[0]
444
+ depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample
445
+
446
+ self.scheduler._step_index = None
447
+ curr_frame += horizon
448
+ pbar.update(horizon)
449
+
450
+ while curr_frame < n_frames:
451
+ if self.chunk_size > 0:
452
+ horizon = min(n_frames - curr_frame, self.chunk_size)
453
+ else:
454
+ horizon = min(n_frames - curr_frame, self.n_tokens)
455
+ assert horizon <= self.n_tokens, "horizon exceeds the number of tokens."
456
+ chunk = self.prepare_latents(
457
+ batch_size,
458
+ horizon,
459
+ num_channels_latents,
460
+ H,
461
+ W,
462
+ image_embeddings.dtype,
463
+ device,
464
+ generator,
465
+ latents,
466
+ )
467
+ depth_latent = torch.cat([depth_latent, chunk], 1)
468
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
469
+ depth_pred_last_latent = depth_latent[:, start_frame:curr_frame].clone()
470
+
471
+ pbar.set_postfix(
472
+ {
473
+ "start": start_frame,
474
+ "end": curr_frame + horizon,
475
+ }
476
+ )
477
+
478
+ if show_pbar:
479
+ iterable = tqdm(
480
+ enumerate(timesteps),
481
+ total=len(timesteps),
482
+ leave=False,
483
+ desc=" " * 4 + "Diffusion denoising ",
484
+ )
485
+ else:
486
+ iterable = enumerate(timesteps)
487
+
488
+ for i, t in iterable:
489
+ curr_timesteps = torch.tensor([t]*(curr_frame+horizon-start_frame)).to(device)
490
+ epsilon = randn_tensor(
491
+ depth_pred_last_latent.shape,
492
+ generator=generator,
493
+ device=device,
494
+ dtype=image_embeddings.dtype
495
+ )
496
+ depth_latent[:, start_frame:curr_frame] = depth_pred_last_latent + epsilon * self.scheduler.sigmas[i]
497
+ depth_latent[:, start_frame:] = self.scheduler.scale_model_input(depth_latent[:, start_frame:], t)
498
+ noise_pred = self.unet(
499
+ torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2),
500
+ curr_timesteps,
501
+ image_embeddings[start_frame:curr_frame+horizon],
502
+ added_time_ids=added_time_ids
503
+ )[0]
504
+ depth_latent[:, start_frame:] = self.scheduler.step(noise_pred, t, depth_latent[:, start_frame:]).prev_sample
505
+
506
+ depth_latent[:, start_frame:curr_frame] = depth_pred_last_latent
507
+ self.scheduler._step_index = None
508
+ curr_frame += horizon
509
+ pbar.update(horizon)
510
+
511
+ torch.cuda.empty_cache()
512
+ if needs_upcasting:
513
+ self.vae.to(dtype=torch.float16)
514
+ depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size)
515
+ # clip prediction
516
+ depth = torch.clip(depth, -1.0, 1.0)
517
+ # shift to [0, 1]
518
+ depth = (depth + 1.0) / 2.0
519
+
520
+ return depth.squeeze(0)
521
+
522
+ @torch.no_grad()
523
+ def single_infer_naive_sliding_window(self,
524
+ input_rgb: torch.Tensor,
525
+ image_embeddings: torch.Tensor,
526
+ added_time_ids: torch.Tensor,
527
+ num_inference_steps: int,
528
+ show_pbar: bool,
529
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
530
+ decode_chunk_size=1,
531
+ latents: Optional[torch.Tensor] = None,
532
+ ):
533
+ device = input_rgb.device
534
+ H, W = input_rgb.shape[-2:]
535
+
536
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
537
+ if needs_upcasting:
538
+ self.vae.to(dtype=torch.float32)
539
+
540
+ rgb_latent = self.encode_images(input_rgb)
541
+ rgb_latent = rgb_latent.to(image_embeddings.dtype)
542
+
543
+ torch.cuda.empty_cache()
544
+
545
+ # cast back to fp16 if needed
546
+ if needs_upcasting:
547
+ self.vae.to(dtype=torch.float16)
548
+
549
+ # Prepare timesteps
550
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
551
+ timesteps = self.scheduler.timesteps
552
+
553
+ batch_size, n_frames, _, _, _ = rgb_latent.shape
554
+ num_channels_latents = self.unet.config.in_channels
555
+
556
+ curr_frame = 0
557
+ depth_latent = torch.tensor([], dtype=image_embeddings.dtype, device=device)
558
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
559
+
560
+ # first chunk
561
+ horizon = min(n_frames-curr_frame, self.n_tokens)
562
+ start_frame = 0
563
+ chunk = self.prepare_latents(
564
+ batch_size,
565
+ horizon,
566
+ num_channels_latents,
567
+ H,
568
+ W,
569
+ image_embeddings.dtype,
570
+ device,
571
+ generator,
572
+ latents,
573
+ )
574
+ depth_latent = torch.cat([depth_latent, chunk], 1)
575
+ if show_pbar:
576
+ iterable = tqdm(
577
+ enumerate(timesteps),
578
+ total=len(timesteps),
579
+ leave=False,
580
+ desc=" " * 4 + "Diffusion denoising first chunk",
581
+ )
582
+ else:
583
+ iterable = enumerate(timesteps)
584
+
585
+ for i, t in iterable:
586
+ curr_timesteps = torch.tensor([t]*horizon).to(device)
587
+ depth_latent = self.scheduler.scale_model_input(depth_latent, t)
588
+ noise_pred = self.unet(
589
+ torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2),
590
+ curr_timesteps[start_frame:],
591
+ image_embeddings[start_frame:curr_frame+horizon],
592
+ added_time_ids=added_time_ids
593
+ )[0]
594
+ depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample
595
+
596
+ self.scheduler._step_index = None
597
+ curr_frame += horizon
598
+ pbar.update(horizon)
599
+
600
+ while curr_frame < n_frames:
601
+ if self.chunk_size > 0:
602
+ horizon = min(n_frames - curr_frame, self.chunk_size)
603
+ else:
604
+ horizon = min(n_frames - curr_frame, self.n_tokens)
605
+ assert horizon <= self.n_tokens, "horizon exceeds the number of tokens."
606
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
607
+
608
+ chunk = self.prepare_latents(
609
+ batch_size,
610
+ curr_frame+horizon-start_frame,
611
+ num_channels_latents,
612
+ H,
613
+ W,
614
+ image_embeddings.dtype,
615
+ device,
616
+ generator,
617
+ latents,
618
+ )
619
+
620
+ pbar.set_postfix(
621
+ {
622
+ "start": start_frame,
623
+ "end": curr_frame + horizon,
624
+ }
625
+ )
626
+
627
+ if show_pbar:
628
+ iterable = tqdm(
629
+ enumerate(timesteps),
630
+ total=len(timesteps),
631
+ leave=False,
632
+ desc=" " * 4 + "Diffusion denoising ",
633
+ )
634
+ else:
635
+ iterable = enumerate(timesteps)
636
+
637
+ for i, t in iterable:
638
+ curr_timesteps = torch.tensor([t]*(curr_frame+horizon-start_frame)).to(device)
639
+ chunk = self.scheduler.scale_model_input(chunk, t)
640
+ noise_pred = self.unet(
641
+ torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], chunk], dim=2),
642
+ curr_timesteps,
643
+ image_embeddings[start_frame:curr_frame+horizon],
644
+ added_time_ids=added_time_ids
645
+ )[0]
646
+ chunk = self.scheduler.step(noise_pred, t, chunk).prev_sample
647
+
648
+ depth_latent = torch.cat([depth_latent, chunk[:, -horizon:]], 1)
649
+ self.scheduler._step_index = None
650
+ curr_frame += horizon
651
+ pbar.update(horizon)
652
+
653
+ torch.cuda.empty_cache()
654
+ if needs_upcasting:
655
+ self.vae.to(dtype=torch.float16)
656
+ depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size)
657
+ # clip prediction
658
+ depth = torch.clip(depth, -1.0, 1.0)
659
+ # shift to [0, 1]
660
+ depth = (depth + 1.0) / 2.0
661
+
662
+ return depth.squeeze(0)
chronodepth/unet_chronodepth.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple
2
+
3
+ import torch
4
+ from diffusers import UNetSpatioTemporalConditionModel
5
+ from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
6
+
7
+ class DiffusersUNetSpatioTemporalConditionModelChronodepth(
8
+ UNetSpatioTemporalConditionModel
9
+ ):
10
+
11
+ def forward(
12
+ self,
13
+ sample: torch.FloatTensor,
14
+ timestep: Union[torch.Tensor, float, int],
15
+ encoder_hidden_states: torch.Tensor,
16
+ added_time_ids: torch.Tensor,
17
+ return_dict: bool = True,
18
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
19
+ r"""
20
+ The [`UNetSpatioTemporalConditionModel`] forward method.
21
+
22
+ Args:
23
+ sample (`torch.FloatTensor`):
24
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
25
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
26
+ encoder_hidden_states (`torch.FloatTensor`):
27
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
28
+ added_time_ids: (`torch.FloatTensor`):
29
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
30
+ embeddings and added to the time embeddings.
31
+ return_dict (`bool`, *optional*, defaults to `True`):
32
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
33
+ tuple.
34
+ Returns:
35
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
36
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
37
+ a `tuple` is returned where the first element is the sample tensor.
38
+ """
39
+ # 1. time
40
+ timesteps = timestep
41
+ if not torch.is_tensor(timesteps):
42
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
43
+ # This would be a good case for the `match` statement (Python 3.10+)
44
+ is_mps = sample.device.type == "mps"
45
+ if isinstance(timestep, float):
46
+ dtype = torch.float32 if is_mps else torch.float64
47
+ else:
48
+ dtype = torch.int32 if is_mps else torch.int64
49
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
50
+ elif len(timesteps.shape) == 0:
51
+ timesteps = timesteps[None].to(sample.device)
52
+
53
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
54
+ batch_size, num_frames = sample.shape[:2]
55
+ # timesteps = timesteps.expand(batch_size)
56
+
57
+ t_emb = self.time_proj(timesteps)
58
+
59
+ # `Timesteps` does not contain any weights and will always return f32 tensors
60
+ # but time_embedding might actually be running in fp16. so we need to cast here.
61
+ # there might be better ways to encapsulate this.
62
+ t_emb = t_emb.to(dtype=sample.dtype)
63
+
64
+ emb = self.time_embedding(t_emb)
65
+
66
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
67
+ time_embeds = time_embeds.reshape((batch_size, -1)).repeat(num_frames, 1)
68
+ time_embeds = time_embeds.to(emb.dtype)
69
+ aug_emb = self.add_embedding(time_embeds)
70
+ emb = emb + aug_emb
71
+
72
+ # Flatten the batch and frames dimensions
73
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
74
+ sample = sample.flatten(0, 1)
75
+
76
+ # Repeat the embeddings num_video_frames times
77
+ # emb: [batch, channels] -> [batch * frames, channels]
78
+ # emb = emb.repeat_interleave(num_frames, dim=0) # TODO: sjh: maybe check later
79
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
80
+ # encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
81
+
82
+ ######### some modifications by Jiahao #########
83
+ # emb: [batch * frames, channels]
84
+ # no need to be repeated, because different frames have different time embeddings
85
+ # encoder_hidden_states: [batch * frames, 1, channels]
86
+ # no need to be repeated, because different frames have different encoder_hidden_states
87
+
88
+ # 2. pre-process
89
+ sample = self.conv_in(sample)
90
+
91
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
92
+
93
+ down_block_res_samples = (sample,)
94
+ for downsample_block in self.down_blocks:
95
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
96
+ sample, res_samples = downsample_block(
97
+ hidden_states=sample,
98
+ temb=emb,
99
+ encoder_hidden_states=encoder_hidden_states,
100
+ image_only_indicator=image_only_indicator,
101
+ )
102
+ else:
103
+ sample, res_samples = downsample_block(
104
+ hidden_states=sample,
105
+ temb=emb,
106
+ image_only_indicator=image_only_indicator,
107
+ )
108
+
109
+ down_block_res_samples += res_samples
110
+
111
+ # 4. mid
112
+ sample = self.mid_block(
113
+ hidden_states=sample,
114
+ temb=emb,
115
+ encoder_hidden_states=encoder_hidden_states,
116
+ image_only_indicator=image_only_indicator,
117
+ )
118
+
119
+ # 5. up
120
+ for i, upsample_block in enumerate(self.up_blocks):
121
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
122
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
123
+
124
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
125
+ sample = upsample_block(
126
+ hidden_states=sample,
127
+ temb=emb,
128
+ res_hidden_states_tuple=res_samples,
129
+ encoder_hidden_states=encoder_hidden_states,
130
+ image_only_indicator=image_only_indicator,
131
+ )
132
+ else:
133
+ sample = upsample_block(
134
+ hidden_states=sample,
135
+ temb=emb,
136
+ res_hidden_states_tuple=res_samples,
137
+ image_only_indicator=image_only_indicator,
138
+ )
139
+
140
+ # 6. post-process
141
+ sample = self.conv_norm_out(sample)
142
+ sample = self.conv_act(sample)
143
+ sample = self.conv_out(sample)
144
+
145
+ # 7. Reshape back to original shape
146
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
147
+
148
+ if not return_dict:
149
+ return (sample,)
150
+
151
+ return UNetSpatioTemporalConditionOutput(sample=sample)
chronodepth/video_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+
6
+
7
+ def resize_max_res(video_rgb, max_res, interpolation=cv2.INTER_LINEAR):
8
+ """
9
+ Resize the video to the max resolution while keeping the aspect ratio.
10
+ Args:
11
+ video_rgb: (T, H, W, 3), RGB video, uint8
12
+ max_res: int, max resolution
13
+ Returns:
14
+ video_rgb: (T, H_new, W_new, 3), resized RGB video, uint8
15
+ """
16
+ original_height = video_rgb.shape[1]
17
+ original_width = video_rgb.shape[2]
18
+
19
+ # round the height and width to the nearest multiple of 64
20
+ height = round(original_height / 64) * 64
21
+ width = round(original_width / 64) * 64
22
+
23
+ # resize the video if the height or width is larger than max_res
24
+ if max(height, width) > max_res:
25
+ scale = max_res / max(original_height, original_width)
26
+ height = round(original_height * scale / 64) * 64
27
+ width = round(original_width * scale / 64) * 64
28
+
29
+ frames = []
30
+ for i in range(video_rgb.shape[0]):
31
+ frames.append(cv2.resize(video_rgb[i], (width, height), interpolation=interpolation))
32
+
33
+ frames = np.array(frames)
34
+ return frames
35
+
36
+
37
+ def colorize_video_depth(depth_video, colormap="Spectral"):
38
+ """
39
+ Colorize the depth video using the specified colormap.
40
+ depth_video: (T, H, W), depth video, [0, 1]
41
+ return:
42
+ colored_depth_video: (T, H, W, 3), colored depth video, dtype=uint8
43
+ """
44
+ if isinstance(depth_video, torch.Tensor):
45
+ depth_video = depth_video.cpu().numpy()
46
+ T, H, W = depth_video.shape
47
+ colored_depth_video = []
48
+ for i in range(T):
49
+ colored_depth = plt.get_cmap(colormap)(depth_video[i], bytes=True)[...,:3]
50
+ colored_depth_video.append(colored_depth)
51
+ colored_depth_video = np.stack(colored_depth_video, axis=0)
52
+
53
+ return colored_depth_video
chronodepth_pipeline.py DELETED
@@ -1,530 +0,0 @@
1
- # Adapted from Marigold: https://github.com/prs-eth/Marigold and diffusers
2
-
3
- import inspect
4
- from typing import Union, Optional, List
5
-
6
- import torch
7
- import numpy as np
8
- import matplotlib.pyplot as plt
9
- from tqdm.auto import tqdm
10
- import PIL
11
- from PIL import Image
12
- from diffusers import (
13
- DiffusionPipeline,
14
- EulerDiscreteScheduler,
15
- UNetSpatioTemporalConditionModel,
16
- AutoencoderKLTemporalDecoder,
17
- )
18
- from diffusers.image_processor import VaeImageProcessor
19
- from diffusers.utils import BaseOutput
20
- from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
21
- from transformers import (
22
- CLIPVisionModelWithProjection,
23
- CLIPImageProcessor,
24
- )
25
- from einops import rearrange, repeat
26
-
27
-
28
- class ChronoDepthOutput(BaseOutput):
29
- r"""
30
- Output class for zero-shot text-to-video pipeline.
31
-
32
- Args:
33
- frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
34
- List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
35
- num_channels)`.
36
- """
37
- depth_np: np.ndarray
38
- depth_colored: Union[List[PIL.Image.Image], np.ndarray]
39
-
40
-
41
- class ChronoDepthPipeline(DiffusionPipeline):
42
- model_cpu_offload_seq = "image_encoder->unet->vae"
43
- _callback_tensor_inputs = ["latents"]
44
- rgb_latent_scale_factor = 0.18215
45
- depth_latent_scale_factor = 0.18215
46
-
47
- def __init__(
48
- self,
49
- vae: AutoencoderKLTemporalDecoder,
50
- image_encoder: CLIPVisionModelWithProjection,
51
- unet: UNetSpatioTemporalConditionModel,
52
- scheduler: EulerDiscreteScheduler,
53
- feature_extractor: CLIPImageProcessor,
54
- ):
55
- super().__init__()
56
-
57
- self.register_modules(
58
- vae=vae,
59
- image_encoder=image_encoder,
60
- unet=unet,
61
- scheduler=scheduler,
62
- feature_extractor=feature_extractor,
63
- )
64
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
65
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
66
- if not hasattr(self, "dtype"):
67
- self.dtype = self.unet.dtype
68
-
69
- def encode_RGB(self,
70
- image: torch.Tensor,
71
- ):
72
- video_length = image.shape[1]
73
- image = rearrange(image, "b f c h w -> (b f) c h w")
74
- latents = self.vae.encode(image).latent_dist.sample()
75
- latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
76
- latents = latents * self.vae.config.scaling_factor
77
-
78
- return latents
79
-
80
- def _encode_image(self, image, device, discard=True):
81
- '''
82
- set image to zero tensor discards the image embeddings if discard is True
83
- '''
84
- dtype = next(self.image_encoder.parameters()).dtype
85
-
86
- if not isinstance(image, torch.Tensor):
87
- image = self.image_processor.pil_to_numpy(image)
88
- if discard:
89
- image = np.zeros_like(image)
90
- image = self.image_processor.numpy_to_pt(image)
91
-
92
- # We normalize the image before resizing to match with the original implementation.
93
- # Then we unnormalize it after resizing.
94
- image = image * 2.0 - 1.0
95
- image = _resize_with_antialiasing(image, (224, 224))
96
- image = (image + 1.0) / 2.0
97
-
98
- # Normalize the image with for CLIP input
99
- image = self.feature_extractor(
100
- images=image,
101
- do_normalize=True,
102
- do_center_crop=False,
103
- do_resize=False,
104
- do_rescale=False,
105
- return_tensors="pt",
106
- ).pixel_values
107
-
108
- image = image.to(device=device, dtype=dtype)
109
- image_embeddings = self.image_encoder(image).image_embeds
110
- image_embeddings = image_embeddings.unsqueeze(1)
111
-
112
- return image_embeddings
113
-
114
- def decode_depth(self, depth_latent: torch.Tensor, decode_chunk_size=5) -> torch.Tensor:
115
- num_frames = depth_latent.shape[1]
116
- depth_latent = rearrange(depth_latent, "b f c h w -> (b f) c h w")
117
-
118
- depth_latent = depth_latent / self.vae.config.scaling_factor
119
-
120
- forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
121
- accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
122
-
123
- depth_frames = []
124
- for i in range(0, depth_latent.shape[0], decode_chunk_size):
125
- num_frames_in = depth_latent[i : i + decode_chunk_size].shape[0]
126
- decode_kwargs = {}
127
- if accepts_num_frames:
128
- # we only pass num_frames_in if it's expected
129
- decode_kwargs["num_frames"] = num_frames_in
130
-
131
- depth_frame = self.vae.decode(depth_latent[i : i + decode_chunk_size], **decode_kwargs).sample
132
- depth_frames.append(depth_frame)
133
-
134
- depth_frames = torch.cat(depth_frames, dim=0)
135
- depth_frames = depth_frames.reshape(-1, num_frames, *depth_frames.shape[1:])
136
- depth_mean = depth_frames.mean(dim=2, keepdim=True)
137
-
138
- return depth_mean
139
-
140
- def _get_add_time_ids(self,
141
- fps,
142
- motion_bucket_id,
143
- noise_aug_strength,
144
- dtype,
145
- batch_size,
146
- ):
147
- add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
148
-
149
- passed_add_embed_dim = self.unet.config.addition_time_embed_dim * \
150
- len(add_time_ids)
151
- expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
152
-
153
- if expected_add_embed_dim != passed_add_embed_dim:
154
- raise ValueError(
155
- f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
156
- )
157
-
158
- add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
159
- add_time_ids = add_time_ids.repeat(batch_size, 1)
160
- return add_time_ids
161
-
162
- def decode_latents(self, latents, num_frames, decode_chunk_size=14):
163
- # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
164
- latents = latents.flatten(0, 1)
165
-
166
- latents = 1 / self.vae.config.scaling_factor * latents
167
-
168
- forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
169
- accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
170
-
171
- # decode decode_chunk_size frames at a time to avoid OOM
172
- frames = []
173
- for i in range(0, latents.shape[0], decode_chunk_size):
174
- num_frames_in = latents[i : i + decode_chunk_size].shape[0]
175
- decode_kwargs = {}
176
- if accepts_num_frames:
177
- # we only pass num_frames_in if it's expected
178
- decode_kwargs["num_frames"] = num_frames_in
179
-
180
- frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
181
- frames.append(frame)
182
- frames = torch.cat(frames, dim=0)
183
-
184
- # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
185
- frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
186
-
187
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
188
- frames = frames.float()
189
- return frames
190
-
191
- def check_inputs(self, image, height, width):
192
- if (
193
- not isinstance(image, torch.Tensor)
194
- and not isinstance(image, PIL.Image.Image)
195
- and not isinstance(image, list)
196
- ):
197
- raise ValueError(
198
- "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
199
- f" {type(image)}"
200
- )
201
-
202
- if height % 64 != 0 or width % 64 != 0:
203
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
204
-
205
- def prepare_latents(
206
- self,
207
- shape,
208
- dtype,
209
- device,
210
- generator,
211
- latent=None,
212
- ):
213
- if isinstance(generator, list) and len(generator) != shape[0]:
214
- raise ValueError(
215
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
216
- f" size of {shape[0]}. Make sure the batch size matches the length of the generators."
217
- )
218
-
219
- if latent is None:
220
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
221
- else:
222
- latents = latents.to(device)
223
-
224
- # scale the initial noise by the standard deviation required by the scheduler
225
- latents = latents * self.scheduler.init_noise_sigma
226
- return latents
227
-
228
- @property
229
- def num_timesteps(self):
230
- return self._num_timesteps
231
-
232
- @torch.no_grad()
233
- def __call__(
234
- self,
235
- input_image: Union[List[PIL.Image.Image], torch.FloatTensor],
236
- height: int = 576,
237
- width: int = 768,
238
- num_frames: Optional[int] = None,
239
- num_inference_steps: int = 10,
240
- fps: int = 7,
241
- motion_bucket_id: int = 127,
242
- noise_aug_strength: float = 0.02,
243
- decode_chunk_size: Optional[int] = None,
244
- color_map: str="Spectral",
245
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
246
- show_progress_bar: bool = True,
247
- match_input_res: bool = True,
248
- depth_pred_last: Optional[torch.FloatTensor] = None,
249
- ):
250
- assert height >= 0 and width >=0
251
- assert num_inference_steps >=1
252
-
253
- num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
254
- decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
255
-
256
- # 1. Check inputs. Raise error if not correct
257
- self.check_inputs(input_image, height, width)
258
-
259
- # 2. Define call parameters
260
- if isinstance(input_image, list):
261
- batch_size = 1
262
- input_size = input_image[0].size
263
- elif isinstance(input_image, torch.Tensor):
264
- batch_size = input_image.shape[0]
265
- input_size = input_image.shape[:-3:-1]
266
- assert batch_size == 1, "Batch size must be 1 for now"
267
- device = self._execution_device
268
-
269
- # 3. Encode input image
270
- image_embeddings = self._encode_image(input_image[0], device)
271
- image_embeddings = image_embeddings.repeat((batch_size, 1, 1))
272
-
273
- # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
274
- # is why it is reduced here.
275
- # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
276
- fps = fps - 1
277
-
278
- # 4. Encode input image using VAE
279
- input_image = self.image_processor.preprocess(input_image, height=height, width=width).to(device)
280
- assert input_image.min() >= -1.0 and input_image.max() <= 1.0
281
- noise = randn_tensor(input_image.shape, generator=generator, device=device, dtype=input_image.dtype)
282
- input_image = input_image + noise_aug_strength * noise
283
- if depth_pred_last is not None:
284
- depth_pred_last = depth_pred_last.to(device)
285
- # resize depth
286
- from torchvision.transforms import InterpolationMode
287
- from torchvision.transforms.functional import resize
288
- depth_pred_last = resize(depth_pred_last.unsqueeze(1), (height, width), InterpolationMode.NEAREST_EXACT, antialias=True)
289
- depth_pred_last = repeat(depth_pred_last, 'f c h w ->b f c h w', b=batch_size)
290
-
291
- rgb_batch = repeat(input_image, 'f c h w ->b f c h w', b=batch_size)
292
-
293
- added_time_ids = self._get_add_time_ids(
294
- fps,
295
- motion_bucket_id,
296
- noise_aug_strength,
297
- image_embeddings.dtype,
298
- batch_size,
299
- )
300
- added_time_ids = added_time_ids.to(device)
301
-
302
- depth_pred_raw = self.single_infer(rgb_batch,
303
- image_embeddings,
304
- added_time_ids,
305
- num_inference_steps,
306
- show_progress_bar,
307
- generator,
308
- depth_pred_last=depth_pred_last,
309
- decode_chunk_size=decode_chunk_size)
310
-
311
- depth_colored_img_list = []
312
- depth_frames = []
313
- for i in range(num_frames):
314
- depth_frame = depth_pred_raw[:, i].squeeze()
315
-
316
- # Convert to numpy
317
- depth_frame = depth_frame.cpu().numpy().astype(np.float32)
318
-
319
- if match_input_res:
320
- pred_img = Image.fromarray(depth_frame)
321
- pred_img = pred_img.resize(input_size, resample=Image.NEAREST)
322
- depth_frame = np.asarray(pred_img)
323
-
324
- # Clip output range: current size is the original size
325
- depth_frame = depth_frame.clip(0, 1)
326
-
327
- # Colorize
328
- depth_colored = plt.get_cmap(color_map)(depth_frame, bytes=True)[..., :3]
329
- depth_colored_img = Image.fromarray(depth_colored)
330
-
331
- depth_colored_img_list.append(depth_colored_img)
332
- depth_frames.append(depth_frame)
333
-
334
- depth_frame = np.stack(depth_frames)
335
-
336
- self.maybe_free_model_hooks()
337
-
338
- return ChronoDepthOutput(
339
- depth_np = depth_frames,
340
- depth_colored = depth_colored_img_list,
341
- )
342
-
343
- @torch.no_grad()
344
- def single_infer(self,
345
- input_rgb: torch.Tensor,
346
- image_embeddings: torch.Tensor,
347
- added_time_ids: torch.Tensor,
348
- num_inference_steps: int,
349
- show_pbar: bool,
350
- generator: Optional[Union[torch.Generator, List[torch.Generator]]],
351
- depth_pred_last: Optional[torch.Tensor] = None,
352
- decode_chunk_size=1,
353
- ):
354
- device = input_rgb.device
355
-
356
- needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
357
- if needs_upcasting:
358
- self.vae.to(dtype=torch.float32)
359
-
360
- rgb_latent = self.encode_RGB(input_rgb)
361
- rgb_latent = rgb_latent.to(image_embeddings.dtype)
362
- if depth_pred_last is not None:
363
- depth_pred_last = depth_pred_last.repeat(1, 1, 3, 1, 1)
364
- depth_pred_last_latent = self.encode_RGB(depth_pred_last)
365
- depth_pred_last_latent = depth_pred_last_latent.to(image_embeddings.dtype)
366
- else:
367
- depth_pred_last_latent = None
368
-
369
- # cast back to fp16 if needed
370
- if needs_upcasting:
371
- self.vae.to(dtype=torch.float16)
372
-
373
- # Prepare timesteps
374
- self.scheduler.set_timesteps(num_inference_steps, device=device)
375
- timesteps = self.scheduler.timesteps
376
-
377
- depth_latent = self.prepare_latents(
378
- rgb_latent.shape,
379
- image_embeddings.dtype,
380
- device,
381
- generator
382
- )
383
-
384
- if show_pbar:
385
- iterable = tqdm(
386
- enumerate(timesteps),
387
- total=len(timesteps),
388
- leave=False,
389
- desc=" " * 4 + "Diffusion denoising",
390
- )
391
- else:
392
- iterable = enumerate(timesteps)
393
-
394
- for i, t in iterable:
395
- if depth_pred_last_latent is not None:
396
- known_frames_num = depth_pred_last_latent.shape[1]
397
- epsilon = randn_tensor(
398
- depth_pred_last_latent.shape,
399
- generator=generator,
400
- device=device,
401
- dtype=image_embeddings.dtype
402
- )
403
- depth_latent[:, :known_frames_num] = depth_pred_last_latent + epsilon * self.scheduler.sigmas[i]
404
- depth_latent = self.scheduler.scale_model_input(depth_latent, t)
405
- unet_input = torch.cat([rgb_latent, depth_latent], dim=2)
406
-
407
- noise_pred = self.unet(
408
- unet_input, t, image_embeddings, added_time_ids=added_time_ids
409
- )[0]
410
-
411
- # compute the previous noisy sample x_t -> x_t-1
412
- depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
413
-
414
- torch.cuda.empty_cache()
415
- if needs_upcasting:
416
- self.vae.to(dtype=torch.float16)
417
- depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size)
418
- # clip prediction
419
- depth = torch.clip(depth, -1.0, 1.0)
420
- # shift to [0, 1]
421
- depth = (depth + 1.0) / 2.0
422
-
423
- return depth
424
-
425
- # resizing utils
426
- def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
427
- h, w = input.shape[-2:]
428
- factors = (h / size[0], w / size[1])
429
-
430
- # First, we have to determine sigma
431
- # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
432
- sigmas = (
433
- max((factors[0] - 1.0) / 2.0, 0.001),
434
- max((factors[1] - 1.0) / 2.0, 0.001),
435
- )
436
-
437
- # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
438
- # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
439
- # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
440
- ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
441
-
442
- # Make sure it is odd
443
- if (ks[0] % 2) == 0:
444
- ks = ks[0] + 1, ks[1]
445
-
446
- if (ks[1] % 2) == 0:
447
- ks = ks[0], ks[1] + 1
448
-
449
- input = _gaussian_blur2d(input, ks, sigmas)
450
-
451
- output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
452
- return output
453
-
454
-
455
- def _compute_padding(kernel_size):
456
- """Compute padding tuple."""
457
- # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
458
- # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
459
- if len(kernel_size) < 2:
460
- raise AssertionError(kernel_size)
461
- computed = [k - 1 for k in kernel_size]
462
-
463
- # for even kernels we need to do asymmetric padding :(
464
- out_padding = 2 * len(kernel_size) * [0]
465
-
466
- for i in range(len(kernel_size)):
467
- computed_tmp = computed[-(i + 1)]
468
-
469
- pad_front = computed_tmp // 2
470
- pad_rear = computed_tmp - pad_front
471
-
472
- out_padding[2 * i + 0] = pad_front
473
- out_padding[2 * i + 1] = pad_rear
474
-
475
- return out_padding
476
-
477
-
478
- def _filter2d(input, kernel):
479
- # prepare kernel
480
- b, c, h, w = input.shape
481
- tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
482
-
483
- tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
484
-
485
- height, width = tmp_kernel.shape[-2:]
486
-
487
- padding_shape: list[int] = _compute_padding([height, width])
488
- input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
489
-
490
- # kernel and input tensor reshape to align element-wise or batch-wise params
491
- tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
492
- input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
493
-
494
- # convolve the tensor with the kernel.
495
- output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
496
-
497
- out = output.view(b, c, h, w)
498
- return out
499
-
500
-
501
- def _gaussian(window_size: int, sigma):
502
- if isinstance(sigma, float):
503
- sigma = torch.tensor([[sigma]])
504
-
505
- batch_size = sigma.shape[0]
506
-
507
- x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
508
-
509
- if window_size % 2 == 0:
510
- x = x + 0.5
511
-
512
- gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
513
-
514
- return gauss / gauss.sum(-1, keepdim=True)
515
-
516
-
517
- def _gaussian_blur2d(input, kernel_size, sigma):
518
- if isinstance(sigma, tuple):
519
- sigma = torch.tensor([sigma], dtype=input.dtype)
520
- else:
521
- sigma = sigma.to(dtype=input.dtype)
522
-
523
- ky, kx = int(kernel_size[0]), int(kernel_size[1])
524
- bs = sigma.shape[0]
525
- kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
526
- kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
527
- out_x = _filter2d(input, kernel_x[..., None, :])
528
- out = _filter2d(out_x, kernel_y[..., None])
529
-
530
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_patches/examples.py DELETED
@@ -1,13 +0,0 @@
1
- from pathlib import Path
2
-
3
- import gradio
4
- from gradio.utils import get_cache_folder
5
-
6
-
7
- class Examples(gradio.helpers.Examples):
8
- def __init__(self, *args, directory_name=None, **kwargs):
9
- super().__init__(*args, **kwargs, _initiated_directly=False)
10
- if directory_name is not None:
11
- self.cached_folder = get_cache_folder() / directory_name
12
- self.cached_file = Path(self.cached_folder) / "log.csv"
13
- self.create()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,14 +1,16 @@
1
  spaces
2
- gradio>=4.32.1
3
- diffusers==0.26.0
4
  easydict==1.13
5
  einops==0.8.0
6
  matplotlib==3.8.4
7
  mediapy==1.2.2
8
  numpy==1.26.4
9
  Pillow==10.3.0
10
- torch==2.0.1
11
- torchvision==0.15.2
 
12
  tqdm==4.66.2
13
  accelerate==0.28.0
14
- transformers==4.36.2
 
 
1
  spaces
2
+ gradio==4.32.1
3
+ diffusers==0.29.1
4
  easydict==1.13
5
  einops==0.8.0
6
  matplotlib==3.8.4
7
  mediapy==1.2.2
8
  numpy==1.26.4
9
  Pillow==10.3.0
10
+ torch==2.1.0
11
+ torchvision==0.16.0
12
+ xformers==0.0.22.post7
13
  tqdm==4.66.2
14
  accelerate==0.28.0
15
+ transformers==4.36.2
16
+ opencv-python