Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,272 Bytes
2c83504 4cce5d7 2c83504 3636bd3 2c83504 0ea72b7 2c83504 3636bd3 2c83504 3636bd3 2c83504 ca43de2 2c83504 ca43de2 2c83504 ca43de2 0f05329 ca43de2 2c83504 0f05329 2c83504 4cce5d7 0f05329 4cce5d7 0f05329 4cce5d7 0f05329 4cce5d7 2c83504 0f05329 2c83504 0f05329 2c83504 4cce5d7 2c83504 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import shutil
import gradio as gr
desc = """
<p align="center">
<a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
</a>
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
</a>
<a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a>
</p>
<p align="justify">
Marigold is the new state-of-the-art depth estimator for images in the wild. Upload your image into the pane on the left side, or expore examples listed in the bottom.
</p>
"""
def download_code():
os.system('git clone https://github.com/prs-eth/Marigold.git')
def find_first_png(directory):
for file in os.listdir(directory):
if file.lower().endswith(".png"):
return os.path.join(directory, file)
return None
def marigold_process(path_input, path_out_vis=None, path_out_pred=None):
if path_out_vis is not None and path_out_pred is not None:
return path_out_vis, path_out_pred
path_input_dir = path_input + ".input"
path_output_dir = path_input + ".output"
os.makedirs(path_input_dir, exist_ok=True)
os.makedirs(path_output_dir, exist_ok=True)
shutil.copy(path_input, path_input_dir)
os.system(
f"cd Marigold && python3 run.py "
f"--input_rgb_dir \"{path_input_dir}\" "
f"--output_dir \"{path_output_dir}\" "
f"--n_infer 10 "
f"--denoise_steps 10 "
)
path_out_vis = find_first_png(path_output_dir + "/depth_colored")
assert path_out_vis is not None, "Processing failed"
path_out_pred = find_first_png(path_output_dir + "/depth_bw")
assert path_out_pred is not None, "Processing failed"
return path_out_vis, path_out_pred
iface = gr.Interface(
title="Marigold Depth Estimation",
description=desc,
thumbnail="marigold_logo_square.jpg",
fn=marigold_process,
inputs=[
gr.Image(
label="Input Image",
type="filepath",
),
gr.Image(
label="Predicted depth (red-near, blue-far)",
type="filepath",
visible=False,
),
gr.Image(
label="Predicted depth",
type="filepath",
visible=False,
),
],
outputs=[
gr.Image(
label="Predicted depth (red-near, blue-far)",
type="pil",
),
gr.Image(
label="Predicted depth",
type="pil",
elem_classes="imgdownload",
),
],
allow_flagging="never",
examples=[
[
os.path.join(os.path.dirname(__file__), "files/bee.jpg"),
os.path.join(os.path.dirname(__file__), "files/bee_vis.jpg"),
os.path.join(os.path.dirname(__file__), "files/bee_pred.jpg"),
],
[
os.path.join(os.path.dirname(__file__), "files/cat.jpg"),
os.path.join(os.path.dirname(__file__), "files/cat_vis.jpg"),
os.path.join(os.path.dirname(__file__), "files/cat_pred.jpg"),
],
[
os.path.join(os.path.dirname(__file__), "files/swings.jpg"),
os.path.join(os.path.dirname(__file__), "files/swings_vis.jpg"),
os.path.join(os.path.dirname(__file__), "files/swings_pred.jpg"),
],
],
css="""
.viewport {
aspect-ratio: 4/3;
}
.imgdownload {
height: 64px;
}
""",
cache_examples=True,
)
if __name__ == "__main__":
download_code()
iface.queue().launch(server_name="0.0.0.0", server_port=7860)
|