Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import shutil | |
import gradio as gr | |
desc = """ | |
<p align="center"> | |
<a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-website.svg"> | |
</a> | |
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> | |
</a> | |
<a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
</a> | |
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
</a> | |
</p> | |
<p align="justify"> | |
Marigold is the new state-of-the-art depth estimator for images in the wild. Upload your image into the pane on the left side, or expore examples listed in the bottom. | |
</p> | |
""" | |
def init_persistence(purge=False): | |
if not os.path.exists('/data'): | |
return | |
os.environ['ckpt_dir'] = "/data/Marigold_ckpt" | |
os.environ['TRANSFORMERS_CACHE'] = "/data/hfcache" | |
os.environ['HF_DATASETS_CACHE'] = "/data/hfcache" | |
os.environ['HF_HOME'] = "/data/hfcache" | |
if purge: | |
os.system("rm -rf /data/Marigold_ckpt/*") | |
def download_code_weights(): | |
os.system('git clone https://github.com/prs-eth/Marigold.git') | |
os.system('cd Marigold && bash script/download_weights.sh') | |
os.system('echo /data && ls -la /data') | |
os.system('echo /data/Marigold_ckpt && ls -la /data/Marigold_ckpt') | |
os.system('echo /data/Marigold_ckpt/Marigold_v1_merged && ls -la /data/Marigold_ckpt/Marigold_v1_merged') | |
def find_first_png(directory): | |
for file in os.listdir(directory): | |
if file.lower().endswith(".png"): | |
return os.path.join(directory, file) | |
return None | |
def marigold_process(path_input, path_out_png=None, path_out_obj=None, path_out_2_png=None): | |
if path_out_png is not None and path_out_obj is not None and path_out_2_png is not None: | |
return path_out_png, path_out_obj, path_out_2_png | |
path_input_dir = path_input + ".input" | |
path_output_dir = path_input + ".output" | |
os.makedirs(path_input_dir, exist_ok=True) | |
os.makedirs(path_output_dir, exist_ok=True) | |
shutil.copy(path_input, path_input_dir) | |
persistence_args = "" | |
if os.path.exists('/data'): | |
persistence_args = "--checkpoint /data/Marigold_ckpt/Marigold_v1_merged" | |
os.system( | |
f"cd Marigold && python3 run.py " | |
f"{persistence_args} " | |
f"--input_rgb_dir \"{path_input_dir}\" " | |
f"--output_dir \"{path_output_dir}\" " | |
f"--n_infer 5 " | |
f"--denoise_steps 10 " | |
) | |
# depth_colored, depth_bw, depth_npy | |
path_out_colored = find_first_png(path_output_dir + "/depth_colored") | |
assert path_out_colored is not None, "Processing failed" | |
path_out_bw = find_first_png(path_output_dir + "/depth_bw") | |
assert path_out_bw is not None, "Processing failed" | |
return path_out_colored, path_out_bw | |
iface = gr.Interface( | |
title="Marigold Depth Estimation", | |
description=desc, | |
thumbnail="marigold_logo_square.jpg", | |
fn=marigold_process, | |
inputs=[ | |
gr.Image( | |
label="Input Image", | |
type="filepath", | |
), | |
gr.File( | |
label="Predicted depth (red-near, blue-far)", | |
visible=False, | |
), | |
gr.File( | |
label="Predicted depth (16-bit PNG)", | |
visible=False, | |
), | |
], | |
outputs=[ | |
gr.Image( | |
label="Predicted depth (red-near, blue-far)", | |
type="pil", | |
), | |
gr.Image( | |
label="Predicted depth (16-bit PNG)", | |
type="pil", | |
elem_classes="imgdownload", | |
), | |
], | |
allow_flagging="never", | |
# examples=[ | |
# [ | |
# os.path.join(os.path.dirname(__file__), "files/test.png"), | |
# os.path.join(os.path.dirname(__file__), "files/test.png.out.png"), | |
# os.path.join(os.path.dirname(__file__), "files/test.png.out.2.png"), | |
# ], | |
# ], | |
css=""" | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
.imgdownload { | |
height: 32px; | |
} | |
""", | |
cache_examples=True, | |
) | |
if __name__ == "__main__": | |
init_persistence() | |
download_code_weights() | |
iface.queue().launch(server_name="0.0.0.0", server_port=7860) | |