marigold-dc / app.py
toshas's picture
Update app.py
f829455 verified
raw
history blame
10.6 kB
# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
# More information about the method can be found at https://marigoldmonodepth.github.io
# --------------------------------------------------------------------------
import functools
import os
import spaces
import gradio as gr
import numpy as np
import plotly.graph_objects as go
import torch as torch
from PIL import Image
from scipy.ndimage import maximum_filter
from marigold_dc import MarigoldDepthCompletionPipeline
from gradio_imageslider import ImageSlider
from huggingface_hub import login
DRY_RUN = False
def dilate_rgb_image(image, kernel_size):
r_channel, g_channel, b_channel = image[..., 0], image[..., 1], image[..., 2]
r_dilated = maximum_filter(r_channel, size=kernel_size)
g_dilated = maximum_filter(g_channel, size=kernel_size)
b_dilated = maximum_filter(b_channel, size=kernel_size)
dilated_image = np.stack([r_dilated, g_dilated, b_dilated], axis=-1)
return dilated_image
def generate_rmse_plot(steps, metrics, denoise_steps):
y_min = min(metrics)
y_max = max(metrics)
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=steps,
y=metrics,
mode="lines+markers",
line=dict(color="#af2928"),
name="RMSE",
)
)
if denoise_steps < 20:
x_dtick = 1
else:
x_dtick = 5
fig.update_layout(
autosize=False,
height=300,
xaxis_title="Steps",
xaxis_range=[0, denoise_steps + 1],
xaxis=dict(
scaleanchor="y",
scaleratio=1.5,
dtick=x_dtick,
),
yaxis_title="RMSE",
yaxis_range=[np.log10(max(y_min - 0.1, 0.1)), np.log10(y_max + 1)],
yaxis=dict(
type="log",
),
hovermode="x unified",
template="plotly_white",
)
return fig
def process(
pipe,
path_image,
path_sparse,
denoise_steps,
):
image = Image.open(path_image)
sparse_depth = np.load(path_sparse)
sparse_depth_valid = sparse_depth[sparse_depth > 0]
sparse_depth_min = np.min(sparse_depth_valid)
sparse_depth_max = np.max(sparse_depth_valid)
width, height = image.size
max_dim = max(width, height)
processing_resolution = 0
if max_dim > 768:
processing_resolution = 768
metrics = []
steps = []
for step, (pred, rmse) in enumerate(
pipe(
image=Image.open(path_image),
sparse_depth=sparse_depth,
num_inference_steps=denoise_steps + 1,
processing_resolution=processing_resolution,
dry_run=DRY_RUN,
)
):
min_both = min(sparse_depth_min, pred.min().item())
max_both = min(sparse_depth_max, pred.max().item())
metrics.append(rmse)
steps.append(step)
vis_pred = pipe.image_processor.visualize_depth(
pred, val_min=min_both, val_max=max_both
)[0]
vis_sparse = pipe.image_processor.visualize_depth(
sparse_depth, val_min=min_both, val_max=max_both
)[0]
vis_sparse = np.array(vis_sparse)
vis_sparse[sparse_depth <= 0] = (0, 0, 0)
vis_sparse = dilate_rgb_image(vis_sparse, kernel_size=5)
vis_sparse = Image.fromarray(vis_sparse)
plot = generate_rmse_plot(steps, metrics, denoise_steps)
yield (
[vis_sparse, vis_pred],
plot,
)
def run_demo_server(pipe):
process_pipe = spaces.GPU(functools.partial(process, pipe))
os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
with gr.Blocks(
analytics_enabled=False,
title="Marigold Depth Completion",
css="""
#short {
height: 130px;
}
.slider .inner {
width: 4px;
background: #FFF;
}
.slider .icon-wrap svg {
fill: #FFF;
stroke: #FFF;
stroke-width: 3px;
}
.viewport {
aspect-ratio: 4/3;
}
h1 {
text-align: center;
display: block;
}
h2 {
text-align: center;
display: block;
}
h3 {
text-align: center;
display: block;
}
""",
) as demo:
gr.HTML(
"""
<h1>⇆ Marigold-DC: Zero-Shot Monocular Depth Completion with Guided Diffusion</h1>
<p align="center">
<a title="Website" href="https://MarigoldDepthCompletion.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%F0%9F%A4%8D%20Project%20-Website-blue" alt="Website Badge">
</a>
<a title="arXiv" href="https://arxiv.org/abs/2412.13389" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-af2928" alt="arXiv Badge">
</a>
<a title="Github" href="https://github.com/prs-eth/marigold-dc" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/prs-eth/marigold-dc?label=GitHub&logo=github&color=C8C" alt="badge-github-stars">
</a>
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a><br>
Start exploring the interactive examples at the bottom of the page!
</p>
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="filepath",
)
input_sparse = gr.File(
label="Input sparse depth (numpy file)",
elem_id="short",
)
with gr.Accordion("Advanced options", open=False):
denoise_steps = gr.Slider(
label="Number of denoising steps",
minimum=10,
maximum=50,
step=1,
value=50,
)
with gr.Row():
submit_btn = gr.Button(value="Compute Depth", variant="primary")
clear_btn = gr.Button(value="Clear")
with gr.Column():
output_slider = ImageSlider(
label="Completed depth (red-near, blue-far)",
type="filepath",
show_download_button=True,
show_share_button=True,
interactive=False,
elem_classes="slider",
position=0.25,
)
plot = gr.Plot(
label="RMSE between input and result",
elem_id="viewport",
)
inputs = [
input_image,
input_sparse,
denoise_steps,
]
outputs = [
output_slider,
plot,
]
def submit_depth_fn(path_image, path_sparse, denoise_steps):
for outputs in process_pipe(path_image, path_sparse, denoise_steps):
yield outputs
submit_btn.click(
fn=submit_depth_fn,
inputs=inputs,
outputs=outputs,
)
gr.Examples(
fn=submit_depth_fn,
examples=[
[
"files/kitti_1.png",
"files/kitti_1.npy",
10, # denoise_steps
],
[
"files/kitti_2.png",
"files/kitti_2.npy",
10, # denoise_steps
],
[
"files/teaser.png",
"files/teaser_1000.npy",
10, # denoise_steps
],
[
"files/teaser.png",
"files/teaser_100.npy",
10, # denoise_steps
],
[
"files/teaser.png",
"files/teaser_10.npy",
10, # denoise_steps
],
],
inputs=inputs,
outputs=outputs,
cache_examples="lazy",
)
def clear_fn():
return [
gr.Image(value=None, interactive=True),
gr.File(None, interactive=True),
None,
]
clear_btn.click(
fn=clear_fn,
inputs=[],
outputs=[
input_image,
input_sparse,
output_slider,
],
)
demo.queue(
api_open=False,
).launch(
server_name="0.0.0.0",
server_port=7860,
)
def main():
CHECKPOINT = "prs-eth/marigold-depth-v1-0"
os.system("pip freeze")
if "HF_TOKEN_LOGIN" in os.environ:
login(token=os.environ["HF_TOKEN_LOGIN"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe = MarigoldDepthCompletionPipeline.from_pretrained(CHECKPOINT)
try:
import xformers
pipe.enable_xformers_memory_efficient_attention()
except:
pass # run without xformers
pipe = pipe.to(device)
run_demo_server(pipe)
if __name__ == "__main__":
main()