Spaces:
Running
on
Zero
Running
on
Zero
# MIT License | |
# Copyright (c) 2024 Jiahao Shao | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
import functools | |
import os | |
import zipfile | |
import tempfile | |
from io import BytesIO | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch as torch | |
from PIL import Image | |
from tqdm import tqdm | |
import mediapy as media | |
from huggingface_hub import login | |
from chronodepth_pipeline import ChronoDepthPipeline | |
from gradio_patches.examples import Examples | |
default_seed = 2024 | |
default_num_inference_steps = 5 | |
default_num_frames = 10 | |
default_window_size = 9 | |
default_video_processing_resolution = 768 | |
default_video_out_max_frames = 80 | |
default_decode_chunk_size = 10 | |
def process_video( | |
pipe, | |
path_input, | |
num_inference_steps=default_num_inference_steps, | |
num_frames=default_num_frames, | |
window_size=default_window_size, | |
out_max_frames=default_video_out_max_frames, | |
progress=gr.Progress(), | |
): | |
if path_input is None: | |
raise gr.Error( | |
"Missing video in the first pane: upload a file or use one from the gallery below." | |
) | |
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) | |
print(f"Processing video {name_base}{name_ext}") | |
path_output_dir = tempfile.mkdtemp() | |
path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4") | |
path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.zip") | |
generator = torch.Generator(device=pipe.device).manual_seed(default_seed) | |
import time | |
start_time = time.time() | |
zipf = None | |
try: | |
if window_size is None or window_size == num_frames: | |
inpaint_inference = False | |
else: | |
inpaint_inference = True | |
data_ls = [] | |
video_data = media.read_video(path_input) | |
video_length = len(video_data) | |
fps = video_data.metadata.fps | |
duration_sec = video_length / fps | |
out_duration_sec = out_max_frames / fps | |
if duration_sec > out_duration_sec: | |
gr.Warning( | |
f"Only the first ~{int(out_duration_sec)} seconds will be processed; " | |
f"use alternative setups such as ChronoDepth on github for full processing" | |
) | |
video_length = out_max_frames | |
for i in tqdm(range(video_length-num_frames+1)): | |
is_first_clip = i == 0 | |
is_last_clip = i == video_length - num_frames | |
is_new_clip = ( | |
(inpaint_inference and i % window_size == 0) | |
or (inpaint_inference == False and i % num_frames == 0) | |
) | |
if is_first_clip or is_last_clip or is_new_clip: | |
data_ls.append(np.array(video_data[i: i+num_frames])) # [t, H, W, 3] | |
zipf = zipfile.ZipFile(path_out_16bit, "w", zipfile.ZIP_DEFLATED) | |
depth_colored_pred = [] | |
depth_pred = [] | |
# -------------------- Inference and saving -------------------- | |
with torch.no_grad(): | |
for iter, batch in enumerate(tqdm(data_ls)): | |
rgb_int = batch | |
input_images = [Image.fromarray(rgb_int[i]) for i in range(num_frames)] | |
# Predict depth | |
if iter == 0: # First clip | |
pipe_out = pipe( | |
input_images, | |
num_frames=len(input_images), | |
num_inference_steps=num_inference_steps, | |
decode_chunk_size=default_decode_chunk_size, | |
motion_bucket_id=127, | |
fps=7, | |
noise_aug_strength=0.0, | |
generator=generator, | |
) | |
elif inpaint_inference and (iter == len(data_ls) - 1): # temporal inpaint inference for last clip | |
last_window_size = window_size if video_length%window_size == 0 else video_length%window_size | |
pipe_out = pipe( | |
input_images, | |
num_frames=num_frames, | |
num_inference_steps=num_inference_steps, | |
decode_chunk_size=default_decode_chunk_size, | |
motion_bucket_id=127, | |
fps=7, | |
noise_aug_strength=0.0, | |
generator=generator, | |
depth_pred_last=depth_frames_pred_ts[last_window_size:], | |
) | |
elif inpaint_inference and iter > 0: # temporal inpaint inference | |
pipe_out = pipe( | |
input_images, | |
num_frames=num_frames, | |
num_inference_steps=num_inference_steps, | |
decode_chunk_size=default_decode_chunk_size, | |
motion_bucket_id=127, | |
fps=7, | |
noise_aug_strength=0.0, | |
generator=generator, | |
depth_pred_last=depth_frames_pred_ts[window_size:], | |
) | |
else: # separate inference | |
pipe_out = pipe( | |
input_images, | |
num_frames=num_frames, | |
num_inference_steps=num_inference_steps, | |
decode_chunk_size=default_decode_chunk_size, | |
motion_bucket_id=127, | |
fps=7, | |
noise_aug_strength=0.0, | |
generator=generator, | |
) | |
depth_frames_pred = [pipe_out.depth_np[i] for i in range(num_frames)] | |
depth_frames_colored_pred = [] | |
for i in range(num_frames): | |
depth_frame_colored_pred = np.array(pipe_out.depth_colored[i]) | |
depth_frames_colored_pred.append(depth_frame_colored_pred) | |
depth_frames_colored_pred = np.stack(depth_frames_colored_pred, axis=0) | |
depth_frames_pred = np.stack(depth_frames_pred, axis=0) | |
depth_frames_pred_ts = torch.from_numpy(depth_frames_pred).to(pipe.device) | |
depth_frames_pred_ts = depth_frames_pred_ts * 2 - 1 | |
if inpaint_inference == False: | |
if iter == len(data_ls) - 1: | |
last_window_size = num_frames if video_length%num_frames == 0 else video_length%num_frames | |
depth_colored_pred.append(depth_frames_colored_pred[-last_window_size:]) | |
depth_pred.append(depth_frames_pred[-last_window_size:]) | |
else: | |
depth_colored_pred.append(depth_frames_colored_pred) | |
depth_pred.append(depth_frames_pred) | |
else: | |
if iter == 0: | |
depth_colored_pred.append(depth_frames_colored_pred) | |
depth_pred.append(depth_frames_pred) | |
elif iter == len(data_ls) - 1: | |
depth_colored_pred.append(depth_frames_colored_pred[-last_window_size:]) | |
depth_pred.append(depth_frames_pred[-last_window_size:]) | |
else: | |
depth_colored_pred.append(depth_frames_colored_pred[-window_size:]) | |
depth_pred.append(depth_frames_pred[-window_size:]) | |
depth_colored_pred = np.concatenate(depth_colored_pred, axis=0) | |
depth_pred = np.concatenate(depth_pred, axis=0) | |
# -------------------- Save results -------------------- | |
# Save images | |
for i in tqdm(range(len(depth_pred))): | |
archive_path = os.path.join( | |
f"{name_base}_depth_16bit", f"{i:05d}.png" | |
) | |
img_byte_arr = BytesIO() | |
depth_16bit = Image.fromarray((depth_pred[i] * 65535.0).astype(np.uint16)) | |
depth_16bit.save(img_byte_arr, format="png") | |
img_byte_arr.seek(0) | |
zipf.writestr(archive_path, img_byte_arr.read()) | |
# Export to video | |
media.write_video(path_out_vis, depth_colored_pred, fps=fps) | |
finally: | |
if zipf is not None: | |
zipf.close() | |
end_time = time.time() | |
print(f"Processing time: {end_time - start_time} seconds") | |
return ( | |
path_out_vis, | |
[path_out_vis, path_out_16bit], | |
) | |
def run_demo_server(pipe): | |
process_pipe_video = spaces.GPU( | |
functools.partial(process_video, pipe), duration=220 | |
) | |
os.environ["GRADIO_ALLOW_FLAGGING"] = "never" | |
with gr.Blocks( | |
analytics_enabled=False, | |
title="ChronoDepth Video Depth Estimation", | |
css=""" | |
#download { | |
height: 118px; | |
} | |
.slider .inner { | |
width: 5px; | |
background: #FFF; | |
} | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
""", | |
) as demo: | |
gr.Markdown( | |
""" | |
# ChronoDepth Video Depth Estimation | |
<p align="center"> | |
<a title="Website" href="https://jhaoshao.github.io/ChronoDepth/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/website?url=https%3A%2F%2Fjhaoshao.github.io%2FChronoDepth%2F&up_message=ChronoDepth&up_color=blue&style=flat&logo=timescale&logoColor=%23FFDC0F"> | |
</a> | |
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/badge/arXiv-PDF-b31b1b"> | |
</a> | |
<a title="Github" href="https://github.com/jhaoshao/ChronoDepth" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/jhaoshao/ChronoDepth?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
</a> | |
</p> | |
ChronoDepth is the state-of-the-art video depth estimator for videos in the wild. | |
Upload your video and have a try!<br> | |
We set denoising steps to 5, number of frames for each video clip to 10, and overlap between clips to 1. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video( | |
label="Input Video", | |
sources=["upload"], | |
) | |
with gr.Row(): | |
video_submit_btn = gr.Button( | |
value="Compute Depth", variant="primary" | |
) | |
video_reset_btn = gr.Button(value="Reset") | |
with gr.Column(): | |
video_output_video = gr.Video( | |
label="Output video depth (red-near, blue-far)", | |
interactive=False, | |
) | |
video_output_files = gr.Files( | |
label="Depth outputs", | |
elem_id="download", | |
interactive=False, | |
) | |
Examples( | |
fn=process_pipe_video, | |
examples=[ | |
os.path.join("files", name) | |
for name in [ | |
"sora_e2.mp4", | |
"sora_1758192960116785459.mp4", | |
] | |
], | |
inputs=[video_input], | |
outputs=[video_output_video, video_output_files], | |
cache_examples=True, | |
directory_name="examples_video", | |
) | |
video_submit_btn.click( | |
fn=process_pipe_video, | |
inputs=[video_input], | |
outputs=[video_output_video, video_output_files], | |
concurrency_limit=1, | |
) | |
video_reset_btn.click( | |
fn=lambda: (None, None, None), | |
inputs=[], | |
outputs=[video_input, video_output_video], | |
concurrency_limit=1, | |
) | |
demo.queue( | |
api_open=False, | |
).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
) | |
def main(): | |
CHECKPOINT = "jhshao/ChronoDepth" | |
if "HF_TOKEN_LOGIN" in os.environ: | |
login(token=os.environ["HF_TOKEN_LOGIN"]) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Running on device: {device}") | |
pipe = ChronoDepthPipeline.from_pretrained(CHECKPOINT) | |
try: | |
import xformers | |
pipe.enable_xformers_memory_efficient_attention() | |
except: | |
pass # run without xformers | |
pipe = pipe.to(device) | |
run_demo_server(pipe) | |
if __name__ == "__main__": | |
main() | |