Spaces:
Running
on
Zero
Running
on
Zero
# MIT License | |
# Copyright (c) 2024 Jiahao Shao | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
import functools | |
import os | |
import zipfile | |
import tempfile | |
from io import BytesIO | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch as torch | |
import torch.nn.functional as F | |
import xformers | |
from PIL import Image | |
from tqdm import tqdm | |
import mediapy as media | |
from huggingface_hub import login | |
from chronodepth.unet_chronodepth import DiffusersUNetSpatioTemporalConditionModelChronodepth | |
from chronodepth.chronodepth_pipeline import ChronoDepthPipeline | |
from chronodepth.video_utils import resize_max_res, colorize_video_depth | |
MAX_FRAME=15 | |
default_seed = 2024 | |
default_num_inference_steps = 5 | |
default_n_tokens = 10 | |
default_chunk_size = 5 | |
default_video_processing_resolution = 768 | |
default_decode_chunk_size = 8 | |
def run_pipeline(pipe, video_rgb, generator, device): | |
""" | |
Run the pipe on the input video. | |
args: | |
pipe: ChronoDepthPipeline object | |
video_rgb: input video, torch.Tensor, shape [T, H, W, 3], range [0, 255] | |
generator: torch.Generator | |
returns: | |
video_depth_pred: predicted depth, torch.Tensor, shape [T, H, W], range [0, 1] | |
""" | |
if isinstance(video_rgb, torch.Tensor): | |
video_rgb = video_rgb.cpu().numpy() | |
original_height = video_rgb.shape[1] | |
original_width = video_rgb.shape[2] | |
# resize the video to the max resolution | |
video_rgb = resize_max_res(video_rgb, default_video_processing_resolution) | |
video_rgb = video_rgb.astype(np.float32) / 255.0 | |
pipe_out = pipe( | |
video_rgb, | |
num_inference_steps=default_num_inference_steps, | |
decode_chunk_size=default_decode_chunk_size, | |
motion_bucket_id=127, | |
fps=7, | |
noise_aug_strength=0.0, | |
generator=generator, | |
infer_mode="ours", | |
sigma_epsilon=-4, | |
) | |
depth_frames_pred = pipe_out.frames | |
depth_frames_pred = torch.from_numpy(depth_frames_pred).to(device) | |
depth_frames_pred = F.interpolate(depth_frames_pred, size=(original_height, original_width), mode="bilinear", align_corners=False) | |
depth_frames_pred = depth_frames_pred.clamp(0, 1) | |
depth_frames_pred = depth_frames_pred.squeeze(1) | |
return depth_frames_pred | |
def process_video( | |
pipe, | |
path_input, | |
num_inference_steps=default_num_inference_steps, | |
out_max_frames=MAX_FRAME, | |
progress=gr.Progress(), | |
): | |
if path_input is None: | |
raise gr.Error( | |
"Missing video in the first pane: upload a file or use one from the gallery below." | |
) | |
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) | |
print(f"Processing video {name_base}{name_ext}") | |
path_output_dir = tempfile.mkdtemp() | |
path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4") | |
path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.zip") | |
generator = torch.Generator(device=pipe.device).manual_seed(default_seed) | |
import time | |
start_time = time.time() | |
zipf = None | |
try: | |
# -------------------- data -------------------- | |
video_name = path_input.split('/')[-1].split('.')[0] | |
video_data = media.read_video(path_input) | |
fps = video_data.metadata.fps | |
video_length = len(video_data) | |
video_rgb = np.array(video_data) | |
duration_sec = video_length / fps | |
out_duration_sec = out_max_frames / fps | |
if duration_sec > out_duration_sec: | |
gr.Warning( | |
f"Only the first ~{int(out_duration_sec)} seconds will be processed; " | |
f"use alternative setups such as ChronoDepth on github for full processing" | |
) | |
video_rgb = video_rgb[:out_max_frames] | |
zipf = zipfile.ZipFile(path_out_16bit, "w", zipfile.ZIP_DEFLATED) | |
# -------------------- Inference and saving -------------------- | |
depth_pred = run_pipeline(pipe, video_rgb, generator, pipe.device) # range [0, 1] | |
depth_pred = depth_pred.cpu().numpy() | |
depth_colored_pred = colorize_video_depth(depth_pred) # range [0, 1] -> [0, 255] | |
# -------------------- Save results -------------------- | |
for i in tqdm(range(len(depth_pred))): | |
archive_path = os.path.join( | |
f"{name_base}_depth_16bit", f"{i:05d}.png" | |
) | |
img_byte_arr = BytesIO() | |
depth_16bit = Image.fromarray((depth_pred[i] * 65535.0).astype(np.uint16)) | |
depth_16bit.save(img_byte_arr, format="png") | |
img_byte_arr.seek(0) | |
zipf.writestr(archive_path, img_byte_arr.read()) | |
# Export to video | |
media.write_video(path_out_vis, depth_colored_pred, fps=fps) | |
finally: | |
if zipf is not None: | |
zipf.close() | |
end_time = time.time() | |
print(f"Processing time: {end_time - start_time} seconds") | |
return ( | |
path_out_vis, | |
[path_out_vis, path_out_16bit], | |
) | |
def run_demo_server(pipe): | |
process_pipe_video = spaces.GPU( | |
functools.partial(process_video, pipe), duration=100 | |
) | |
os.environ["GRADIO_ALLOW_FLAGGING"] = "never" | |
with gr.Blocks( | |
analytics_enabled=False, | |
title="ChronoDepth Video Depth Estimation", | |
css=""" | |
#download { | |
height: 118px; | |
} | |
.slider .inner { | |
width: 5px; | |
background: #FFF; | |
} | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
""", | |
) as demo: | |
gr.HTML( | |
""" | |
<h1>⏰ChronoDepth: Learning Temporally Consistent Video Depth from Video Diffusion Priors</h1> | |
<div style="text-align: center; margin-top: 20px;"> | |
<a title="Website" href="https://jhaoshao.github.io/ChronoDepth/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/website?url=https%3A%2F%2Fjhaoshao.github.io%2FChronoDepth%2F&up_message=ChronoDepth&up_color=blue&style=flat&logo=timescale&logoColor=%23FFDC0F"> | |
</a> | |
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/badge/arXiv-PDF-b31b1b"> | |
</a> | |
<a title="Github" href="https://github.com/jhaoshao/ChronoDepth" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/jhaoshao/ChronoDepth?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
</a> | |
</div> | |
<p style="margin-top: 20px; text-align: justify;"> | |
ChronoDepth is the state-of-the-art video depth estimator for streaming videos in the wild. | |
</p> | |
<p style="margin-top: 20px; text-align: justify;"> | |
PS: The maximum video length is limited to 100 frames for the demo. To process longer videos, please use the ChronoDepth on github. | |
</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video( | |
label="Input Video", | |
sources=["upload"], | |
) | |
with gr.Row(): | |
video_submit_btn = gr.Button( | |
value="Compute Depth", variant="primary" | |
) | |
video_reset_btn = gr.Button(value="Reset") | |
with gr.Column(): | |
video_output_video = gr.Video( | |
label="Output video depth (red-near, blue-far)", | |
interactive=False, | |
) | |
video_output_files = gr.Files( | |
label="Depth outputs", | |
elem_id="download", | |
interactive=False, | |
) | |
gr.Examples( | |
examples=[ | |
["files/elephant.mp4"], | |
["files/kitti360_seq_0000.mp4"], | |
], | |
inputs=[video_input], | |
outputs=[video_output_video, video_output_files], | |
fn=process_pipe_video, | |
cache_examples=True, | |
cache_mode="examples_video", | |
) | |
video_submit_btn.click( | |
fn=process_pipe_video, | |
inputs=[video_input], | |
outputs=[video_output_video, video_output_files], | |
concurrency_limit=1, | |
) | |
video_reset_btn.click( | |
fn=lambda: (None, None, None), | |
inputs=[], | |
outputs=[video_input, video_output_video], | |
concurrency_limit=1, | |
) | |
demo.queue( | |
api_open=False, | |
).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
) | |
def main(): | |
CHECKPOINT = "jhshao/ChronoDepth-v1" | |
if "HF_TOKEN_LOGIN" in os.environ: | |
login(token=os.environ["HF_TOKEN_LOGIN"]) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Running on device: {device}") | |
# -------------------- Model -------------------- | |
unet = DiffusersUNetSpatioTemporalConditionModelChronodepth.from_pretrained( | |
CHECKPOINT, | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float16, | |
) | |
pipe = ChronoDepthPipeline.from_pretrained( | |
"stabilityai/stable-video-diffusion-img2vid-xt", | |
unet=unet, | |
torch_dtype=torch.float16, | |
variant="fp16", | |
) | |
pipe.n_tokens = default_n_tokens | |
pipe.chunk_size = default_chunk_size | |
try: | |
pipe.enable_xformers_memory_efficient_attention() | |
except: | |
pass # run without xformers | |
pipe = pipe.to(device) | |
run_demo_server(pipe) | |
if __name__ == "__main__": | |
main() | |