marigold-iid-private / marigold_iid_lighting.py
KevinQu7
update -gitattributes
a8e6640
# Copyright 2024 Anton Obukhov, Bingxin Ke & Kevin Qu, ETH Zurich and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
# More information about the method can be found at https://marigoldcomputervision.github.io
# --------------------------------------------------------------------------
import logging
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.utils import BaseOutput, check_min_version
from PIL import Image
from PIL.Image import Resampling
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
class MarigoldIIDLightingOutput(BaseOutput):
"""
Output class for Marigold-IID-Lighting pipeline.
Args:
albedo (`np.ndarray`):
Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1].
albedo_colored (`PIL.Image.Image`):
Colorized albedo map with the shape of [H, W, 3].
shading (`np.ndarray`):
Predicted diffuse shading map with the shape of [3, H, W] values in the range of [0, 1].
shading_colored (`PIL.Image.Image`):
Colorized diffuse shading map with the shape of [H, W, 3].
residual (`np.ndarray`):
Predicted non-diffuse residual map with the shape of [3, H, W] values in the range of [0, 1].
residual_colored (`PIL.Image.Image`):
Colorized non-diffuse residual map with the shape of [H, W, 3].
"""
albedo: np.ndarray
albedo_colored: Image.Image
shading: np.ndarray
shading_colored: Image.Image
residual: np.ndarray
residual_colored: Image.Image
class MarigoldIIDLightingPipeline(DiffusionPipeline):
"""
Pipeline for Intrinsic Image Decomposition (Albedo, diffuse shading and non-diffuse residual) using Marigold: https://marigoldcomputervision.github.io.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
unet (`UNet2DConditionModel`):
Conditional U-Net to denoise the normals latent, conditioned on image latent.
vae (`AutoencoderKL`):
Variational Auto-Encoder (VAE) Model to encode and decode images and normals maps
to and from latent representations.
scheduler (`DDIMScheduler`):
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
text_encoder (`CLIPTextModel`):
Text-encoder, for empty text embedding.
tokenizer (`CLIPTokenizer`):
CLIP tokenizer.
"""
latent_scale_factor = 0.18215
def __init__(
self,
unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: DDIMScheduler,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
):
super().__init__()
self.register_modules(
unet=unet,
vae=vae,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
self.empty_text_embed = None
self.n_targets = 3 # Albedo, shading, residual
@torch.no_grad()
def __call__(
self,
input_image: Image,
denoising_steps: int = 4,
ensemble_size: int = 10,
processing_res: int = 768,
match_input_res: bool = True,
resample_method: str = "bilinear",
batch_size: int = 0,
save_memory: bool = False,
seed: Union[int, None] = None,
color_map: str = "Spectral", # TODO change colorization api based on modality
show_progress_bar: bool = True,
**kwargs,
) -> MarigoldIIDLightingOutput:
"""
Function invoked when calling the pipeline.
Args:
input_image (`Image`):
Input RGB (or gray-scale) image.
denoising_steps (`int`, *optional*, defaults to `10`):
Number of diffusion denoising steps (DDIM) during inference.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
processing_res (`int`, *optional*, defaults to `768`):
Maximum resolution of processing.
If set to 0: will not resize at all.
match_input_res (`bool`, *optional*, defaults to `True`):
Resize normals prediction to match input resolution.
Only valid if `limit_input_res` is not None.
resample_method: (`str`, *optional*, defaults to `bilinear`):
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`.
If set to 0, the script will automatically decide the proper batch size.
save_memory (`bool`, defaults to `False`):
Extra steps to save memory at the cost of perforance.
seed (`int`, *optional*, defaults to `None`)
Reproducibility seed.
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized normals map generation):
Colormap used to colorize the normals map.
show_progress_bar (`bool`, *optional*, defaults to `True`):
Display a progress bar of diffusion denoising.
Returns:
`MarigoldIIDLightingOutput`: Output class for Marigold monocular intrinsic image decomposition (lighting) prediction pipeline, including:
- **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]
- **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1]
- **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1]
- **material_colored** (`PIL.Image.Image`) Colorized material map with the shape of [3, H, W] and values in [0, 1]
"""
if not match_input_res:
assert processing_res is not None
assert processing_res >= 0
assert denoising_steps >= 1
assert ensemble_size >= 1
# Check if denoising step is reasonable
self.check_inference_step(denoising_steps)
resample_method: Resampling = self.get_pil_resample_method(resample_method)
W, H = input_image.size
if processing_res > 0:
input_image = self.resize_max_res(
input_image, max_edge_resolution=processing_res, resample_method=resample_method,
)
input_image = input_image.convert("RGB")
image = np.asarray(input_image)
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
rgb_norm = rgb_norm.to(self.device)
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # TODO remove this
def ensemble(
targets: torch.Tensor, return_uncertainty: bool = False, reduction = "median",
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
uncertainty = None
if reduction == "mean":
prediction = torch.mean(targets, dim=0, keepdim=True)
if return_uncertainty:
uncertainty = torch.std(targets, dim=0, keepdim=True)
elif reduction == "median":
prediction = torch.median(targets, dim=0, keepdim=True).values
if return_uncertainty:
uncertainty = torch.median(
torch.abs(targets - prediction), dim=0, keepdim=True
).values
else:
raise ValueError(f"Unrecognized reduction method: {reduction}.")
return prediction, uncertainty
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
single_rgb_dataset = TensorDataset(duplicated_rgb)
if batch_size <= 0:
batch_size = self.find_batch_size(
ensemble_size=ensemble_size,
input_res=max(rgb_norm.shape[1:]),
dtype=self.dtype,
)
single_rgb_loader = DataLoader(
single_rgb_dataset, batch_size=batch_size, shuffle=False
)
target_pred_ls = []
iterable = single_rgb_loader
if show_progress_bar:
iterable = tqdm(
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
)
for batch in iterable:
(batched_img,) = batch
target_pred = self.single_infer(
rgb_in=batched_img,
num_inference_steps=denoising_steps,
seed=seed,
show_pbar=show_progress_bar,
)
target_pred = target_pred.detach()
if save_memory:
target_pred = target_pred.cpu()
target_pred_ls.append(target_pred.detach())
target_preds = torch.concat(target_pred_ls, dim=0)
pred_uncert = None
if save_memory:
torch.cuda.empty_cache()
if ensemble_size > 1:
final_pred, pred_uncert = ensemble(
target_preds,
reduction = "median",
return_uncertainty=False
)
else:
final_pred = target_preds
pred_uncert = None
if match_input_res:
final_pred = torch.nn.functional.interpolate(
final_pred, (H, W), mode="bilinear" # TODO: parameterize this method
) # [1,3,H,W]
if pred_uncert is not None:
pred_uncert = torch.nn.functional.interpolate(
pred_uncert.unsqueeze(1), (H, W), mode="bilinear"
).squeeze(
1
) # [1,H,W]
# Convert to numpy
final_pred = final_pred.squeeze()
final_pred = final_pred.cpu().numpy()
albedo = final_pred[0:3, :, :]
shading = final_pred[3:6, :, :]
residual = final_pred[6:, :, :]
albedo_colored = (albedo + 1.0) * 0.5 # [-1,1] -> [0,1]
albedo_colored = albedo_colored ** (1/2.2) # from linear to sRGB (to be consistent with IID-Appearance model)
albedo_colored = (albedo_colored * 255).astype(np.uint8)
albedo_colored = self.chw2hwc(albedo_colored)
albedo_colored_img = Image.fromarray(albedo_colored)
shading_colored = (shading + 1.0) * 0.5
shading_colored = shading_colored / shading_colored.max() # rescale for better visualization
shading_colored = (shading_colored * 255).astype(np.uint8)
shading_colored = self.chw2hwc(shading_colored)
shading_colored_img = Image.fromarray(shading_colored)
residual_colored = (residual + 1.0) * 0.5
residual_colored = residual_colored / residual_colored.max() # rescale for better visualization
residual_colored = (residual_colored * 255).astype(np.uint8)
residual_colored = self.chw2hwc(residual_colored)
residual_colored_img = Image.fromarray(residual_colored)
out = MarigoldIIDLightingOutput(
albedo=albedo,
albedo_colored=albedo_colored_img,
shading=shading,
shading_colored=shading_colored_img,
residual=residual,
residual_colored=residual_colored_img
)
return out
def check_inference_step(self, n_step: int):
"""
Check if denoising step is reasonable
Args:
n_step (`int`): denoising steps
"""
assert n_step >= 1
if isinstance(self.scheduler, DDIMScheduler):
pass
else:
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
def encode_empty_text(self):
"""
Encode text embedding for empty prompt.
"""
prompt = ""
text_inputs = self.tokenizer(
prompt,
padding="do_not_pad",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
@torch.no_grad()
def single_infer(
self,
rgb_in: torch.Tensor,
num_inference_steps: int,
seed: Union[int, None],
show_pbar: bool,
) -> torch.Tensor:
"""
Perform an individual iid prediction without ensembling.
"""
device = rgb_in.device
# Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps # [T]
# Encode image
rgb_latent = self.encode_rgb(rgb_in)
target_latent_shape = list(rgb_latent.shape)
target_latent_shape[1] *= (
3 # TODO: no hardcoding # self.n_targets # (B, 4*n_targets, h, w)
)
# Initialize prediction latent with noise
if seed is None:
rand_num_generator = None
else:
rand_num_generator = torch.Generator(device=device)
rand_num_generator.manual_seed(seed)
target_latents = torch.randn(
target_latent_shape,
device=device,
dtype=self.dtype,
generator=rand_num_generator,
) # [B, 4, h, w]
# Batched empty text embedding
if self.empty_text_embed is None:
self.encode_empty_text()
batch_empty_text_embed = self.empty_text_embed.repeat(
(rgb_latent.shape[0], 1, 1)
) # [B, 2, 1024]
# Denoising loop
if show_pbar:
iterable = tqdm(
enumerate(timesteps),
total=len(timesteps),
leave=False,
desc=" " * 4 + "Diffusion denoising",
)
else:
iterable = enumerate(timesteps)
for i, t in iterable:
unet_input = torch.cat(
[rgb_latent, target_latents], dim=1
) # this order is important
# predict the noise residual
noise_pred = self.unet(
unet_input, t, encoder_hidden_states=batch_empty_text_embed
).sample # [B, 4, h, w]
# compute the previous noisy sample x_t -> x_t-1
target_latents = self.scheduler.step(
noise_pred, t, target_latents, generator=rand_num_generator
).prev_sample
# torch.cuda.empty_cache() # TODO is it really needed here, even if memory saving?
targets = self.decode_targets(target_latents) # [B, 3, H, W]
targets = torch.clip(targets, -1.0, 1.0)
return targets
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
"""
Encode RGB image into latent.
Args:
rgb_in (`torch.Tensor`):
Input RGB image to be encoded.
Returns:
`torch.Tensor`: Image latent.
"""
# encode
h = self.vae.encoder(rgb_in)
moments = self.vae.quant_conv(h)
mean, logvar = torch.chunk(moments, 2, dim=1)
# scale latent
rgb_latent = mean * self.latent_scale_factor
return rgb_latent
def decode_targets(self, target_latents: torch.Tensor) -> torch.Tensor:
"""
Decode target latent into target map.
Args:
target_latents (`torch.Tensor`):
Target latent to be decoded.
Returns:
`torch.Tensor`: Decoded target map.
"""
assert target_latents.shape[1] == 12 # self.n_targets * 4
# scale latent
target_latents = target_latents / self.latent_scale_factor
# decode
targets = []
for i in range(self.n_targets):
latent = target_latents[:, i * 4 : (i + 1) * 4, :, :]
z = self.vae.post_quant_conv(latent)
stacked = self.vae.decoder(z)
targets.append(stacked)
return torch.cat(targets, dim=1)
@staticmethod
def get_pil_resample_method(method_str: str) -> Resampling:
resample_method_dic = {
"bilinear": Resampling.BILINEAR,
"bicubic": Resampling.BICUBIC,
"nearest": Resampling.NEAREST,
}
resample_method = resample_method_dic.get(method_str, None)
if resample_method is None:
raise ValueError(f"Unknown resampling method: {resample_method}")
else:
return resample_method
@staticmethod
def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image:
"""
Resize image to limit maximum edge length while keeping aspect ratio.
"""
original_width, original_height = img.size
downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height)
new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor)
resized_img = img.resize((new_width, new_height), resample=resample_method)
return resized_img
@staticmethod
def chw2hwc(chw):
assert 3 == len(chw.shape)
if isinstance(chw, torch.Tensor):
hwc = torch.permute(chw, (1, 2, 0))
elif isinstance(chw, np.ndarray):
hwc = np.moveaxis(chw, 0, -1)
return hwc
@staticmethod
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
"""
Automatically search for suitable operating batch size.
Args:
ensemble_size (`int`):
Number of predictions to be ensembled.
input_res (`int`):
Operating resolution of the input image.
Returns:
`int`: Operating batch size.
"""
# Search table for suggested max. inference batch size
bs_search_table = [
# tested on A100-PCIE-80GB
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
# tested on A100-PCIE-40GB
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
# tested on RTX3090, RTX4090
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
# tested on GTX1080Ti
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
]
if not torch.cuda.is_available():
return 1
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
for settings in sorted(
filtered_bs_search_table,
key=lambda k: (k["res"], -k["total_vram"]),
):
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
bs = settings["bs"]
if bs > ensemble_size:
bs = ensemble_size
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
bs = math.ceil(ensemble_size / 2)
return bs
return 1