Spaces:
dylanebert
/
Running on Zero

File size: 3,501 Bytes
90fd8f8
1803df3
 
26b3ad9
b7c04cb
26b3ad9
 
 
bab8ce7
1803df3
 
 
 
 
1a5d02b
26b3ad9
92f2e1f
90fd8f8
26b3ad9
 
 
 
90fd8f8
 
26b3ad9
90fd8f8
26b3ad9
 
 
 
90fd8f8
 
26b3ad9
 
 
b7c04cb
26b3ad9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90fd8f8
92f2e1f
90fd8f8
26b3ad9
90fd8f8
26b3ad9
92f2e1f
 
 
 
 
 
26b3ad9
92f2e1f
 
90fd8f8
26b3ad9
 
 
92f2e1f
90fd8f8
 
26b3ad9
90fd8f8
26b3ad9
 
90fd8f8
26b3ad9
 
 
 
 
 
 
 
 
 
 
 
 
90fd8f8
 
 
dec9ec5
 
26b3ad9
 
 
 
90fd8f8
 
e2beaf5
 
 
 
 
 
90fd8f8
 
92f2e1f
26b3ad9
e2beaf5
26b3ad9
90fd8f8
26b3ad9
92f2e1f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import shlex
import subprocess
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
from gradio_client import Client, file

subprocess.run(
    shlex.split(
        "pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"
    )
)

TMP_DIR = "/tmp"
os.makedirs(TMP_DIR, exist_ok=True)


image_pipeline = DiffusionPipeline.from_pretrained(
    "ashawkey/imagedream-ipmv-diffusers",
    custom_pipeline="dylanebert/multi_view_diffusion",
    torch_dtype=torch.float16,
    trust_remote_code=True,
).to("cuda")


splat_pipeline = DiffusionPipeline.from_pretrained(
    "dylanebert/LGM",
    custom_pipeline="dylanebert/LGM",
    torch_dtype=torch.float16,
    trust_remote_code=True,
).to("cuda")


@spaces.GPU
def run(input_image, convert):
    input_image = input_image.astype("float32") / 255.0
    images = image_pipeline(
        "", input_image, guidance_scale=5, num_inference_steps=30, elevation=0
    )
    gaussians = splat_pipeline(images)
    output_ply_path = os.path.join(TMP_DIR, "output.ply")
    splat_pipeline.save_ply(gaussians, output_ply_path)
    if convert:
        output_mesh_path = convert_to_mesh(output_ply_path)
        return output_mesh_path
    else:
        return output_ply_path


def convert_to_mesh(input_ply):
    client = Client("https://dylanebert-splat-to-mesh.hf.space/")
    output_mesh_path = client.predict(file(input_ply), api_name="/run")
    client.close()
    return output_mesh_path


_TITLE = """LGM Mini"""

_DESCRIPTION = """
<div>
A lightweight version of <a href="https://huggingface.co/spaces/ashawkey/LGM">LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation</a>.
</div>
"""

css = """
#duplicate-button {
    margin: auto;
    color: white;
    background: #1565c0;
    border-radius: 100vh;
}
"""

block = gr.Blocks(title=_TITLE, css=css)
with block:
    gr.DuplicateButton(
        value="Duplicate Space for private use", elem_id="duplicate-button"
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("# " + _TITLE)
    gr.Markdown(_DESCRIPTION)

    with gr.Row(variant="panel"):
        with gr.Column(scale=1):

            def update_warning(checked):
                if checked:
                    return '<span style="color: #ff0000;">Warning: Mesh conversion takes several minutes</span>'
                else:
                    return ""

            input_image = gr.Image(label="image", type="numpy")
            convert_checkbox = gr.Checkbox(label="Convert to Mesh")
            warning = gr.HTML()
            convert_checkbox.change(
                fn=update_warning, inputs=[convert_checkbox], outputs=[warning]
            )
            button_gen = gr.Button("Generate")

        with gr.Column(scale=1):
            output_splat = gr.Model3D(label="3D Gaussians")

        button_gen.click(
            fn=run, inputs=[input_image, convert_checkbox], outputs=[output_splat]
        )

    gr.Examples(
        examples=[
            "data_test/frog_sweater.jpg",
            "data_test/bird.jpg",
            "data_test/boy.jpg",
            "data_test/cat_statue.jpg",
            "data_test/dragontoy.jpg",
            "data_test/gso_rabbit.jpg",
        ],
        inputs=[input_image],
        outputs=[output_splat],
        fn=lambda x: run(input_image=x, convert=False),
        cache_examples=True,
        label="Image-to-3D Examples",
    )

block.queue().launch(debug=True, share=True)