Spaces:
Sleeping
Sleeping
# Copyright 2024 Anton Obukhov and Kevin Qu, ETH Zurich. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# -------------------------------------------------------------------------- | |
# If you find this code useful, we kindly ask you to cite our paper in your work. | |
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation | |
# More information about the method can be found at https://marigoldcomputervision.github.io | |
# -------------------------------------------------------------------------- | |
from __future__ import annotations | |
import functools | |
import os | |
import tempfile | |
import warnings | |
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch as torch | |
from PIL import Image | |
from gradio_imageslider import ImageSlider | |
from huggingface_hub import login | |
from gradio_patches.examples import Examples | |
from marigold_iid_appearance import MarigoldIIDAppearancePipeline | |
from marigold_iid_lighting import MarigoldIIDLightingPipeline | |
warnings.filterwarnings( | |
"ignore", message=".*LoginButton created outside of a Blocks context.*" | |
) | |
default_seed = 2024 | |
default_image_denoise_steps = 4 | |
default_image_ensemble_size = 1 | |
default_image_processing_res = 768 | |
default_image_reproducuble = True | |
default_model_type = "appearance" | |
loaded_pipelines = {} # Cache to store loaded pipelines | |
def process_with_loaded_pipeline( | |
image_path, | |
model_type=default_model_type, | |
denoise_steps=default_image_denoise_steps, | |
ensemble_size=default_image_ensemble_size, | |
processing_res=default_image_processing_res, | |
): | |
# Load and cache the pipeline based on the model type. | |
if model_type not in loaded_pipelines.keys(): | |
if model_type == "appearance": | |
if "lighting" in loaded_pipelines.keys(): | |
del loaded_pipelines[ | |
"lighting" | |
] # to save GPU memory. Can be removed if enough memory is available for faster switching between models | |
torch.cuda.empty_cache() | |
loaded_pipelines[model_type] = MarigoldIIDAppearancePipeline.from_pretrained( | |
"prs-eth/marigold-iid-appearance-v1-1" | |
) | |
elif model_type == "lighting": | |
if "appearance" in loaded_pipelines.keys(): | |
del loaded_pipelines[ | |
"appearance" | |
] # to save GPU memory. Can be removed if enough memory is available for faster switching between models | |
torch.cuda.empty_cache() | |
loaded_pipelines[model_type] = MarigoldIIDLightingPipeline.from_pretrained( | |
"prs-eth/marigold-iid-lighting-v1-1" | |
) | |
# Move the pipeline to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
loaded_pipelines[model_type] = loaded_pipelines[model_type].to(device) | |
try: | |
loaded_pipelines[model_type].enable_xformers_memory_efficient_attention() | |
except: | |
pass # run without xformers | |
pipe = loaded_pipelines[model_type] | |
# Process the image using the preloaded pipeline. | |
return process_image( | |
pipe=pipe, | |
path_input=image_path, | |
denoise_steps=denoise_steps, | |
ensemble_size=ensemble_size, | |
processing_res=processing_res, | |
model_type=model_type, | |
) | |
def process_image_check(path_input): | |
if path_input is None: | |
raise gr.Error( | |
"Missing image in the first pane: upload a file or use one from the gallery below." | |
) | |
def process_image( | |
pipe, | |
path_input, | |
denoise_steps=default_image_denoise_steps, | |
ensemble_size=default_image_ensemble_size, | |
processing_res=default_image_processing_res, | |
model_type=default_model_type, | |
): | |
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) | |
print(f"Processing image {name_base}{name_ext}") | |
path_output_dir = tempfile.mkdtemp() | |
input_image = Image.open(path_input) | |
pipe_out = pipe( | |
input_image, | |
denoising_steps=denoise_steps, | |
ensemble_size=ensemble_size, | |
processing_res=processing_res, | |
batch_size=1 | |
if processing_res == 0 | |
else 0, # TODO: do we abuse "batch size" notation here? | |
seed=default_seed, | |
show_progress_bar=True, | |
) | |
path_output_dir = os.path.splitext(path_input)[0] + "_output" | |
os.makedirs(path_output_dir, exist_ok=True) | |
if model_type == "appearance": | |
path_albedo_out = os.path.join( | |
path_output_dir, f"{name_base}_albedo_app_fp32.npy" | |
) | |
path_albedo_out_vis = os.path.join( | |
path_output_dir, f"{name_base}_albedo_app.png" | |
) | |
path_material_out = os.path.join( | |
path_output_dir, f"{name_base}_material_fp32.npy" | |
) | |
path_material_out_vis = os.path.join( | |
path_output_dir, f"{name_base}_material.png" | |
) | |
albedo = pipe_out.albedo | |
albedo_colored = pipe_out.albedo_colored | |
material = pipe_out.material | |
material_colored = pipe_out.material_colored | |
np.save(path_albedo_out, albedo) | |
albedo_colored.save(path_albedo_out_vis) | |
np.save(path_material_out, material) | |
material_colored.save(path_material_out_vis) | |
return ( | |
[path_input, path_albedo_out_vis], | |
[path_input, path_material_out_vis], | |
[path_input, path_material_out_vis], # placeholder which is not displayed | |
[ | |
path_albedo_out_vis, | |
path_material_out_vis, | |
path_albedo_out, | |
path_material_out, | |
], | |
) | |
elif model_type == "lighting": | |
path_albedo_out = os.path.join( | |
path_output_dir, f"{name_base}_albedo_res_fp32.npy" | |
) | |
path_albedo_out_vis = os.path.join( | |
path_output_dir, f"{name_base}_albedo_res.png" | |
) | |
path_shading_out = os.path.join( | |
path_output_dir, f"{name_base}_shading_fp32.npy" | |
) | |
path_shading_out_vis = os.path.join(path_output_dir, f"{name_base}_shading.png") | |
path_residual_out = os.path.join( | |
path_output_dir, f"{name_base}_residual_fp32.npy" | |
) | |
path_residual_out_vis = os.path.join( | |
path_output_dir, f"{name_base}_residual.png" | |
) | |
albedo = pipe_out.albedo | |
albedo_colored = pipe_out.albedo_colored | |
shading = pipe_out.shading | |
shading_colored = pipe_out.shading_colored | |
residual = pipe_out.residual | |
residual_colored = pipe_out.residual_colored | |
np.save(path_albedo_out, albedo) | |
albedo_colored.save(path_albedo_out_vis) | |
np.save(path_shading_out, shading) | |
shading_colored.save(path_shading_out_vis) | |
np.save(path_residual_out, residual) | |
residual_colored.save(path_residual_out_vis) | |
return ( | |
[path_input, path_albedo_out_vis], | |
[path_input, path_shading_out_vis], | |
[path_input, path_residual_out_vis], | |
[ | |
path_albedo_out_vis, | |
path_shading_out_vis, | |
path_residual_out_vis, | |
path_albedo_out, | |
path_shading_out, | |
path_residual_out, | |
], | |
) | |
def run_demo_server(): | |
process_pipe_image = spaces.GPU( | |
functools.partial(process_with_loaded_pipeline), duration=120 | |
) | |
gradio_theme = gr.themes.Default() | |
with gr.Blocks( | |
theme=gradio_theme, | |
title="Marigold Intrinsic Image Decomposition (Marigold-IID)", | |
css=""" | |
#download { | |
height: 118px; | |
} | |
.slider .inner { | |
width: 5px; | |
background: #FFF; | |
} | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
.tabs button.selected { | |
font-size: 20px !important; | |
color: crimson !important; | |
} | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
.md_feedback li { | |
margin-bottom: 0px !important; | |
} | |
""", | |
head=""" | |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag() {dataLayer.push(arguments);} | |
gtag('js', new Date()); | |
gtag('config', 'G-1FWSVCGZTG'); | |
</script> | |
""", | |
) as demo: | |
gr.Markdown( | |
""" | |
# Marigold Intrinsic Image Decomposition (IID) | |
<p align="center"> | |
<a title="Website" href="https://marigoldcomputervision.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-website.svg"> | |
</a> | |
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> | |
</a> | |
<a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
</a> | |
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
</a> | |
</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image( | |
label="Input Image", | |
type="filepath", | |
) | |
model_type = gr.Radio( | |
[ | |
("Appearance (albedo & material)", "appearance"), | |
("Lighting (albedo, shading & residual)", "lighting"), | |
], | |
label="Model type: Marigold-IID-Appearance or Marigold IID-Lighting", | |
value=default_model_type, | |
) | |
with gr.Accordion("Advanced options", open=True): | |
image_ensemble_size = gr.Slider( | |
label="Ensemble size", | |
minimum=1, | |
maximum=5, | |
step=1, | |
value=default_image_ensemble_size, | |
) | |
image_denoise_steps = gr.Slider( | |
label="Number of denoising steps", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=default_image_denoise_steps, | |
) | |
image_processing_res = gr.Radio( | |
[ | |
("Native", 0), | |
("Recommended", 768), | |
], | |
label="Processing resolution", | |
value=default_image_processing_res, | |
) | |
with gr.Row(): | |
image_submit_btn = gr.Button(value="Compute IID", variant="primary") | |
image_reset_btn = gr.Button(value="Reset") | |
with gr.Column(): | |
image_output_slider1 = ImageSlider( | |
label="Predicted Albedo", | |
type="filepath", | |
show_download_button=True, | |
show_share_button=True, | |
interactive=False, | |
elem_classes="slider", | |
position=0.25, | |
visible=True, | |
) | |
image_output_slider2 = ImageSlider( | |
label="Predicted Material", | |
type="filepath", | |
show_download_button=True, | |
show_share_button=True, | |
interactive=False, | |
elem_classes="slider", | |
position=0.25, | |
visible=True, | |
) | |
image_output_slider3 = ImageSlider( | |
label="Predicted Residual", | |
type="filepath", | |
show_download_button=True, | |
show_share_button=True, | |
interactive=False, | |
elem_classes="slider", | |
position=0.25, | |
visible=False, | |
) | |
image_output_files = gr.Files( | |
label="Output files", | |
elem_id="download", | |
interactive=False, | |
) | |
# Function to toggle visibility and set dynamic labels | |
def toggle_sliders_and_labels(model_type): | |
if model_type == "appearance": | |
return ( | |
gr.update(visible=True, label="Predicted Albedo"), | |
gr.update(visible=True, label="Predicted Material"), | |
gr.update(visible=False), # Hide third slider | |
) | |
elif model_type == "lighting": | |
return ( | |
gr.update(visible=True, label="Predicted Albedo"), | |
gr.update(visible=True, label="Predicted Shading"), | |
gr.update(visible=True, label="Predicted Residual"), | |
) | |
# Attach the change event to update sliders | |
model_type.change( | |
fn=toggle_sliders_and_labels, | |
inputs=[model_type], | |
outputs=[image_output_slider1, image_output_slider2, image_output_slider3], | |
show_progress=False, | |
) | |
Examples( | |
fn=process_pipe_image, | |
examples=[ | |
[os.path.join("files", "image", name), _model_type] | |
for name in [ | |
"livingroom.jpg", | |
"books.jpg", | |
"food_counter.png", | |
"cat2.png", | |
"costumes.png", | |
"icecream.jpg", | |
"juices.jpeg", | |
"cat.jpg", | |
"food.jpeg", | |
"puzzle.jpeg", | |
"screw.png", | |
] | |
for _model_type in ["appearance", "lighting"] | |
], | |
inputs=[image_input, model_type], | |
outputs=[ | |
image_output_slider1, | |
image_output_slider2, | |
image_output_slider3, | |
image_output_files, | |
], | |
cache_examples=True, # TODO: toggle later | |
directory_name="examples_images", | |
) | |
### Image tab | |
image_submit_btn.click( | |
fn=process_image_check, | |
inputs=image_input, | |
outputs=None, | |
preprocess=False, | |
queue=False, | |
).success( | |
fn=process_pipe_image, | |
inputs=[ | |
image_input, | |
model_type, | |
image_denoise_steps, | |
image_ensemble_size, | |
image_processing_res, | |
], | |
outputs=[ | |
image_output_slider1, | |
image_output_slider2, | |
image_output_slider3, | |
image_output_files, | |
], | |
concurrency_limit=1, | |
) | |
image_reset_btn.click( | |
fn=lambda: ( | |
None, | |
None, | |
None, | |
None, | |
None, | |
default_model_type, | |
default_image_ensemble_size, | |
default_image_denoise_steps, | |
default_image_processing_res, | |
), | |
inputs=[], | |
outputs=[ | |
image_input, | |
image_output_slider1, | |
image_output_slider2, | |
image_output_slider3, | |
image_output_files, | |
model_type, | |
image_ensemble_size, | |
image_denoise_steps, | |
image_processing_res, | |
], | |
queue=False, | |
) | |
demo.queue( | |
api_open=False, | |
).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
) | |
def main(): | |
os.system("pip freeze") | |
if "HF_TOKEN_LOGIN" in os.environ: | |
login(token=os.environ["HF_TOKEN_LOGIN"]) | |
run_demo_server() | |
if __name__ == "__main__": | |
main() | |