# Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py from typing import Any, Dict, Optional import torch from einops import rearrange from models_diffusers.camera.attention import TemporalPoseCondTransformerBlock as TemporalBasicTransformerBlock from diffusers.models.attention import BasicTransformerBlock from torch import nn def torch_dfs(model: torch.nn.Module): result = [model] for child in model.children(): result += torch_dfs(child) return result def _chunked_feed_forward( ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None ): # "feed_forward_chunk_size" can be used to save memory if hidden_states.shape[chunk_dim] % chunk_size != 0: raise ValueError( f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." ) num_chunks = hidden_states.shape[chunk_dim] // chunk_size if lora_scale is None: ff_output = torch.cat( [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], dim=chunk_dim, ) else: # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete ff_output = torch.cat( [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], dim=chunk_dim, ) return ff_output class ReferenceAttentionControl: def __init__( self, unet, mode="write", do_classifier_free_guidance=False, attention_auto_machine_weight=float("inf"), gn_auto_machine_weight=1.0, style_fidelity=1.0, reference_attn=True, reference_adain=False, fusion_blocks="midup", batch_size=1, ) -> None: # 10. Modify self attention and group norm self.unet = unet assert mode in ["read", "write"] assert fusion_blocks in ["midup", "full"] self.reference_attn = reference_attn self.reference_adain = reference_adain self.fusion_blocks = fusion_blocks self.register_reference_hooks( mode, do_classifier_free_guidance, attention_auto_machine_weight, gn_auto_machine_weight, style_fidelity, reference_attn, reference_adain, fusion_blocks, batch_size=batch_size, ) def register_reference_hooks( self, mode, do_classifier_free_guidance, attention_auto_machine_weight, gn_auto_machine_weight, style_fidelity, reference_attn, reference_adain, dtype=torch.float16, batch_size=1, num_images_per_prompt=1, device=torch.device("cpu"), fusion_blocks="midup", ): MODE = mode do_classifier_free_guidance = do_classifier_free_guidance attention_auto_machine_weight = attention_auto_machine_weight gn_auto_machine_weight = gn_auto_machine_weight style_fidelity = style_fidelity reference_attn = reference_attn reference_adain = reference_adain fusion_blocks = fusion_blocks num_images_per_prompt = num_images_per_prompt dtype = dtype if do_classifier_free_guidance: uc_mask = ( torch.Tensor( [1] * batch_size * num_images_per_prompt * 16 + [0] * batch_size * num_images_per_prompt * 16 ) .to(device) .bool() ) else: uc_mask = ( torch.Tensor([0] * batch_size * num_images_per_prompt * 2) .to(device) .bool() ) def hacked_basic_transformer_inner_forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, video_length=None, self_attention_additional_feats=None, mode=None, ): batch_size = hidden_states.shape[0] if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype ) elif self.use_layer_norm: norm_hidden_states = self.norm1(hidden_states) elif self.use_ada_layer_norm_single: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) ).chunk(6, dim=1) norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa norm_hidden_states = norm_hidden_states.squeeze(1) else: raise ValueError("Incorrect norm used") if self.pos_embed is not None: norm_hidden_states = self.pos_embed(norm_hidden_states) # 1. Retrieve lora scale. lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # 2. Prepare GLIGEN inputs cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} gligen_kwargs = cross_attention_kwargs.pop("gligen", None) if self.only_cross_attention: attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) else: if MODE == "write": # print("this is write") self.bank.append(norm_hidden_states.clone()) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) if MODE == "read": # bank_fea = [ # rearrange( # d.unsqueeze(1).repeat(1, video_length, 1, 1), # "b t l c -> (b t) l c", # ) # for d in self.bank # ] bank_fea=[] for d in self.bank: if d.shape[0]==1: bank_fea.append(d.repeat(norm_hidden_states.shape[0],1,1)) else: bank_fea.append(d) modify_norm_hidden_states = torch.cat( [norm_hidden_states] + bank_fea, dim=1 ) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=modify_norm_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output elif self.use_ada_layer_norm_single: attn_output = gate_msa * attn_output hidden_states = attn_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) # 2.5 GLIGEN Control if gligen_kwargs is not None: hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) # 3. Cross-Attention if self.attn2 is not None: if self.use_ada_layer_norm: norm_hidden_states = self.norm2(hidden_states, timestep) elif self.use_ada_layer_norm_zero or self.use_layer_norm: norm_hidden_states = self.norm2(hidden_states) elif self.use_ada_layer_norm_single: # For PixArt norm2 isn't applied here: # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 norm_hidden_states = hidden_states else: raise ValueError("Incorrect norm") if self.pos_embed is not None and self.use_ada_layer_norm_single is False: norm_hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 4. Feed-forward if not self.use_ada_layer_norm_single: norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self.use_ada_layer_norm_single: norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward( self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale ) else: ff_output = self.ff(norm_hidden_states, scale=lora_scale) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output elif self.use_ada_layer_norm_single: ff_output = gate_mlp * ff_output hidden_states = ff_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) return hidden_states if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output elif self.use_ada_layer_norm_single: attn_output = gate_msa * attn_output hidden_states = attn_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) # 2.5 GLIGEN Control if gligen_kwargs is not None: hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) # 3. Cross-Attention if self.attn2 is not None: if self.use_ada_layer_norm: norm_hidden_states = self.norm2(hidden_states, timestep) elif self.use_ada_layer_norm_zero or self.use_layer_norm: norm_hidden_states = self.norm2(hidden_states) elif self.use_ada_layer_norm_single: # For PixArt norm2 isn't applied here: # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 norm_hidden_states = hidden_states else: raise ValueError("Incorrect norm") if self.pos_embed is not None and self.use_ada_layer_norm_single is False: norm_hidden_states = self.pos_embed(norm_hidden_states) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 4. Feed-forward if not self.use_ada_layer_norm_single: norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self.use_ada_layer_norm_single: norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward( self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale ) else: ff_output = self.ff(norm_hidden_states, scale=lora_scale) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output elif self.use_ada_layer_norm_single: ff_output = gate_mlp * ff_output hidden_states = ff_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) return hidden_states if self.reference_attn: if self.fusion_blocks == "midup": attn_modules = [ module for module in ( torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) ) if isinstance(module, BasicTransformerBlock) # or isinstance(module, TemporalBasicTransformerBlock) ] elif self.fusion_blocks == "full": attn_modules = [ module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) # or isinstance(module, TemporalBasicTransformerBlock) ] attn_modules = sorted( attn_modules, key=lambda x: -x.norm1.normalized_shape[0] ) for i, module in enumerate(attn_modules): module._original_inner_forward = module.forward if isinstance(module, BasicTransformerBlock): module.forward = hacked_basic_transformer_inner_forward.__get__( module, BasicTransformerBlock ) # if isinstance(module, TemporalBasicTransformerBlock): # module.forward = hacked_basic_transformer_inner_forward.__get__( # module, TemporalBasicTransformerBlock # ) module.bank = [] module.attn_weight = float(i) / float(len(attn_modules)) def update(self, writer, dtype=torch.float16): if self.reference_attn: if self.fusion_blocks == "midup": reader_attn_modules = [ module for module in ( torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) ) if isinstance(module, BasicTransformerBlock) ] writer_attn_modules = [ module for module in ( torch_dfs(writer.unet.mid_block) + torch_dfs(writer.unet.up_blocks) ) if isinstance(module, BasicTransformerBlock) ] elif self.fusion_blocks == "full": # reader_attn_modules = [ # module # for module in torch_dfs(self.unet) # if isinstance(module, TemporalBasicTransformerBlock) # ] reader_attn_modules = [ module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) ] writer_attn_modules = [ module for module in torch_dfs(writer.unet) if isinstance(module, BasicTransformerBlock) ] reader_attn_modules = sorted( reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] ) writer_attn_modules = sorted( writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] ) for r, w in zip(reader_attn_modules, writer_attn_modules): r.bank = [v.clone().to(dtype) for v in w.bank] # w.bank.clear() def clear(self): if self.reference_attn: if self.fusion_blocks == "midup": reader_attn_modules = [ module for module in ( torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks) ) if isinstance(module, BasicTransformerBlock) # or isinstance(module, TemporalBasicTransformerBlock) ] elif self.fusion_blocks == "full": reader_attn_modules = [ module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock) # or isinstance(module, TemporalBasicTransformerBlock) ] reader_attn_modules = sorted( reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0] ) for r in reader_attn_modules: r.bank.clear()