# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re from typing import Dict, List, Tuple, Union import torch import torch.nn as nn from ...models.attention_processor import ( Attention, AttentionProcessor, PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, ) from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name class PAGMixin: r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1).""" def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): r""" Set the attention processor for the PAG layers. """ pag_attn_processors = self._pag_attn_processors if pag_attn_processors is None: raise ValueError( "No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters." ) pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] if hasattr(self, "unet"): model: nn.Module = self.unet else: model: nn.Module = self.transformer def is_self_attn(module: nn.Module) -> bool: r""" Check if the module is self-attention module based on its name. """ return isinstance(module, Attention) and not module.is_cross_attention def is_fake_integral_match(layer_id, name): layer_id = layer_id.split(".")[-1] name = name.split(".")[-1] return layer_id.isnumeric() and name.isnumeric() and layer_id == name for layer_id in pag_applied_layers: # for each PAG layer input, we find corresponding self-attention layers in the unet model target_modules = [] for name, module in model.named_modules(): # Identify the following simple cases: # (1) Self Attention layer existing # (2) Whether the module name matches pag layer id even partially # (3) Make sure it's not a fake integral match if the layer_id ends with a number # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" if ( is_self_attn(module) and re.search(layer_id, name) is not None and not is_fake_integral_match(layer_id, name) ): logger.debug(f"Applying PAG to layer: {name}") target_modules.append(module) if len(target_modules) == 0: raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") for module in target_modules: module.processor = pag_attn_proc def _get_pag_scale(self, t): r""" Get the scale factor for the perturbed attention guidance at timestep `t`. """ if self.do_pag_adaptive_scaling: signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) if signal_scale < 0: signal_scale = 0 return signal_scale else: return self.pag_scale def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): r""" Apply perturbed attention guidance to the noise prediction. Args: noise_pred (torch.Tensor): The noise prediction tensor. do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. guidance_scale (float): The scale factor for the guidance term. t (int): The current time step. Returns: torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance. """ pag_scale = self._get_pag_scale(t) if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) noise_pred = ( noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + pag_scale * (noise_pred_text - noise_pred_perturb) ) else: noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) return noise_pred def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): """ Prepares the perturbed attention guidance for the PAG model. Args: cond (torch.Tensor): The conditional input tensor. uncond (torch.Tensor): The unconditional input tensor. do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance. Returns: torch.Tensor: The prepared perturbed attention guidance tensor. """ cond = torch.cat([cond] * 2, dim=0) if do_classifier_free_guidance: cond = torch.cat([uncond, cond], dim=0) return cond def set_pag_applied_layers( self, pag_applied_layers: Union[str, List[str]], pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = ( PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0(), ), ): r""" Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid. Args: pag_applied_layers (`str` or `List[str]`): One or more strings identifying the layer names, or a simple regex for matching multiple layers, where PAG is to be applied. A few ways of expected usage are as follows: - Single layers specified as - "blocks.{layer_index}" - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] - Multiple layers as a block name - "mid" - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" pag_attn_processors: (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second attention processor is for PAG with CFG disabled (unconditional only). """ if not hasattr(self, "_pag_attn_processors"): self._pag_attn_processors = None if not isinstance(pag_applied_layers, list): pag_applied_layers = [pag_applied_layers] if pag_attn_processors is not None: if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: raise ValueError("Expected a tuple of two attention processors") for i in range(len(pag_applied_layers)): if not isinstance(pag_applied_layers[i], str): raise ValueError( f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" ) self.pag_applied_layers = pag_applied_layers self._pag_attn_processors = pag_attn_processors @property def pag_scale(self) -> float: r"""Get the scale factor for the perturbed attention guidance.""" return self._pag_scale @property def pag_adaptive_scale(self) -> float: r"""Get the adaptive scale factor for the perturbed attention guidance.""" return self._pag_adaptive_scale @property def do_pag_adaptive_scaling(self) -> bool: r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property def do_perturbed_attention_guidance(self) -> bool: r"""Check if the perturbed attention guidance is enabled.""" return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 @property def pag_attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model with the key as the name of the layer. """ if self._pag_attn_processors is None: return {} valid_attn_processors = {x.__class__ for x in self._pag_attn_processors} processors = {} # We could have iterated through the self.components.items() and checked if a component is # `ModelMixin` subclassed but that can include a VAE too. if hasattr(self, "unet"): denoiser_module = self.unet elif hasattr(self, "transformer"): denoiser_module = self.transformer else: raise ValueError("No denoiser module found.") for name, proc in denoiser_module.attn_processors.items(): if proc.__class__ in valid_attn_processors: processors[name] = proc return processors