|
import functools |
|
import os |
|
import shutil |
|
import zipfile |
|
from io import BytesIO |
|
|
|
import spaces |
|
import gradio as gr |
|
import imageio as imageio |
|
import numpy as np |
|
import torch as torch |
|
from PIL import Image |
|
from diffusers import UNet2DConditionModel, LCMScheduler |
|
from gradio_imageslider import ImageSlider |
|
from huggingface_hub import login |
|
from tqdm import tqdm |
|
|
|
from extrude import extrude_depth_3d |
|
from marigold_depth_estimation_lcm import MarigoldDepthConsistencyPipeline |
|
|
|
default_seed = 2024 |
|
|
|
default_image_denoise_steps = 4 |
|
default_image_ensemble_size = 1 |
|
default_image_processing_res = 768 |
|
default_image_reproducuble = True |
|
|
|
default_video_depth_latent_init_strength = 0.1 |
|
default_video_denoise_steps = 1 |
|
default_video_ensemble_size = 1 |
|
default_video_processing_res = 768 |
|
default_video_out_fps = 15 |
|
default_video_out_max_frames = 100 |
|
|
|
default_bas_plane_near = 0.0 |
|
default_bas_plane_far = 1.0 |
|
default_bas_embossing = 20 |
|
default_bas_denoise_steps = 4 |
|
default_bas_ensemble_size = 1 |
|
default_bas_processing_res = 768 |
|
default_bas_size_longest_px = 512 |
|
default_bas_size_longest_cm = 10 |
|
default_bas_filter_size = 3 |
|
default_bas_frame_thickness = 5 |
|
default_bas_frame_near = 1 |
|
default_bas_frame_far = 1 |
|
|
|
|
|
def process_image( |
|
pipe, |
|
path_input, |
|
denoise_steps=default_image_denoise_steps, |
|
ensemble_size=default_image_ensemble_size, |
|
processing_res=default_image_processing_res, |
|
reproducible=default_image_reproducuble, |
|
): |
|
input_image = Image.open(path_input) |
|
|
|
pipe_out = pipe( |
|
input_image, |
|
denoising_steps=denoise_steps, |
|
ensemble_size=ensemble_size, |
|
processing_res=processing_res, |
|
batch_size=1 if processing_res == 0 else 0, |
|
seed=default_seed if reproducible else None, |
|
show_progress_bar=False, |
|
) |
|
|
|
depth_pred = pipe_out.depth_np |
|
depth_colored = pipe_out.depth_colored |
|
depth_16bit = (depth_pred * 65535.0).astype(np.uint16) |
|
|
|
path_output_dir = os.path.splitext(path_input)[0] + "_output" |
|
os.makedirs(path_output_dir, exist_ok=True) |
|
|
|
name_base = os.path.splitext(os.path.basename(path_input))[0] |
|
path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_depth_fp32.npy") |
|
path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.png") |
|
path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.png") |
|
|
|
np.save(path_out_fp32, depth_pred) |
|
Image.fromarray(depth_16bit).save(path_out_16bit, mode="I;16") |
|
depth_colored.save(path_out_vis) |
|
|
|
return ( |
|
[path_out_16bit, path_out_vis], |
|
[path_out_16bit, path_out_fp32, path_out_vis], |
|
) |
|
|
|
|
|
def process_video( |
|
pipe, |
|
path_input, |
|
depth_latent_init_strength=default_video_depth_latent_init_strength, |
|
denoise_steps=default_video_denoise_steps, |
|
ensemble_size=default_video_ensemble_size, |
|
processing_res=default_video_processing_res, |
|
out_fps=default_video_out_fps, |
|
out_max_frames=default_video_out_max_frames, |
|
progress=gr.Progress(), |
|
): |
|
path_output_dir = os.path.splitext(path_input)[0] + "_output" |
|
os.makedirs(path_output_dir, exist_ok=True) |
|
|
|
name_base = os.path.splitext(os.path.basename(path_input))[0] |
|
path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4") |
|
path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.zip") |
|
|
|
reader = imageio.get_reader(path_input) |
|
|
|
meta_data = reader.get_meta_data() |
|
fps = meta_data["fps"] |
|
size = meta_data["size"] |
|
duration_sec = meta_data["duration"] |
|
|
|
if fps <= out_fps: |
|
frame_interval, out_fps = 1, fps |
|
else: |
|
frame_interval = round(fps / out_fps) |
|
out_fps = fps / frame_interval |
|
|
|
out_duration_sec = out_max_frames / out_fps |
|
if duration_sec > out_duration_sec: |
|
gr.Warning( |
|
f"Only the first ~{int(out_duration_sec)} seconds will be processed; " |
|
f"use alternative setups for full processing" |
|
) |
|
|
|
writer = imageio.get_writer(path_out_vis, fps=out_fps) |
|
zipf = zipfile.ZipFile(path_out_16bit, "w", zipfile.ZIP_DEFLATED) |
|
prev_depth_latent = None |
|
|
|
pbar = tqdm(desc="Processing Video", total=out_max_frames) |
|
|
|
out_frame_id = 0 |
|
for frame_id, frame in enumerate(reader): |
|
if not (frame_id % frame_interval == 0): |
|
continue |
|
out_frame_id += 1 |
|
pbar.update(1) |
|
if out_frame_id > out_max_frames: |
|
break |
|
|
|
frame_pil = Image.fromarray(frame) |
|
|
|
pipe_out = pipe( |
|
frame_pil, |
|
denoising_steps=denoise_steps, |
|
ensemble_size=ensemble_size, |
|
processing_res=processing_res, |
|
match_input_res=False, |
|
batch_size=0, |
|
depth_latent_init=prev_depth_latent, |
|
depth_latent_init_strength=depth_latent_init_strength, |
|
seed=default_seed, |
|
show_progress_bar=False, |
|
) |
|
|
|
prev_depth_latent = pipe_out.depth_latent |
|
|
|
processed_frame = pipe_out.depth_colored |
|
processed_frame = imageio.core.util.Array(np.array(processed_frame)) |
|
writer.append_data(processed_frame) |
|
|
|
processed_frame = (65535 * np.clip(pipe_out.depth_np, 0.0, 1.0)).astype( |
|
np.uint16 |
|
) |
|
processed_frame = Image.fromarray(processed_frame, mode="I;16") |
|
|
|
archive_path = os.path.join( |
|
f"{name_base}_depth_16bit", f"{out_frame_id:05d}.png" |
|
) |
|
img_byte_arr = BytesIO() |
|
processed_frame.save(img_byte_arr, format="png") |
|
img_byte_arr.seek(0) |
|
zipf.writestr(archive_path, img_byte_arr.read()) |
|
|
|
reader.close() |
|
writer.close() |
|
zipf.close() |
|
|
|
return ( |
|
path_out_vis, |
|
[path_out_vis, path_out_16bit], |
|
) |
|
|
|
|
|
def process_bas( |
|
pipe, |
|
path_input, |
|
plane_near=default_bas_plane_near, |
|
plane_far=default_bas_plane_far, |
|
embossing=default_bas_embossing, |
|
denoise_steps=default_bas_denoise_steps, |
|
ensemble_size=default_bas_ensemble_size, |
|
processing_res=default_bas_processing_res, |
|
size_longest_px=default_bas_size_longest_px, |
|
size_longest_cm=default_bas_size_longest_cm, |
|
filter_size=default_bas_filter_size, |
|
frame_thickness=default_bas_frame_thickness, |
|
frame_near=default_bas_frame_near, |
|
frame_far=default_bas_frame_far, |
|
): |
|
if plane_near >= plane_far: |
|
raise gr.Error("NEAR plane must have a value smaller than the FAR plane") |
|
|
|
path_output_dir = os.path.splitext(path_input)[0] + "_output" |
|
os.makedirs(path_output_dir, exist_ok=True) |
|
|
|
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) |
|
|
|
input_image = Image.open(path_input) |
|
|
|
pipe_out = pipe( |
|
input_image, |
|
denoising_steps=denoise_steps, |
|
ensemble_size=ensemble_size, |
|
processing_res=processing_res, |
|
seed=default_seed, |
|
show_progress_bar=False, |
|
) |
|
|
|
depth_pred = pipe_out.depth_np * 65535 |
|
|
|
def _process_3d( |
|
size_longest_px, |
|
filter_size, |
|
vertex_colors, |
|
scene_lights, |
|
output_model_scale=None, |
|
prepare_for_3d_printing=False, |
|
): |
|
image_rgb_w, image_rgb_h = input_image.width, input_image.height |
|
image_rgb_d = max(image_rgb_w, image_rgb_h) |
|
image_new_w = size_longest_px * image_rgb_w // image_rgb_d |
|
image_new_h = size_longest_px * image_rgb_h // image_rgb_d |
|
|
|
image_rgb_new = os.path.join( |
|
path_output_dir, f"{name_base}_rgb_{size_longest_px}{name_ext}" |
|
) |
|
image_depth_new = os.path.join( |
|
path_output_dir, f"{name_base}_depth_{size_longest_px}.png" |
|
) |
|
input_image.resize((image_new_w, image_new_h), Image.LANCZOS).save( |
|
image_rgb_new |
|
) |
|
Image.fromarray(depth_pred).convert(mode="F").resize( |
|
(image_new_w, image_new_h), Image.BILINEAR |
|
).convert("I").save(image_depth_new) |
|
|
|
path_glb, path_stl = extrude_depth_3d( |
|
image_rgb_new, |
|
image_depth_new, |
|
output_model_scale=size_longest_cm * 10 |
|
if output_model_scale is None |
|
else output_model_scale, |
|
filter_size=filter_size, |
|
coef_near=plane_near, |
|
coef_far=plane_far, |
|
emboss=embossing / 100, |
|
f_thic=frame_thickness / 100, |
|
f_near=frame_near / 100, |
|
f_back=frame_far / 100, |
|
vertex_colors=vertex_colors, |
|
scene_lights=scene_lights, |
|
prepare_for_3d_printing=prepare_for_3d_printing, |
|
) |
|
|
|
return path_glb, path_stl |
|
|
|
path_viewer_glb, _ = _process_3d( |
|
256, filter_size, vertex_colors=False, scene_lights=True, output_model_scale=1 |
|
) |
|
path_files_glb, path_files_stl = _process_3d( |
|
size_longest_px, filter_size, vertex_colors=True, scene_lights=False, prepare_for_3d_printing=True |
|
) |
|
|
|
return path_viewer_glb, [path_files_glb, path_files_stl] |
|
|
|
|
|
def run_demo_server(pipe): |
|
process_pipe_image = spaces.GPU(lambda *args, **kwargs: process_image(pipe, *args, **kwargs)) |
|
process_pipe_video = spaces.GPU(lambda *args, **kwargs: process_video(pipe, *args, **kwargs)) |
|
process_pipe_bas = spaces.GPU(lambda *args, **kwargs: process_bas(pipe, *args, **kwargs)) |
|
os.environ["GRADIO_ALLOW_FLAGGING"] = "never" |
|
|
|
gradio_theme = gr.themes.Default() |
|
|
|
with gr.Blocks( |
|
theme=gradio_theme, |
|
title="Marigold-LCM Depth Estimation", |
|
css=""" |
|
#download { |
|
height: 118px; |
|
} |
|
.slider .inner { |
|
width: 5px; |
|
background: #FFF; |
|
} |
|
.viewport { |
|
aspect-ratio: 4/3; |
|
} |
|
.tabs button.selected { |
|
font-size: 20px !important; |
|
color: crimson !important; |
|
} |
|
""", |
|
head=""" |
|
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> |
|
<script> |
|
window.dataLayer = window.dataLayer || []; |
|
function gtag() {dataLayer.push(arguments);} |
|
gtag('js', new Date()); |
|
gtag('config', 'G-1FWSVCGZTG'); |
|
</script> |
|
""", |
|
) as demo: |
|
gr.Markdown( |
|
""" |
|
<h1 align="center">Marigold-LCM Depth Estimation</h1> |
|
<p align="center"> |
|
<a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://www.obukhov.ai/img/badges/badge-website.svg"> |
|
</a> |
|
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> |
|
</a> |
|
<a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> |
|
</a> |
|
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> |
|
</a> |
|
</p> |
|
<p align="justify"> |
|
Marigold-LCM is the fast version of Marigold, the state-of-the-art depth estimator for images in the wild. |
|
It combines the power of the original Marigold 10-step estimator and the Latent Consistency Models, delivering high-quality results in as little as <b>one step</b>. |
|
We provide three functions in this demo: Image, Video, and Bas-relief 3D processing — <b>see the tabs below</b>. |
|
Upload your content into the <b>left</b> side, or click any of the <b>examples</b> below. |
|
Wait a second (for images and 3D) or a minute (for videos), and interact with the result in the <b>right</b> side. |
|
To avoid queuing, fork the demo into your profile. |
|
</p> |
|
""" |
|
) |
|
|
|
with gr.Tabs(elem_classes=["tabs"]): |
|
with gr.Tab("Image"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image( |
|
label="Input Image", |
|
type="filepath", |
|
) |
|
with gr.Row(): |
|
image_submit_btn = gr.Button( |
|
value="Compute Depth", variant="primary" |
|
) |
|
image_reset_btn = gr.Button(value="Reset") |
|
with gr.Accordion("Advanced options", open=False): |
|
image_denoise_steps = gr.Slider( |
|
label="Number of denoising steps", |
|
minimum=1, |
|
maximum=4, |
|
step=1, |
|
value=default_image_denoise_steps, |
|
) |
|
image_ensemble_size = gr.Slider( |
|
label="Ensemble size", |
|
minimum=1, |
|
maximum=10, |
|
step=1, |
|
value=default_image_ensemble_size, |
|
) |
|
image_processing_res = gr.Radio( |
|
[ |
|
("Native", 0), |
|
("Recommended", 768), |
|
], |
|
label="Processing resolution", |
|
value=default_image_processing_res, |
|
) |
|
with gr.Column(): |
|
image_output_slider = ImageSlider( |
|
label="Predicted depth (red-near, blue-far)", |
|
type="filepath", |
|
show_download_button=True, |
|
show_share_button=True, |
|
interactive=False, |
|
elem_classes="slider", |
|
position=0.25, |
|
) |
|
image_output_files = gr.Files( |
|
label="Depth outputs", |
|
elem_id="download", |
|
interactive=False, |
|
) |
|
gr.Examples( |
|
fn=process_pipe_image, |
|
examples=[ |
|
os.path.join("files", "image", name) |
|
for name in [ |
|
"arc.jpeg", |
|
"berries.jpeg", |
|
"butterfly.jpeg", |
|
"cat.jpg", |
|
"concert.jpeg", |
|
"dog.jpeg", |
|
"doughnuts.jpeg", |
|
"einstein.jpg", |
|
"food.jpeg", |
|
"glasses.jpeg", |
|
"house.jpg", |
|
"lake.jpeg", |
|
"marigold.jpeg", |
|
"portrait_1.jpeg", |
|
"portrait_2.jpeg", |
|
"pumpkins.jpg", |
|
"puzzle.jpeg", |
|
"road.jpg", |
|
"scientists.jpg", |
|
"surfboards.jpeg", |
|
"surfer.jpeg", |
|
"swings.jpg", |
|
"switzerland.jpeg", |
|
"teamwork.jpeg", |
|
"wave.jpeg", |
|
] |
|
], |
|
inputs=[image_input], |
|
outputs=[image_output_slider, image_output_files], |
|
cache_examples=True, |
|
) |
|
|
|
with gr.Tab("Video"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
video_input = gr.Video( |
|
label="Input Video", |
|
sources=["upload"], |
|
) |
|
with gr.Row(): |
|
video_submit_btn = gr.Button( |
|
value="Compute Depth", variant="primary" |
|
) |
|
video_reset_btn = gr.Button(value="Reset") |
|
with gr.Column(): |
|
video_output_video = gr.Video( |
|
label="Output video depth (red-near, blue-far)", |
|
interactive=False, |
|
) |
|
video_output_files = gr.Files( |
|
label="Depth outputs", |
|
elem_id="download", |
|
interactive=False, |
|
) |
|
gr.Examples( |
|
fn=process_pipe_video, |
|
examples=[ |
|
os.path.join("files", "video", name) |
|
for name in [ |
|
"cab.mp4", |
|
"elephant.mp4", |
|
"obama.mp4", |
|
] |
|
], |
|
inputs=[video_input], |
|
outputs=[video_output_video, video_output_files], |
|
cache_examples=True, |
|
) |
|
|
|
with gr.Tab("Bas-relief (3D)"): |
|
gr.Markdown( |
|
""" |
|
<p align="justify"> |
|
This part of the demo uses Marigold-LCM to create a bas-relief model. |
|
The models are watertight, with correct normals, and exported in the STL format, which makes them <b>3D-printable</b>. |
|
Start by uploading the image and click "Create" with the default parameters. |
|
To improve the result, click "Clear", adjust the geometry sliders below, and click "Create" again. |
|
</p> |
|
""", |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
bas_input = gr.Image( |
|
label="Input Image", |
|
type="filepath", |
|
) |
|
with gr.Row(): |
|
bas_submit_btn = gr.Button(value="Create 3D", variant="primary") |
|
bas_clear_btn = gr.Button(value="Clear") |
|
bas_reset_btn = gr.Button(value="Reset") |
|
with gr.Accordion("3D printing demo: Main options", open=True): |
|
bas_plane_near = gr.Slider( |
|
label="Relative position of the near plane (between 0 and 1)", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.001, |
|
value=default_bas_plane_near, |
|
) |
|
bas_plane_far = gr.Slider( |
|
label="Relative position of the far plane (between near and 1)", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.001, |
|
value=default_bas_plane_far, |
|
) |
|
bas_embossing = gr.Slider( |
|
label="Embossing level", |
|
minimum=0, |
|
maximum=100, |
|
step=1, |
|
value=default_bas_embossing, |
|
) |
|
with gr.Accordion("3D printing demo: Advanced options", open=False): |
|
bas_denoise_steps = gr.Slider( |
|
label="Number of denoising steps", |
|
minimum=1, |
|
maximum=4, |
|
step=1, |
|
value=default_bas_denoise_steps, |
|
) |
|
bas_ensemble_size = gr.Slider( |
|
label="Ensemble size", |
|
minimum=1, |
|
maximum=10, |
|
step=1, |
|
value=default_bas_ensemble_size, |
|
) |
|
bas_processing_res = gr.Radio( |
|
[ |
|
("Native", 0), |
|
("Recommended", 768), |
|
], |
|
label="Processing resolution", |
|
value=default_bas_processing_res, |
|
) |
|
bas_size_longest_px = gr.Slider( |
|
label="Size (px) of the longest side", |
|
minimum=256, |
|
maximum=1024, |
|
step=256, |
|
value=default_bas_size_longest_px, |
|
) |
|
bas_size_longest_cm = gr.Slider( |
|
label="Size (cm) of the longest side", |
|
minimum=1, |
|
maximum=100, |
|
step=1, |
|
value=default_bas_size_longest_cm, |
|
) |
|
bas_filter_size = gr.Slider( |
|
label="Size (px) of the smoothing filter", |
|
minimum=1, |
|
maximum=5, |
|
step=2, |
|
value=default_bas_filter_size, |
|
) |
|
bas_frame_thickness = gr.Slider( |
|
label="Frame thickness", |
|
minimum=0, |
|
maximum=100, |
|
step=1, |
|
value=default_bas_frame_thickness, |
|
) |
|
bas_frame_near = gr.Slider( |
|
label="Frame's near plane offset", |
|
minimum=-100, |
|
maximum=100, |
|
step=1, |
|
value=default_bas_frame_near, |
|
) |
|
bas_frame_far = gr.Slider( |
|
label="Frame's far plane offset", |
|
minimum=1, |
|
maximum=10, |
|
step=1, |
|
value=default_bas_frame_far, |
|
) |
|
with gr.Column(): |
|
bas_output_viewer = gr.Model3D( |
|
camera_position=(75.0, 90.0, 1.25), |
|
elem_classes="viewport", |
|
label="3D preview (low-res, relief highlight)", |
|
interactive=False, |
|
) |
|
bas_output_files = gr.Files( |
|
label="3D model outputs (high-res)", |
|
elem_id="download", |
|
interactive=False, |
|
) |
|
gr.Examples( |
|
fn=process_pipe_bas, |
|
examples=[ |
|
[ |
|
"files/basrelief/coin.jpg", |
|
0.0, |
|
0.66, |
|
15, |
|
4, |
|
4, |
|
768, |
|
512, |
|
10, |
|
3, |
|
5, |
|
0, |
|
1, |
|
], |
|
[ |
|
"files/basrelief/einstein.jpg", |
|
0.0, |
|
0.5, |
|
50, |
|
2, |
|
1, |
|
768, |
|
512, |
|
10, |
|
3, |
|
5, |
|
-15, |
|
1, |
|
], |
|
[ |
|
"files/basrelief/food.jpeg", |
|
0.0, |
|
1.0, |
|
20, |
|
2, |
|
4, |
|
768, |
|
512, |
|
10, |
|
3, |
|
5, |
|
-5, |
|
1, |
|
], |
|
], |
|
inputs=[ |
|
bas_input, |
|
bas_plane_near, |
|
bas_plane_far, |
|
bas_embossing, |
|
bas_denoise_steps, |
|
bas_ensemble_size, |
|
bas_processing_res, |
|
bas_size_longest_px, |
|
bas_size_longest_cm, |
|
bas_filter_size, |
|
bas_frame_thickness, |
|
bas_frame_near, |
|
bas_frame_far, |
|
], |
|
outputs=[bas_output_viewer, bas_output_files], |
|
cache_examples=True, |
|
) |
|
|
|
image_submit_btn.click( |
|
fn=process_pipe_image, |
|
inputs=[ |
|
image_input, |
|
image_denoise_steps, |
|
image_ensemble_size, |
|
image_processing_res, |
|
], |
|
outputs=[image_output_slider, image_output_files], |
|
concurrency_limit=1, |
|
) |
|
|
|
image_reset_btn.click( |
|
fn=lambda: ( |
|
None, |
|
None, |
|
None, |
|
default_image_ensemble_size, |
|
default_image_denoise_steps, |
|
default_image_processing_res, |
|
), |
|
inputs=[], |
|
outputs=[ |
|
image_input, |
|
image_output_slider, |
|
image_output_files, |
|
image_ensemble_size, |
|
image_denoise_steps, |
|
image_processing_res, |
|
], |
|
concurrency_limit=1, |
|
) |
|
|
|
video_submit_btn.click( |
|
fn=process_pipe_video, |
|
inputs=[video_input], |
|
outputs=[video_output_video, video_output_files], |
|
concurrency_limit=1, |
|
) |
|
|
|
video_reset_btn.click( |
|
fn=lambda: (None, None, None), |
|
inputs=[], |
|
outputs=[video_input, video_output_video, video_output_files], |
|
concurrency_limit=1, |
|
) |
|
|
|
def wrapper_process_pipe_bas(*args, **kwargs): |
|
out = list(process_pipe_bas(*args, **kwargs)) |
|
out = [gr.Button(interactive=False), gr.Image(interactive=False)] + out |
|
return out |
|
|
|
bas_submit_btn.click( |
|
fn=wrapper_process_pipe_bas, |
|
inputs=[ |
|
bas_input, |
|
bas_plane_near, |
|
bas_plane_far, |
|
bas_embossing, |
|
bas_denoise_steps, |
|
bas_ensemble_size, |
|
bas_processing_res, |
|
bas_size_longest_px, |
|
bas_size_longest_cm, |
|
bas_filter_size, |
|
bas_frame_thickness, |
|
bas_frame_near, |
|
bas_frame_far, |
|
], |
|
outputs=[bas_submit_btn, bas_input, bas_output_viewer, bas_output_files], |
|
concurrency_limit=1, |
|
) |
|
|
|
bas_clear_btn.click( |
|
fn=lambda: (gr.Button(interactive=True), None, None), |
|
inputs=[], |
|
outputs=[ |
|
bas_submit_btn, |
|
bas_output_viewer, |
|
bas_output_files, |
|
], |
|
concurrency_limit=1, |
|
) |
|
|
|
bas_reset_btn.click( |
|
fn=lambda: ( |
|
gr.Button(interactive=True), |
|
None, |
|
None, |
|
None, |
|
default_bas_plane_near, |
|
default_bas_plane_far, |
|
default_bas_embossing, |
|
default_bas_denoise_steps, |
|
default_bas_ensemble_size, |
|
default_bas_processing_res, |
|
default_bas_size_longest_px, |
|
default_bas_size_longest_cm, |
|
default_bas_filter_size, |
|
default_bas_frame_thickness, |
|
default_bas_frame_near, |
|
default_bas_frame_far, |
|
), |
|
inputs=[], |
|
outputs=[ |
|
bas_submit_btn, |
|
bas_input, |
|
bas_output_viewer, |
|
bas_output_files, |
|
bas_plane_near, |
|
bas_plane_far, |
|
bas_embossing, |
|
bas_denoise_steps, |
|
bas_ensemble_size, |
|
bas_processing_res, |
|
bas_size_longest_px, |
|
bas_size_longest_cm, |
|
bas_filter_size, |
|
bas_frame_thickness, |
|
bas_frame_near, |
|
bas_frame_far, |
|
], |
|
concurrency_limit=1, |
|
) |
|
|
|
demo.queue( |
|
api_open=False, |
|
).launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
) |
|
|
|
|
|
def main(): |
|
CHECKPOINT = "prs-eth/marigold-v1-0" |
|
CHECKPOINT_UNET_LCM = "prs-eth/marigold-lcm-v1-0" |
|
|
|
if "HF_TOKEN_LOGIN" in os.environ: |
|
login(token=os.environ["HF_TOKEN_LOGIN"]) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
pipe = MarigoldDepthConsistencyPipeline.from_pretrained( |
|
CHECKPOINT, |
|
unet=UNet2DConditionModel.from_pretrained( |
|
CHECKPOINT_UNET_LCM, subfolder="unet", use_auth_token=True |
|
), |
|
) |
|
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) |
|
try: |
|
import xformers |
|
|
|
pipe.enable_xformers_memory_efficient_attention() |
|
except: |
|
pass |
|
|
|
pipe = pipe.to(device) |
|
run_demo_server(pipe) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|