File size: 4,742 Bytes
2c83504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import shutil

import gradio as gr


desc = """
    <p align="center">
    <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
    </a>
    <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
    </a>
    <a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
    </a>
    <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
    </a>
    </p>
    <p align="justify">
    Marigold is the new state-of-the-art depth estimator for images in the wild. Upload your image into the pane on the left side, or expore examples listed in the bottom.  
    </p>
"""


def init_persistence(purge=False):
    if not os.path.exists('/data'):
        return
    os.environ['ckpt_dir'] = "/data/Marigold_ckpt"
    os.environ['TRANSFORMERS_CACHE'] = "/data/hfcache"
    os.environ['HF_DATASETS_CACHE'] = "/data/hfcache"
    os.environ['HF_HOME'] = "/data/hfcache"
    if purge:
        os.system("rm -rf /data/Marigold_ckpt/*")


def download_code_weights():
    os.system('git clone https://github.com/prs-eth/Marigold.git')
    os.system('cd Marigold && bash script/download_weights.sh')
    os.system('echo /data && ls -la /data')
    os.system('echo /data/Marigold_ckpt && ls -la /data/Marigold_ckpt')
    os.system('echo /data/Marigold_ckpt/Marigold_v1_merged && ls -la /data/Marigold_ckpt/Marigold_v1_merged')


def find_first_png(directory):
    for file in os.listdir(directory):
        if file.lower().endswith(".png"):
            return os.path.join(directory, file)
    return None


def marigold_process(path_input, path_out_png=None, path_out_obj=None, path_out_2_png=None):
    if path_out_png is not None and path_out_obj is not None and path_out_2_png is not None:
        return path_out_png, path_out_obj, path_out_2_png

    path_input_dir = path_input + ".input"
    path_output_dir = path_input + ".output"
    os.makedirs(path_input_dir, exist_ok=True)
    os.makedirs(path_output_dir, exist_ok=True)
    shutil.copy(path_input, path_input_dir)

    persistence_args = ""
    if os.path.exists('/data'):
        persistence_args = "--checkpoint /data/Marigold_ckpt/Marigold_v1_merged"

    os.system(
        f"cd Marigold && python3 run.py "
        f"{persistence_args} "
        f"--input_rgb_dir \"{path_input_dir}\" "
        f"--output_dir \"{path_output_dir}\" "
        f"--n_infer 5 "
        f"--denoise_steps 10 "
    )

    # depth_colored, depth_bw, depth_npy
    path_out_colored = find_first_png(path_output_dir + "/depth_colored")
    assert path_out_colored is not None, "Processing failed"
    path_out_bw = find_first_png(path_output_dir + "/depth_bw")
    assert path_out_bw is not None, "Processing failed"

    return path_out_colored, path_out_bw


iface = gr.Interface(
    title="Marigold Depth Estimation",
    description=desc,
    thumbnail="marigold_logo_square.jpg",
    fn=marigold_process,
    inputs=[
        gr.Image(
            label="Input Image",
            type="filepath",
        ),
        gr.File(
            label="Predicted depth (red-near, blue-far)",
            visible=False,
        ),
        gr.File(
            label="Predicted depth (16-bit PNG)",
            visible=False,
        ),
    ],
    outputs=[
        gr.Image(
            label="Predicted depth (red-near, blue-far)",
            type="pil",
        ),
        gr.Image(
            label="Predicted depth (16-bit PNG)",
            type="pil",
            elem_classes="imgdownload",
        ),
    ],
    allow_flagging="never",
    # examples=[
    #     [
    #         os.path.join(os.path.dirname(__file__), "files/test.png"),
    #         os.path.join(os.path.dirname(__file__), "files/test.png.out.png"),
    #         os.path.join(os.path.dirname(__file__), "files/test.png.out.2.png"),
    #     ],
    # ],
    css="""
    .viewport {
        aspect-ratio: 4/3;
    }
    .imgdownload {
        height: 32px;
    }
    """,
    cache_examples=True,
)


if __name__ == "__main__":
    init_persistence()
    download_code_weights()
    iface.queue().launch(server_name="0.0.0.0", server_port=7860)