Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,742 Bytes
2c83504 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import shutil
import gradio as gr
desc = """
<p align="center">
<a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
</a>
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
</a>
<a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a>
</p>
<p align="justify">
Marigold is the new state-of-the-art depth estimator for images in the wild. Upload your image into the pane on the left side, or expore examples listed in the bottom.
</p>
"""
def init_persistence(purge=False):
if not os.path.exists('/data'):
return
os.environ['ckpt_dir'] = "/data/Marigold_ckpt"
os.environ['TRANSFORMERS_CACHE'] = "/data/hfcache"
os.environ['HF_DATASETS_CACHE'] = "/data/hfcache"
os.environ['HF_HOME'] = "/data/hfcache"
if purge:
os.system("rm -rf /data/Marigold_ckpt/*")
def download_code_weights():
os.system('git clone https://github.com/prs-eth/Marigold.git')
os.system('cd Marigold && bash script/download_weights.sh')
os.system('echo /data && ls -la /data')
os.system('echo /data/Marigold_ckpt && ls -la /data/Marigold_ckpt')
os.system('echo /data/Marigold_ckpt/Marigold_v1_merged && ls -la /data/Marigold_ckpt/Marigold_v1_merged')
def find_first_png(directory):
for file in os.listdir(directory):
if file.lower().endswith(".png"):
return os.path.join(directory, file)
return None
def marigold_process(path_input, path_out_png=None, path_out_obj=None, path_out_2_png=None):
if path_out_png is not None and path_out_obj is not None and path_out_2_png is not None:
return path_out_png, path_out_obj, path_out_2_png
path_input_dir = path_input + ".input"
path_output_dir = path_input + ".output"
os.makedirs(path_input_dir, exist_ok=True)
os.makedirs(path_output_dir, exist_ok=True)
shutil.copy(path_input, path_input_dir)
persistence_args = ""
if os.path.exists('/data'):
persistence_args = "--checkpoint /data/Marigold_ckpt/Marigold_v1_merged"
os.system(
f"cd Marigold && python3 run.py "
f"{persistence_args} "
f"--input_rgb_dir \"{path_input_dir}\" "
f"--output_dir \"{path_output_dir}\" "
f"--n_infer 5 "
f"--denoise_steps 10 "
)
# depth_colored, depth_bw, depth_npy
path_out_colored = find_first_png(path_output_dir + "/depth_colored")
assert path_out_colored is not None, "Processing failed"
path_out_bw = find_first_png(path_output_dir + "/depth_bw")
assert path_out_bw is not None, "Processing failed"
return path_out_colored, path_out_bw
iface = gr.Interface(
title="Marigold Depth Estimation",
description=desc,
thumbnail="marigold_logo_square.jpg",
fn=marigold_process,
inputs=[
gr.Image(
label="Input Image",
type="filepath",
),
gr.File(
label="Predicted depth (red-near, blue-far)",
visible=False,
),
gr.File(
label="Predicted depth (16-bit PNG)",
visible=False,
),
],
outputs=[
gr.Image(
label="Predicted depth (red-near, blue-far)",
type="pil",
),
gr.Image(
label="Predicted depth (16-bit PNG)",
type="pil",
elem_classes="imgdownload",
),
],
allow_flagging="never",
# examples=[
# [
# os.path.join(os.path.dirname(__file__), "files/test.png"),
# os.path.join(os.path.dirname(__file__), "files/test.png.out.png"),
# os.path.join(os.path.dirname(__file__), "files/test.png.out.2.png"),
# ],
# ],
css="""
.viewport {
aspect-ratio: 4/3;
}
.imgdownload {
height: 32px;
}
""",
cache_examples=True,
)
if __name__ == "__main__":
init_persistence()
download_code_weights()
iface.queue().launch(server_name="0.0.0.0", server_port=7860)
|