Spaces:
Sleeping
Sleeping
# Copyright 2024 Anton Obukhov, Bingxin Ke & Kevin Qu, ETH Zurich and The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# -------------------------------------------------------------------------- | |
# If you find this code useful, we kindly ask you to cite our paper in your work. | |
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation | |
# More information about the method can be found at https://marigoldcomputervision.github.io | |
# -------------------------------------------------------------------------- | |
import logging | |
import math | |
from typing import Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from diffusers import ( | |
AutoencoderKL, | |
DDIMScheduler, | |
DiffusionPipeline, | |
UNet2DConditionModel, | |
) | |
from diffusers.utils import BaseOutput, check_min_version | |
from PIL import Image | |
from PIL.Image import Resampling | |
from torch.utils.data import DataLoader, TensorDataset | |
from tqdm.auto import tqdm | |
from transformers import CLIPTextModel, CLIPTokenizer | |
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. | |
check_min_version("0.27.0.dev0") | |
class MarigoldIIDLightingOutput(BaseOutput): | |
""" | |
Output class for Marigold-IID-Lighting pipeline. | |
Args: | |
albedo (`np.ndarray`): | |
Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]. | |
albedo_colored (`PIL.Image.Image`): | |
Colorized albedo map with the shape of [H, W, 3]. | |
shading (`np.ndarray`): | |
Predicted diffuse shading map with the shape of [3, H, W] values in the range of [0, 1]. | |
shading_colored (`PIL.Image.Image`): | |
Colorized diffuse shading map with the shape of [H, W, 3]. | |
residual (`np.ndarray`): | |
Predicted non-diffuse residual map with the shape of [3, H, W] values in the range of [0, 1]. | |
residual_colored (`PIL.Image.Image`): | |
Colorized non-diffuse residual map with the shape of [H, W, 3]. | |
""" | |
albedo: np.ndarray | |
albedo_colored: Image.Image | |
shading: np.ndarray | |
shading_colored: Image.Image | |
residual: np.ndarray | |
residual_colored: Image.Image | |
class MarigoldIIDLightingPipeline(DiffusionPipeline): | |
""" | |
Pipeline for Intrinsic Image Decomposition (Albedo, diffuse shading and non-diffuse residual) using Marigold: https://marigoldcomputervision.github.io. | |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the | |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) | |
Args: | |
unet (`UNet2DConditionModel`): | |
Conditional U-Net to denoise the normals latent, conditioned on image latent. | |
vae (`AutoencoderKL`): | |
Variational Auto-Encoder (VAE) Model to encode and decode images and normals maps | |
to and from latent representations. | |
scheduler (`DDIMScheduler`): | |
A scheduler to be used in combination with `unet` to denoise the encoded image latents. | |
text_encoder (`CLIPTextModel`): | |
Text-encoder, for empty text embedding. | |
tokenizer (`CLIPTokenizer`): | |
CLIP tokenizer. | |
""" | |
latent_scale_factor = 0.18215 | |
def __init__( | |
self, | |
unet: UNet2DConditionModel, | |
vae: AutoencoderKL, | |
scheduler: DDIMScheduler, | |
text_encoder: CLIPTextModel, | |
tokenizer: CLIPTokenizer, | |
): | |
super().__init__() | |
self.register_modules( | |
unet=unet, | |
vae=vae, | |
scheduler=scheduler, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
) | |
self.empty_text_embed = None | |
self.n_targets = 3 # Albedo, shading, residual | |
def __call__( | |
self, | |
input_image: Image, | |
denoising_steps: int = 4, | |
ensemble_size: int = 10, | |
processing_res: int = 768, | |
match_input_res: bool = True, | |
resample_method: str = "bilinear", | |
batch_size: int = 0, | |
save_memory: bool = False, | |
seed: Union[int, None] = None, | |
color_map: str = "Spectral", # TODO change colorization api based on modality | |
show_progress_bar: bool = True, | |
**kwargs, | |
) -> MarigoldIIDLightingOutput: | |
""" | |
Function invoked when calling the pipeline. | |
Args: | |
input_image (`Image`): | |
Input RGB (or gray-scale) image. | |
denoising_steps (`int`, *optional*, defaults to `10`): | |
Number of diffusion denoising steps (DDIM) during inference. | |
ensemble_size (`int`, *optional*, defaults to `10`): | |
Number of predictions to be ensembled. | |
processing_res (`int`, *optional*, defaults to `768`): | |
Maximum resolution of processing. | |
If set to 0: will not resize at all. | |
match_input_res (`bool`, *optional*, defaults to `True`): | |
Resize normals prediction to match input resolution. | |
Only valid if `limit_input_res` is not None. | |
resample_method: (`str`, *optional*, defaults to `bilinear`): | |
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. | |
batch_size (`int`, *optional*, defaults to `0`): | |
Inference batch size, no bigger than `num_ensemble`. | |
If set to 0, the script will automatically decide the proper batch size. | |
save_memory (`bool`, defaults to `False`): | |
Extra steps to save memory at the cost of perforance. | |
seed (`int`, *optional*, defaults to `None`) | |
Reproducibility seed. | |
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized normals map generation): | |
Colormap used to colorize the normals map. | |
show_progress_bar (`bool`, *optional*, defaults to `True`): | |
Display a progress bar of diffusion denoising. | |
Returns: | |
`MarigoldIIDLightingOutput`: Output class for Marigold monocular intrinsic image decomposition (lighting) prediction pipeline, including: | |
- **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1] | |
- **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1] | |
- **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1] | |
- **material_colored** (`PIL.Image.Image`) Colorized material map with the shape of [3, H, W] and values in [0, 1] | |
""" | |
if not match_input_res: | |
assert processing_res is not None | |
assert processing_res >= 0 | |
assert denoising_steps >= 1 | |
assert ensemble_size >= 1 | |
# Check if denoising step is reasonable | |
self.check_inference_step(denoising_steps) | |
resample_method: Resampling = self.get_pil_resample_method(resample_method) | |
W, H = input_image.size | |
if processing_res > 0: | |
input_image = self.resize_max_res( | |
input_image, max_edge_resolution=processing_res, resample_method=resample_method, | |
) | |
input_image = input_image.convert("RGB") | |
image = np.asarray(input_image) | |
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W] | |
rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] | |
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) | |
rgb_norm = rgb_norm.to(self.device) | |
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # TODO remove this | |
def ensemble( | |
targets: torch.Tensor, return_uncertainty: bool = False, reduction = "median", | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
uncertainty = None | |
if reduction == "mean": | |
prediction = torch.mean(targets, dim=0, keepdim=True) | |
if return_uncertainty: | |
uncertainty = torch.std(targets, dim=0, keepdim=True) | |
elif reduction == "median": | |
prediction = torch.median(targets, dim=0, keepdim=True).values | |
if return_uncertainty: | |
uncertainty = torch.median( | |
torch.abs(targets - prediction), dim=0, keepdim=True | |
).values | |
else: | |
raise ValueError(f"Unrecognized reduction method: {reduction}.") | |
return prediction, uncertainty | |
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) | |
single_rgb_dataset = TensorDataset(duplicated_rgb) | |
if batch_size <= 0: | |
batch_size = self.find_batch_size( | |
ensemble_size=ensemble_size, | |
input_res=max(rgb_norm.shape[1:]), | |
dtype=self.dtype, | |
) | |
single_rgb_loader = DataLoader( | |
single_rgb_dataset, batch_size=batch_size, shuffle=False | |
) | |
target_pred_ls = [] | |
iterable = single_rgb_loader | |
if show_progress_bar: | |
iterable = tqdm( | |
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False | |
) | |
for batch in iterable: | |
(batched_img,) = batch | |
target_pred = self.single_infer( | |
rgb_in=batched_img, | |
num_inference_steps=denoising_steps, | |
seed=seed, | |
show_pbar=show_progress_bar, | |
) | |
target_pred = target_pred.detach() | |
if save_memory: | |
target_pred = target_pred.cpu() | |
target_pred_ls.append(target_pred.detach()) | |
target_preds = torch.concat(target_pred_ls, dim=0) | |
pred_uncert = None | |
if save_memory: | |
torch.cuda.empty_cache() | |
if ensemble_size > 1: | |
final_pred, pred_uncert = ensemble( | |
target_preds, | |
reduction = "median", | |
return_uncertainty=False | |
) | |
else: | |
final_pred = target_preds | |
pred_uncert = None | |
if match_input_res: | |
final_pred = torch.nn.functional.interpolate( | |
final_pred, (H, W), mode="bilinear" # TODO: parameterize this method | |
) # [1,3,H,W] | |
if pred_uncert is not None: | |
pred_uncert = torch.nn.functional.interpolate( | |
pred_uncert.unsqueeze(1), (H, W), mode="bilinear" | |
).squeeze( | |
1 | |
) # [1,H,W] | |
# Convert to numpy | |
final_pred = final_pred.squeeze() | |
final_pred = final_pred.cpu().numpy() | |
albedo = final_pred[0:3, :, :] | |
shading = final_pred[3:6, :, :] | |
residual = final_pred[6:, :, :] | |
albedo_colored = (albedo + 1.0) * 0.5 # [-1,1] -> [0,1] | |
albedo_colored = albedo_colored ** (1/2.2) # from linear to sRGB (to be consistent with IID-Appearance model) | |
albedo_colored = (albedo_colored * 255).astype(np.uint8) | |
albedo_colored = self.chw2hwc(albedo_colored) | |
albedo_colored_img = Image.fromarray(albedo_colored) | |
shading_colored = (shading + 1.0) * 0.5 | |
shading_colored = shading_colored / shading_colored.max() # rescale for better visualization | |
shading_colored = (shading_colored * 255).astype(np.uint8) | |
shading_colored = self.chw2hwc(shading_colored) | |
shading_colored_img = Image.fromarray(shading_colored) | |
residual_colored = (residual + 1.0) * 0.5 | |
residual_colored = residual_colored / residual_colored.max() # rescale for better visualization | |
residual_colored = (residual_colored * 255).astype(np.uint8) | |
residual_colored = self.chw2hwc(residual_colored) | |
residual_colored_img = Image.fromarray(residual_colored) | |
out = MarigoldIIDLightingOutput( | |
albedo=albedo, | |
albedo_colored=albedo_colored_img, | |
shading=shading, | |
shading_colored=shading_colored_img, | |
residual=residual, | |
residual_colored=residual_colored_img | |
) | |
return out | |
def check_inference_step(self, n_step: int): | |
""" | |
Check if denoising step is reasonable | |
Args: | |
n_step (`int`): denoising steps | |
""" | |
assert n_step >= 1 | |
if isinstance(self.scheduler, DDIMScheduler): | |
pass | |
else: | |
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}") | |
def encode_empty_text(self): | |
""" | |
Encode text embedding for empty prompt. | |
""" | |
prompt = "" | |
text_inputs = self.tokenizer( | |
prompt, | |
padding="do_not_pad", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) | |
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) | |
def single_infer( | |
self, | |
rgb_in: torch.Tensor, | |
num_inference_steps: int, | |
seed: Union[int, None], | |
show_pbar: bool, | |
) -> torch.Tensor: | |
""" | |
Perform an individual iid prediction without ensembling. | |
""" | |
device = rgb_in.device | |
# Set timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = self.scheduler.timesteps # [T] | |
# Encode image | |
rgb_latent = self.encode_rgb(rgb_in) | |
target_latent_shape = list(rgb_latent.shape) | |
target_latent_shape[1] *= ( | |
3 # TODO: no hardcoding # self.n_targets # (B, 4*n_targets, h, w) | |
) | |
# Initialize prediction latent with noise | |
if seed is None: | |
rand_num_generator = None | |
else: | |
rand_num_generator = torch.Generator(device=device) | |
rand_num_generator.manual_seed(seed) | |
target_latents = torch.randn( | |
target_latent_shape, | |
device=device, | |
dtype=self.dtype, | |
generator=rand_num_generator, | |
) # [B, 4, h, w] | |
# Batched empty text embedding | |
if self.empty_text_embed is None: | |
self.encode_empty_text() | |
batch_empty_text_embed = self.empty_text_embed.repeat( | |
(rgb_latent.shape[0], 1, 1) | |
) # [B, 2, 1024] | |
# Denoising loop | |
if show_pbar: | |
iterable = tqdm( | |
enumerate(timesteps), | |
total=len(timesteps), | |
leave=False, | |
desc=" " * 4 + "Diffusion denoising", | |
) | |
else: | |
iterable = enumerate(timesteps) | |
for i, t in iterable: | |
unet_input = torch.cat( | |
[rgb_latent, target_latents], dim=1 | |
) # this order is important | |
# predict the noise residual | |
noise_pred = self.unet( | |
unet_input, t, encoder_hidden_states=batch_empty_text_embed | |
).sample # [B, 4, h, w] | |
# compute the previous noisy sample x_t -> x_t-1 | |
target_latents = self.scheduler.step( | |
noise_pred, t, target_latents, generator=rand_num_generator | |
).prev_sample | |
# torch.cuda.empty_cache() # TODO is it really needed here, even if memory saving? | |
targets = self.decode_targets(target_latents) # [B, 3, H, W] | |
targets = torch.clip(targets, -1.0, 1.0) | |
return targets | |
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: | |
""" | |
Encode RGB image into latent. | |
Args: | |
rgb_in (`torch.Tensor`): | |
Input RGB image to be encoded. | |
Returns: | |
`torch.Tensor`: Image latent. | |
""" | |
# encode | |
h = self.vae.encoder(rgb_in) | |
moments = self.vae.quant_conv(h) | |
mean, logvar = torch.chunk(moments, 2, dim=1) | |
# scale latent | |
rgb_latent = mean * self.latent_scale_factor | |
return rgb_latent | |
def decode_targets(self, target_latents: torch.Tensor) -> torch.Tensor: | |
""" | |
Decode target latent into target map. | |
Args: | |
target_latents (`torch.Tensor`): | |
Target latent to be decoded. | |
Returns: | |
`torch.Tensor`: Decoded target map. | |
""" | |
assert target_latents.shape[1] == 12 # self.n_targets * 4 | |
# scale latent | |
target_latents = target_latents / self.latent_scale_factor | |
# decode | |
targets = [] | |
for i in range(self.n_targets): | |
latent = target_latents[:, i * 4 : (i + 1) * 4, :, :] | |
z = self.vae.post_quant_conv(latent) | |
stacked = self.vae.decoder(z) | |
targets.append(stacked) | |
return torch.cat(targets, dim=1) | |
def get_pil_resample_method(method_str: str) -> Resampling: | |
resample_method_dic = { | |
"bilinear": Resampling.BILINEAR, | |
"bicubic": Resampling.BICUBIC, | |
"nearest": Resampling.NEAREST, | |
} | |
resample_method = resample_method_dic.get(method_str, None) | |
if resample_method is None: | |
raise ValueError(f"Unknown resampling method: {resample_method}") | |
else: | |
return resample_method | |
def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image: | |
""" | |
Resize image to limit maximum edge length while keeping aspect ratio. | |
""" | |
original_width, original_height = img.size | |
downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height) | |
new_width = int(original_width * downscale_factor) | |
new_height = int(original_height * downscale_factor) | |
resized_img = img.resize((new_width, new_height), resample=resample_method) | |
return resized_img | |
def chw2hwc(chw): | |
assert 3 == len(chw.shape) | |
if isinstance(chw, torch.Tensor): | |
hwc = torch.permute(chw, (1, 2, 0)) | |
elif isinstance(chw, np.ndarray): | |
hwc = np.moveaxis(chw, 0, -1) | |
return hwc | |
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: | |
""" | |
Automatically search for suitable operating batch size. | |
Args: | |
ensemble_size (`int`): | |
Number of predictions to be ensembled. | |
input_res (`int`): | |
Operating resolution of the input image. | |
Returns: | |
`int`: Operating batch size. | |
""" | |
# Search table for suggested max. inference batch size | |
bs_search_table = [ | |
# tested on A100-PCIE-80GB | |
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, | |
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, | |
# tested on A100-PCIE-40GB | |
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, | |
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, | |
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, | |
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, | |
# tested on RTX3090, RTX4090 | |
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, | |
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, | |
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, | |
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, | |
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, | |
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, | |
# tested on GTX1080Ti | |
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, | |
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, | |
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, | |
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, | |
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, | |
] | |
if not torch.cuda.is_available(): | |
return 1 | |
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 | |
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] | |
for settings in sorted( | |
filtered_bs_search_table, | |
key=lambda k: (k["res"], -k["total_vram"]), | |
): | |
if input_res <= settings["res"] and total_vram >= settings["total_vram"]: | |
bs = settings["bs"] | |
if bs > ensemble_size: | |
bs = ensemble_size | |
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: | |
bs = math.ceil(ensemble_size / 2) | |
return bs | |
return 1 | |