toshas's picture
attempt to fix crashing of the demo build
df96952
raw
history blame
17.5 kB
# Copyright 2024 Anton Obukhov and Kevin Qu, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
# More information about the method can be found at https://marigoldcomputervision.github.io
# --------------------------------------------------------------------------
from __future__ import annotations
import functools
import os
import tempfile
import warnings
import spaces
import gradio as gr
import numpy as np
import torch as torch
from PIL import Image
from gradio_imageslider import ImageSlider
from huggingface_hub import login
from gradio_patches.examples import Examples
from marigold_iid_appearance import MarigoldIIDAppearancePipeline
from marigold_iid_lighting import MarigoldIIDLightingPipeline
warnings.filterwarnings(
"ignore", message=".*LoginButton created outside of a Blocks context.*"
)
default_seed = 2024
default_image_denoise_steps = 4
default_image_ensemble_size = 1
default_image_processing_res = 768
default_image_reproducuble = True
default_model_type = "appearance"
loaded_pipelines = {} # Cache to store loaded pipelines
def process_with_loaded_pipeline(
image_path,
model_type=default_model_type,
denoise_steps=default_image_denoise_steps,
ensemble_size=default_image_ensemble_size,
processing_res=default_image_processing_res,
):
# Load and cache the pipeline based on the model type.
if model_type not in loaded_pipelines.keys():
if model_type == "appearance":
if "lighting" in loaded_pipelines.keys():
del loaded_pipelines[
"lighting"
] # to save GPU memory. Can be removed if enough memory is available for faster switching between models
torch.cuda.empty_cache()
loaded_pipelines[model_type] = MarigoldIIDAppearancePipeline.from_pretrained(
"prs-eth/marigold-iid-appearance-v1-1"
)
elif model_type == "lighting":
if "appearance" in loaded_pipelines.keys():
del loaded_pipelines[
"appearance"
] # to save GPU memory. Can be removed if enough memory is available for faster switching between models
torch.cuda.empty_cache()
loaded_pipelines[model_type] = MarigoldIIDLightingPipeline.from_pretrained(
"prs-eth/marigold-iid-lighting-v1-1"
)
# Move the pipeline to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_pipelines[model_type] = loaded_pipelines[model_type].to(device)
try:
loaded_pipelines[model_type].enable_xformers_memory_efficient_attention()
except:
pass # run without xformers
pipe = loaded_pipelines[model_type]
# Process the image using the preloaded pipeline.
return process_image(
pipe=pipe,
path_input=image_path,
denoise_steps=denoise_steps,
ensemble_size=ensemble_size,
processing_res=processing_res,
model_type=model_type,
)
def process_image_check(path_input):
if path_input is None:
raise gr.Error(
"Missing image in the first pane: upload a file or use one from the gallery below."
)
def process_image(
pipe,
path_input,
denoise_steps=default_image_denoise_steps,
ensemble_size=default_image_ensemble_size,
processing_res=default_image_processing_res,
model_type=default_model_type,
):
name_base, name_ext = os.path.splitext(os.path.basename(path_input))
print(f"Processing image {name_base}{name_ext}")
path_output_dir = tempfile.mkdtemp()
input_image = Image.open(path_input)
pipe_out = pipe(
input_image,
denoising_steps=denoise_steps,
ensemble_size=ensemble_size,
processing_res=processing_res,
batch_size=1
if processing_res == 0
else 0, # TODO: do we abuse "batch size" notation here?
seed=default_seed,
show_progress_bar=True,
)
path_output_dir = os.path.splitext(path_input)[0] + "_output"
os.makedirs(path_output_dir, exist_ok=True)
if model_type == "appearance":
path_albedo_out = os.path.join(
path_output_dir, f"{name_base}_albedo_app_fp32.npy"
)
path_albedo_out_vis = os.path.join(
path_output_dir, f"{name_base}_albedo_app.png"
)
path_material_out = os.path.join(
path_output_dir, f"{name_base}_material_fp32.npy"
)
path_material_out_vis = os.path.join(
path_output_dir, f"{name_base}_material.png"
)
albedo = pipe_out.albedo
albedo_colored = pipe_out.albedo_colored
material = pipe_out.material
material_colored = pipe_out.material_colored
np.save(path_albedo_out, albedo)
albedo_colored.save(path_albedo_out_vis)
np.save(path_material_out, material)
material_colored.save(path_material_out_vis)
return (
[path_input, path_albedo_out_vis],
[path_input, path_material_out_vis],
[path_input, path_material_out_vis], # placeholder which is not displayed
[
path_albedo_out_vis,
path_material_out_vis,
path_albedo_out,
path_material_out,
],
)
elif model_type == "lighting":
path_albedo_out = os.path.join(
path_output_dir, f"{name_base}_albedo_res_fp32.npy"
)
path_albedo_out_vis = os.path.join(
path_output_dir, f"{name_base}_albedo_res.png"
)
path_shading_out = os.path.join(
path_output_dir, f"{name_base}_shading_fp32.npy"
)
path_shading_out_vis = os.path.join(path_output_dir, f"{name_base}_shading.png")
path_residual_out = os.path.join(
path_output_dir, f"{name_base}_residual_fp32.npy"
)
path_residual_out_vis = os.path.join(
path_output_dir, f"{name_base}_residual.png"
)
albedo = pipe_out.albedo
albedo_colored = pipe_out.albedo_colored
shading = pipe_out.shading
shading_colored = pipe_out.shading_colored
residual = pipe_out.residual
residual_colored = pipe_out.residual_colored
np.save(path_albedo_out, albedo)
albedo_colored.save(path_albedo_out_vis)
np.save(path_shading_out, shading)
shading_colored.save(path_shading_out_vis)
np.save(path_residual_out, residual)
residual_colored.save(path_residual_out_vis)
return (
[path_input, path_albedo_out_vis],
[path_input, path_shading_out_vis],
[path_input, path_residual_out_vis],
[
path_albedo_out_vis,
path_shading_out_vis,
path_residual_out_vis,
path_albedo_out,
path_shading_out,
path_residual_out,
],
)
def run_demo_server():
process_pipe_image = spaces.GPU(
functools.partial(process_with_loaded_pipeline), duration=120
)
gradio_theme = gr.themes.Default()
with gr.Blocks(
theme=gradio_theme,
title="Marigold Intrinsic Image Decomposition (Marigold-IID)",
css="""
#download {
height: 118px;
}
.slider .inner {
width: 5px;
background: #FFF;
}
.viewport {
aspect-ratio: 4/3;
}
.tabs button.selected {
font-size: 20px !important;
color: crimson !important;
}
h1 {
text-align: center;
display: block;
}
h2 {
text-align: center;
display: block;
}
h3 {
text-align: center;
display: block;
}
.md_feedback li {
margin-bottom: 0px !important;
}
""",
head="""
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'G-1FWSVCGZTG');
</script>
""",
) as demo:
gr.Markdown(
"""
# Marigold Intrinsic Image Decomposition (IID)
<p align="center">
<a title="Website" href="https://marigoldcomputervision.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
</a>
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
</a>
<a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a>
</p>
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
label="Input Image",
type="filepath",
)
model_type = gr.Radio(
[
("Appearance (albedo & material)", "appearance"),
("Lighting (albedo, shading & residual)", "lighting"),
],
label="Model type: Marigold-IID-Appearance or Marigold IID-Lighting",
value=default_model_type,
)
with gr.Accordion("Advanced options", open=True):
image_ensemble_size = gr.Slider(
label="Ensemble size",
minimum=1,
maximum=5,
step=1,
value=default_image_ensemble_size,
)
image_denoise_steps = gr.Slider(
label="Number of denoising steps",
minimum=1,
maximum=10,
step=1,
value=default_image_denoise_steps,
)
image_processing_res = gr.Radio(
[
("Native", 0),
("Recommended", 768),
],
label="Processing resolution",
value=default_image_processing_res,
)
with gr.Row():
image_submit_btn = gr.Button(value="Compute IID", variant="primary")
image_reset_btn = gr.Button(value="Reset")
with gr.Column():
image_output_slider1 = ImageSlider(
label="Predicted Albedo",
type="filepath",
show_download_button=True,
show_share_button=True,
interactive=False,
elem_classes="slider",
position=0.25,
visible=True,
)
image_output_slider2 = ImageSlider(
label="Predicted Material",
type="filepath",
show_download_button=True,
show_share_button=True,
interactive=False,
elem_classes="slider",
position=0.25,
visible=True,
)
image_output_slider3 = ImageSlider(
label="Predicted Residual",
type="filepath",
show_download_button=True,
show_share_button=True,
interactive=False,
elem_classes="slider",
position=0.25,
visible=False,
)
image_output_files = gr.Files(
label="Output files",
elem_id="download",
interactive=False,
)
# Function to toggle visibility and set dynamic labels
def toggle_sliders_and_labels(model_type):
if model_type == "appearance":
return (
gr.update(visible=True, label="Predicted Albedo"),
gr.update(visible=True, label="Predicted Material"),
gr.update(visible=False), # Hide third slider
)
elif model_type == "lighting":
return (
gr.update(visible=True, label="Predicted Albedo"),
gr.update(visible=True, label="Predicted Shading"),
gr.update(visible=True, label="Predicted Residual"),
)
# Attach the change event to update sliders
model_type.change(
fn=toggle_sliders_and_labels,
inputs=[model_type],
outputs=[image_output_slider1, image_output_slider2, image_output_slider3],
show_progress=False,
)
Examples(
fn=process_pipe_image,
examples=[
[os.path.join("files", "image", name), _model_type]
for name in [
"livingroom.jpg",
"books.jpg",
"food_counter.png",
"cat2.png",
"costumes.png",
"icecream.jpg",
"juices.jpeg",
"cat.jpg",
"food.jpeg",
"puzzle.jpeg",
"screw.png",
]
for _model_type in ["appearance", "lighting"]
],
inputs=[image_input, model_type],
outputs=[
image_output_slider1,
image_output_slider2,
image_output_slider3,
image_output_files,
],
cache_examples=True, # TODO: toggle later
directory_name="examples_images",
)
### Image tab
image_submit_btn.click(
fn=process_image_check,
inputs=image_input,
outputs=None,
preprocess=False,
queue=False,
).success(
fn=process_pipe_image,
inputs=[
image_input,
model_type,
image_denoise_steps,
image_ensemble_size,
image_processing_res,
],
outputs=[
image_output_slider1,
image_output_slider2,
image_output_slider3,
image_output_files,
],
concurrency_limit=1,
)
image_reset_btn.click(
fn=lambda: (
None,
None,
None,
None,
None,
default_model_type,
default_image_ensemble_size,
default_image_denoise_steps,
default_image_processing_res,
),
inputs=[],
outputs=[
image_input,
image_output_slider1,
image_output_slider2,
image_output_slider3,
image_output_files,
model_type,
image_ensemble_size,
image_denoise_steps,
image_processing_res,
],
queue=False,
)
demo.queue(
api_open=False,
).launch(
server_name="0.0.0.0",
server_port=7860,
)
def main():
os.system("pip freeze")
if "HF_TOKEN_LOGIN" in os.environ:
login(token=os.environ["HF_TOKEN_LOGIN"])
run_demo_server()
if __name__ == "__main__":
main()