File size: 4,272 Bytes
2c83504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cce5d7
2c83504
 
 
 
 
 
 
 
 
 
3636bd3
 
 
2c83504
 
 
 
 
 
 
 
 
 
 
0ea72b7
2c83504
 
 
3636bd3
 
 
 
2c83504
3636bd3
2c83504
 
 
 
 
 
 
 
 
 
 
 
ca43de2
2c83504
ca43de2
2c83504
 
ca43de2
0f05329
ca43de2
2c83504
 
 
 
 
 
 
 
 
0f05329
2c83504
 
 
 
 
4cce5d7
 
 
0f05329
 
4cce5d7
 
 
0f05329
 
4cce5d7
 
 
0f05329
 
4cce5d7
 
2c83504
 
 
 
 
0f05329
2c83504
 
0f05329
2c83504
 
 
 
4cce5d7
2c83504
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import shutil

import gradio as gr


desc = """
    <p align="center">
    <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
    </a>
    <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
    </a>
    <a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
    </a>
    <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
    </a>
    </p>
    <p align="justify">
    Marigold is the new state-of-the-art depth estimator for images in the wild. Upload your image into the pane on the left side, or expore examples listed in the bottom.  
    </p>
"""


def download_code():
    os.system('git clone https://github.com/prs-eth/Marigold.git')


def find_first_png(directory):
    for file in os.listdir(directory):
        if file.lower().endswith(".png"):
            return os.path.join(directory, file)
    return None


def marigold_process(path_input, path_out_vis=None, path_out_pred=None):
    if path_out_vis is not None and path_out_pred is not None:
        return path_out_vis, path_out_pred

    path_input_dir = path_input + ".input"
    path_output_dir = path_input + ".output"
    os.makedirs(path_input_dir, exist_ok=True)
    os.makedirs(path_output_dir, exist_ok=True)
    shutil.copy(path_input, path_input_dir)

    os.system(
        f"cd Marigold && python3 run.py "
        f"--input_rgb_dir \"{path_input_dir}\" "
        f"--output_dir \"{path_output_dir}\" "
        f"--n_infer 10 "
        f"--denoise_steps 10 "
    )

    path_out_vis = find_first_png(path_output_dir + "/depth_colored")
    assert path_out_vis is not None, "Processing failed"
    path_out_pred = find_first_png(path_output_dir + "/depth_bw")
    assert path_out_pred is not None, "Processing failed"

    return path_out_vis, path_out_pred


iface = gr.Interface(
    title="Marigold Depth Estimation",
    description=desc,
    thumbnail="marigold_logo_square.jpg",
    fn=marigold_process,
    inputs=[
        gr.Image(
            label="Input Image",
            type="filepath",
        ),
        gr.Image(
            label="Predicted depth (red-near, blue-far)",
            type="filepath",
            visible=False,
        ),
        gr.Image(
            label="Predicted depth",
            type="filepath",
            visible=False,
        ),
    ],
    outputs=[
        gr.Image(
            label="Predicted depth (red-near, blue-far)",
            type="pil",
        ),
        gr.Image(
            label="Predicted depth",
            type="pil",
            elem_classes="imgdownload",
        ),
    ],
    allow_flagging="never",
    examples=[
        [
            os.path.join(os.path.dirname(__file__), "files/bee.jpg"),
            os.path.join(os.path.dirname(__file__), "files/bee_vis.jpg"),
            os.path.join(os.path.dirname(__file__), "files/bee_pred.jpg"),
        ],
        [
            os.path.join(os.path.dirname(__file__), "files/cat.jpg"),
            os.path.join(os.path.dirname(__file__), "files/cat_vis.jpg"),
            os.path.join(os.path.dirname(__file__), "files/cat_pred.jpg"),
        ],
        [
            os.path.join(os.path.dirname(__file__), "files/swings.jpg"),
            os.path.join(os.path.dirname(__file__), "files/swings_vis.jpg"),
            os.path.join(os.path.dirname(__file__), "files/swings_pred.jpg"),
        ],
    ],
    css="""
    .viewport {
        aspect-ratio: 4/3;
    }
    .imgdownload {
        height: 64px;
    }
    """,
    cache_examples=True,
)


if __name__ == "__main__":
    download_code()
    iface.queue().launch(server_name="0.0.0.0", server_port=7860)