ChronoDepth / app.py
jhaoshao
release v1 demo
861fa04
raw
history blame
11.3 kB
# MIT License
# Copyright (c) 2024 Jiahao Shao
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import functools
import os
import zipfile
import tempfile
from io import BytesIO
import spaces
import gradio as gr
import numpy as np
import torch as torch
import torch.nn.functional as F
import xformers
from PIL import Image
from tqdm import tqdm
import mediapy as media
from huggingface_hub import login
from chronodepth.unet_chronodepth import DiffusersUNetSpatioTemporalConditionModelChronodepth
from chronodepth.chronodepth_pipeline import ChronoDepthPipeline
from chronodepth.video_utils import resize_max_res, colorize_video_depth
MAX_FRAME=15
default_seed = 2024
default_num_inference_steps = 5
default_n_tokens = 10
default_chunk_size = 5
default_video_processing_resolution = 768
default_decode_chunk_size = 8
@torch.no_grad()
def run_pipeline(pipe, video_rgb, generator, device):
"""
Run the pipe on the input video.
args:
pipe: ChronoDepthPipeline object
video_rgb: input video, torch.Tensor, shape [T, H, W, 3], range [0, 255]
generator: torch.Generator
returns:
video_depth_pred: predicted depth, torch.Tensor, shape [T, H, W], range [0, 1]
"""
if isinstance(video_rgb, torch.Tensor):
video_rgb = video_rgb.cpu().numpy()
original_height = video_rgb.shape[1]
original_width = video_rgb.shape[2]
# resize the video to the max resolution
video_rgb = resize_max_res(video_rgb, default_video_processing_resolution)
video_rgb = video_rgb.astype(np.float32) / 255.0
pipe_out = pipe(
video_rgb,
num_inference_steps=default_num_inference_steps,
decode_chunk_size=default_decode_chunk_size,
motion_bucket_id=127,
fps=7,
noise_aug_strength=0.0,
generator=generator,
infer_mode="ours",
sigma_epsilon=-4,
)
depth_frames_pred = pipe_out.frames
depth_frames_pred = torch.from_numpy(depth_frames_pred).to(device)
depth_frames_pred = F.interpolate(depth_frames_pred, size=(original_height, original_width), mode="bilinear", align_corners=False)
depth_frames_pred = depth_frames_pred.clamp(0, 1)
depth_frames_pred = depth_frames_pred.squeeze(1)
return depth_frames_pred
def process_video(
pipe,
path_input,
num_inference_steps=default_num_inference_steps,
out_max_frames=MAX_FRAME,
progress=gr.Progress(),
):
if path_input is None:
raise gr.Error(
"Missing video in the first pane: upload a file or use one from the gallery below."
)
name_base, name_ext = os.path.splitext(os.path.basename(path_input))
print(f"Processing video {name_base}{name_ext}")
path_output_dir = tempfile.mkdtemp()
path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4")
path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.zip")
generator = torch.Generator(device=pipe.device).manual_seed(default_seed)
import time
start_time = time.time()
zipf = None
try:
# -------------------- data --------------------
video_name = path_input.split('/')[-1].split('.')[0]
video_data = media.read_video(path_input)
fps = video_data.metadata.fps
video_length = len(video_data)
video_rgb = np.array(video_data)
duration_sec = video_length / fps
out_duration_sec = out_max_frames / fps
if duration_sec > out_duration_sec:
gr.Warning(
f"Only the first ~{int(out_duration_sec)} seconds will be processed; "
f"use alternative setups such as ChronoDepth on github for full processing"
)
video_rgb = video_rgb[:out_max_frames]
zipf = zipfile.ZipFile(path_out_16bit, "w", zipfile.ZIP_DEFLATED)
# -------------------- Inference and saving --------------------
depth_pred = run_pipeline(pipe, video_rgb, generator, pipe.device) # range [0, 1]
depth_pred = depth_pred.cpu().numpy()
depth_colored_pred = colorize_video_depth(depth_pred) # range [0, 1] -> [0, 255]
# -------------------- Save results --------------------
for i in tqdm(range(len(depth_pred))):
archive_path = os.path.join(
f"{name_base}_depth_16bit", f"{i:05d}.png"
)
img_byte_arr = BytesIO()
depth_16bit = Image.fromarray((depth_pred[i] * 65535.0).astype(np.uint16))
depth_16bit.save(img_byte_arr, format="png")
img_byte_arr.seek(0)
zipf.writestr(archive_path, img_byte_arr.read())
# Export to video
media.write_video(path_out_vis, depth_colored_pred, fps=fps)
finally:
if zipf is not None:
zipf.close()
end_time = time.time()
print(f"Processing time: {end_time - start_time} seconds")
return (
path_out_vis,
[path_out_vis, path_out_16bit],
)
def run_demo_server(pipe):
process_pipe_video = spaces.GPU(
functools.partial(process_video, pipe), duration=100
)
os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
with gr.Blocks(
analytics_enabled=False,
title="ChronoDepth Video Depth Estimation",
css="""
#download {
height: 118px;
}
.slider .inner {
width: 5px;
background: #FFF;
}
.viewport {
aspect-ratio: 4/3;
}
h1 {
text-align: center;
display: block;
}
h2 {
text-align: center;
display: block;
}
h3 {
text-align: center;
display: block;
}
""",
) as demo:
gr.HTML(
"""
<h1>⏰ChronoDepth: Learning Temporally Consistent Video Depth from Video Diffusion Priors</h1>
<div style="text-align: center; margin-top: 20px;">
<a title="Website" href="https://jhaoshao.github.io/ChronoDepth/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/website?url=https%3A%2F%2Fjhaoshao.github.io%2FChronoDepth%2F&up_message=ChronoDepth&up_color=blue&style=flat&logo=timescale&logoColor=%23FFDC0F">
</a>
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/arXiv-PDF-b31b1b">
</a>
<a title="Github" href="https://github.com/jhaoshao/ChronoDepth" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/jhaoshao/ChronoDepth?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
</div>
<p style="margin-top: 20px; text-align: justify;">
ChronoDepth is the state-of-the-art video depth estimator for streaming videos in the wild.
</p>
<p style="margin-top: 20px; text-align: justify;">
PS: The maximum video length is limited to 100 frames for the demo. To process longer videos, please use the ChronoDepth on github.
</p>
"""
)
with gr.Row():
with gr.Column():
video_input = gr.Video(
label="Input Video",
sources=["upload"],
)
with gr.Row():
video_submit_btn = gr.Button(
value="Compute Depth", variant="primary"
)
video_reset_btn = gr.Button(value="Reset")
with gr.Column():
video_output_video = gr.Video(
label="Output video depth (red-near, blue-far)",
interactive=False,
)
video_output_files = gr.Files(
label="Depth outputs",
elem_id="download",
interactive=False,
)
gr.Examples(
examples=[
["files/elephant.mp4"],
["files/kitti360_seq_0000.mp4"],
],
inputs=[video_input],
outputs=[video_output_video, video_output_files],
fn=process_pipe_video,
cache_examples=True,
cache_mode="examples_video",
)
video_submit_btn.click(
fn=process_pipe_video,
inputs=[video_input],
outputs=[video_output_video, video_output_files],
concurrency_limit=1,
)
video_reset_btn.click(
fn=lambda: (None, None, None),
inputs=[],
outputs=[video_input, video_output_video],
concurrency_limit=1,
)
demo.queue(
api_open=False,
).launch(
server_name="0.0.0.0",
server_port=7860,
)
def main():
CHECKPOINT = "jhshao/ChronoDepth-v1"
if "HF_TOKEN_LOGIN" in os.environ:
login(token=os.environ["HF_TOKEN_LOGIN"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")
# -------------------- Model --------------------
unet = DiffusersUNetSpatioTemporalConditionModelChronodepth.from_pretrained(
CHECKPOINT,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
)
pipe = ChronoDepthPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt",
unet=unet,
torch_dtype=torch.float16,
variant="fp16",
)
pipe.n_tokens = default_n_tokens
pipe.chunk_size = default_chunk_size
try:
pipe.enable_xformers_memory_efficient_attention()
except:
pass # run without xformers
pipe = pipe.to(device)
run_demo_server(pipe)
if __name__ == "__main__":
main()