# Copyright 2024 Anton Obukhov, Bingxin Ke & Kevin Qu, ETH Zurich and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldcomputervision.github.io # -------------------------------------------------------------------------- import logging import math from typing import Optional, Tuple, Union import numpy as np import torch from diffusers import ( AutoencoderKL, DDIMScheduler, DiffusionPipeline, UNet2DConditionModel, ) from diffusers.utils import BaseOutput, check_min_version from PIL import Image from PIL.Image import Resampling from torch.utils.data import DataLoader, TensorDataset from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.27.0.dev0") class MarigoldIIDLightingOutput(BaseOutput): """ Output class for Marigold-IID-Lighting pipeline. Args: albedo (`np.ndarray`): Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1]. albedo_colored (`PIL.Image.Image`): Colorized albedo map with the shape of [H, W, 3]. shading (`np.ndarray`): Predicted diffuse shading map with the shape of [3, H, W] values in the range of [0, 1]. shading_colored (`PIL.Image.Image`): Colorized diffuse shading map with the shape of [H, W, 3]. residual (`np.ndarray`): Predicted non-diffuse residual map with the shape of [3, H, W] values in the range of [0, 1]. residual_colored (`PIL.Image.Image`): Colorized non-diffuse residual map with the shape of [H, W, 3]. """ albedo: np.ndarray albedo_colored: Image.Image shading: np.ndarray shading_colored: Image.Image residual: np.ndarray residual_colored: Image.Image class MarigoldIIDLightingPipeline(DiffusionPipeline): """ Pipeline for Intrinsic Image Decomposition (Albedo, diffuse shading and non-diffuse residual) using Marigold: https://marigoldcomputervision.github.io. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: unet (`UNet2DConditionModel`): Conditional U-Net to denoise the normals latent, conditioned on image latent. vae (`AutoencoderKL`): Variational Auto-Encoder (VAE) Model to encode and decode images and normals maps to and from latent representations. scheduler (`DDIMScheduler`): A scheduler to be used in combination with `unet` to denoise the encoded image latents. text_encoder (`CLIPTextModel`): Text-encoder, for empty text embedding. tokenizer (`CLIPTokenizer`): CLIP tokenizer. """ latent_scale_factor = 0.18215 def __init__( self, unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: DDIMScheduler, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, ): super().__init__() self.register_modules( unet=unet, vae=vae, scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, ) self.empty_text_embed = None self.n_targets = 3 # Albedo, shading, residual @torch.no_grad() def __call__( self, input_image: Image, denoising_steps: int = 4, ensemble_size: int = 10, processing_res: int = 768, match_input_res: bool = True, resample_method: str = "bilinear", batch_size: int = 0, save_memory: bool = False, seed: Union[int, None] = None, color_map: str = "Spectral", # TODO change colorization api based on modality show_progress_bar: bool = True, **kwargs, ) -> MarigoldIIDLightingOutput: """ Function invoked when calling the pipeline. Args: input_image (`Image`): Input RGB (or gray-scale) image. denoising_steps (`int`, *optional*, defaults to `10`): Number of diffusion denoising steps (DDIM) during inference. ensemble_size (`int`, *optional*, defaults to `10`): Number of predictions to be ensembled. processing_res (`int`, *optional*, defaults to `768`): Maximum resolution of processing. If set to 0: will not resize at all. match_input_res (`bool`, *optional*, defaults to `True`): Resize normals prediction to match input resolution. Only valid if `limit_input_res` is not None. resample_method: (`str`, *optional*, defaults to `bilinear`): Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. batch_size (`int`, *optional*, defaults to `0`): Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. save_memory (`bool`, defaults to `False`): Extra steps to save memory at the cost of perforance. seed (`int`, *optional*, defaults to `None`) Reproducibility seed. color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized normals map generation): Colormap used to colorize the normals map. show_progress_bar (`bool`, *optional*, defaults to `True`): Display a progress bar of diffusion denoising. Returns: `MarigoldIIDLightingOutput`: Output class for Marigold monocular intrinsic image decomposition (lighting) prediction pipeline, including: - **albedo** (`np.ndarray`) Predicted albedo map with the shape of [3, H, W] values in the range of [0, 1] - **albedo_colored** (`PIL.Image.Image`) Colorized albedo map with the shape of [3, H, W] values in the range of [0, 1] - **material** (`np.ndarray`) Predicted material map with the shape of [3, H, W] and values in [0, 1] - **material_colored** (`PIL.Image.Image`) Colorized material map with the shape of [3, H, W] and values in [0, 1] """ if not match_input_res: assert processing_res is not None assert processing_res >= 0 assert denoising_steps >= 1 assert ensemble_size >= 1 # Check if denoising step is reasonable self.check_inference_step(denoising_steps) resample_method: Resampling = self.get_pil_resample_method(resample_method) W, H = input_image.size if processing_res > 0: input_image = self.resize_max_res( input_image, max_edge_resolution=processing_res, resample_method=resample_method, ) input_image = input_image.convert("RGB") image = np.asarray(input_image) rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W] rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) rgb_norm = rgb_norm.to(self.device) assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 # TODO remove this def ensemble( targets: torch.Tensor, return_uncertainty: bool = False, reduction = "median", ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: uncertainty = None if reduction == "mean": prediction = torch.mean(targets, dim=0, keepdim=True) if return_uncertainty: uncertainty = torch.std(targets, dim=0, keepdim=True) elif reduction == "median": prediction = torch.median(targets, dim=0, keepdim=True).values if return_uncertainty: uncertainty = torch.median( torch.abs(targets - prediction), dim=0, keepdim=True ).values else: raise ValueError(f"Unrecognized reduction method: {reduction}.") return prediction, uncertainty duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) single_rgb_dataset = TensorDataset(duplicated_rgb) if batch_size <= 0: batch_size = self.find_batch_size( ensemble_size=ensemble_size, input_res=max(rgb_norm.shape[1:]), dtype=self.dtype, ) single_rgb_loader = DataLoader( single_rgb_dataset, batch_size=batch_size, shuffle=False ) target_pred_ls = [] iterable = single_rgb_loader if show_progress_bar: iterable = tqdm( single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False ) for batch in iterable: (batched_img,) = batch target_pred = self.single_infer( rgb_in=batched_img, num_inference_steps=denoising_steps, seed=seed, show_pbar=show_progress_bar, ) target_pred = target_pred.detach() if save_memory: target_pred = target_pred.cpu() target_pred_ls.append(target_pred.detach()) target_preds = torch.concat(target_pred_ls, dim=0) pred_uncert = None if save_memory: torch.cuda.empty_cache() if ensemble_size > 1: final_pred, pred_uncert = ensemble( target_preds, reduction = "median", return_uncertainty=False ) else: final_pred = target_preds pred_uncert = None if match_input_res: final_pred = torch.nn.functional.interpolate( final_pred, (H, W), mode="bilinear" # TODO: parameterize this method ) # [1,3,H,W] if pred_uncert is not None: pred_uncert = torch.nn.functional.interpolate( pred_uncert.unsqueeze(1), (H, W), mode="bilinear" ).squeeze( 1 ) # [1,H,W] # Convert to numpy final_pred = final_pred.squeeze() final_pred = final_pred.cpu().numpy() albedo = final_pred[0:3, :, :] shading = final_pred[3:6, :, :] residual = final_pred[6:, :, :] albedo_colored = (albedo + 1.0) * 0.5 # [-1,1] -> [0,1] albedo_colored = albedo_colored ** (1/2.2) # from linear to sRGB (to be consistent with IID-Appearance model) albedo_colored = (albedo_colored * 255).astype(np.uint8) albedo_colored = self.chw2hwc(albedo_colored) albedo_colored_img = Image.fromarray(albedo_colored) shading_colored = (shading + 1.0) * 0.5 shading_colored = shading_colored / shading_colored.max() # rescale for better visualization shading_colored = (shading_colored * 255).astype(np.uint8) shading_colored = self.chw2hwc(shading_colored) shading_colored_img = Image.fromarray(shading_colored) residual_colored = (residual + 1.0) * 0.5 residual_colored = residual_colored / residual_colored.max() # rescale for better visualization residual_colored = (residual_colored * 255).astype(np.uint8) residual_colored = self.chw2hwc(residual_colored) residual_colored_img = Image.fromarray(residual_colored) out = MarigoldIIDLightingOutput( albedo=albedo, albedo_colored=albedo_colored_img, shading=shading, shading_colored=shading_colored_img, residual=residual, residual_colored=residual_colored_img ) return out def check_inference_step(self, n_step: int): """ Check if denoising step is reasonable Args: n_step (`int`): denoising steps """ assert n_step >= 1 if isinstance(self.scheduler, DDIMScheduler): pass else: raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}") def encode_empty_text(self): """ Encode text embedding for empty prompt. """ prompt = "" text_inputs = self.tokenizer( prompt, padding="do_not_pad", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) @torch.no_grad() def single_infer( self, rgb_in: torch.Tensor, num_inference_steps: int, seed: Union[int, None], show_pbar: bool, ) -> torch.Tensor: """ Perform an individual iid prediction without ensembling. """ device = rgb_in.device # Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # [T] # Encode image rgb_latent = self.encode_rgb(rgb_in) target_latent_shape = list(rgb_latent.shape) target_latent_shape[1] *= ( 3 # TODO: no hardcoding # self.n_targets # (B, 4*n_targets, h, w) ) # Initialize prediction latent with noise if seed is None: rand_num_generator = None else: rand_num_generator = torch.Generator(device=device) rand_num_generator.manual_seed(seed) target_latents = torch.randn( target_latent_shape, device=device, dtype=self.dtype, generator=rand_num_generator, ) # [B, 4, h, w] # Batched empty text embedding if self.empty_text_embed is None: self.encode_empty_text() batch_empty_text_embed = self.empty_text_embed.repeat( (rgb_latent.shape[0], 1, 1) ) # [B, 2, 1024] # Denoising loop if show_pbar: iterable = tqdm( enumerate(timesteps), total=len(timesteps), leave=False, desc=" " * 4 + "Diffusion denoising", ) else: iterable = enumerate(timesteps) for i, t in iterable: unet_input = torch.cat( [rgb_latent, target_latents], dim=1 ) # this order is important # predict the noise residual noise_pred = self.unet( unet_input, t, encoder_hidden_states=batch_empty_text_embed ).sample # [B, 4, h, w] # compute the previous noisy sample x_t -> x_t-1 target_latents = self.scheduler.step( noise_pred, t, target_latents, generator=rand_num_generator ).prev_sample # torch.cuda.empty_cache() # TODO is it really needed here, even if memory saving? targets = self.decode_targets(target_latents) # [B, 3, H, W] targets = torch.clip(targets, -1.0, 1.0) return targets def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: """ Encode RGB image into latent. Args: rgb_in (`torch.Tensor`): Input RGB image to be encoded. Returns: `torch.Tensor`: Image latent. """ # encode h = self.vae.encoder(rgb_in) moments = self.vae.quant_conv(h) mean, logvar = torch.chunk(moments, 2, dim=1) # scale latent rgb_latent = mean * self.latent_scale_factor return rgb_latent def decode_targets(self, target_latents: torch.Tensor) -> torch.Tensor: """ Decode target latent into target map. Args: target_latents (`torch.Tensor`): Target latent to be decoded. Returns: `torch.Tensor`: Decoded target map. """ assert target_latents.shape[1] == 12 # self.n_targets * 4 # scale latent target_latents = target_latents / self.latent_scale_factor # decode targets = [] for i in range(self.n_targets): latent = target_latents[:, i * 4 : (i + 1) * 4, :, :] z = self.vae.post_quant_conv(latent) stacked = self.vae.decoder(z) targets.append(stacked) return torch.cat(targets, dim=1) @staticmethod def get_pil_resample_method(method_str: str) -> Resampling: resample_method_dic = { "bilinear": Resampling.BILINEAR, "bicubic": Resampling.BICUBIC, "nearest": Resampling.NEAREST, } resample_method = resample_method_dic.get(method_str, None) if resample_method is None: raise ValueError(f"Unknown resampling method: {resample_method}") else: return resample_method @staticmethod def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image: """ Resize image to limit maximum edge length while keeping aspect ratio. """ original_width, original_height = img.size downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height) new_width = int(original_width * downscale_factor) new_height = int(original_height * downscale_factor) resized_img = img.resize((new_width, new_height), resample=resample_method) return resized_img @staticmethod def chw2hwc(chw): assert 3 == len(chw.shape) if isinstance(chw, torch.Tensor): hwc = torch.permute(chw, (1, 2, 0)) elif isinstance(chw, np.ndarray): hwc = np.moveaxis(chw, 0, -1) return hwc @staticmethod def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: """ Automatically search for suitable operating batch size. Args: ensemble_size (`int`): Number of predictions to be ensembled. input_res (`int`): Operating resolution of the input image. Returns: `int`: Operating batch size. """ # Search table for suggested max. inference batch size bs_search_table = [ # tested on A100-PCIE-80GB {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, # tested on A100-PCIE-40GB {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, # tested on RTX3090, RTX4090 {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, # tested on GTX1080Ti {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, ] if not torch.cuda.is_available(): return 1 total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] for settings in sorted( filtered_bs_search_table, key=lambda k: (k["res"], -k["total_vram"]), ): if input_res <= settings["res"] and total_vram >= settings["total_vram"]: bs = settings["bs"] if bs > ensemble_size: bs = ensemble_size elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: bs = math.ceil(ensemble_size / 2) return bs return 1