import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init import logging from diffusers.models.attention import Attention from diffusers.utils import USE_PEFT_BACKEND, is_xformers_available from typing import Optional, Callable from einops import rearrange if is_xformers_available(): import xformers import xformers.ops else: xformers = None logger = logging.getLogger(__name__) class AttnProcessor: r""" Default processor for performing attention-related computations. """ def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, pose_feature=None, # the only difference to the original code ) -> torch.Tensor: residual = hidden_states args = () if USE_PEFT_BACKEND else (scale,) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class AttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, pose_feature=None ) -> torch.FloatTensor: residual = hidden_states args = () if USE_PEFT_BACKEND else (scale,) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) args = () if USE_PEFT_BACKEND else (scale,) query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class XFormersAttnProcessor: r""" Processor for implementing memory efficient attention using xFormers. Args: attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. """ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, pose_feature=None, # the only difference to the original code ) -> torch.FloatTensor: residual = hidden_states args = () if USE_PEFT_BACKEND else (scale,) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, key_tokens, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) if attention_mask is not None: # expand our mask's singleton query_tokens dimension: # [batch*heads, 1, key_tokens] -> # [batch*heads, query_tokens, key_tokens] # so that it can be added as a bias onto the attention scores that xformers computes: # [batch*heads, query_tokens, key_tokens] # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class PoseAdaptorAttnProcessor(nn.Module): def __init__( self, hidden_size, # dimension of hidden state pose_feature_dim=None, # dimension of the pose feature cross_attention_dim=None, # dimension of the text embedding query_condition=False, key_value_condition=False, scale=1.0, ): super().__init__() self.hidden_size = hidden_size self.pose_feature_dim = pose_feature_dim self.cross_attention_dim = cross_attention_dim self.scale = scale self.query_condition = query_condition self.key_value_condition = key_value_condition assert hidden_size == pose_feature_dim if self.query_condition and self.key_value_condition: self.qkv_merge = nn.Linear(hidden_size, hidden_size) init.zeros_(self.qkv_merge.weight) init.zeros_(self.qkv_merge.bias) elif self.query_condition: self.q_merge = nn.Linear(hidden_size, hidden_size) init.zeros_(self.q_merge.weight) init.zeros_(self.q_merge.bias) else: self.kv_merge = nn.Linear(hidden_size, hidden_size) init.zeros_(self.kv_merge.weight) init.zeros_(self.kv_merge.bias) def forward( self, attn, hidden_states, pose_feature, encoder_hidden_states=None, attention_mask=None, temb=None, scale=None, ): assert pose_feature is not None pose_embedding_scale = (scale or self.scale) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) assert hidden_states.ndim == 3 and pose_feature.ndim == 3 if self.query_condition and self.key_value_condition: assert encoder_hidden_states is None if encoder_hidden_states is None: encoder_hidden_states = hidden_states assert encoder_hidden_states.ndim == 3 batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.query_condition and self.key_value_condition: # only self attention query_hidden_state = self.qkv_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states key_value_hidden_state = query_hidden_state elif self.query_condition: query_hidden_state = self.q_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states key_value_hidden_state = encoder_hidden_states else: key_value_hidden_state = self.kv_merge(encoder_hidden_states + pose_feature) * pose_embedding_scale + encoder_hidden_states query_hidden_state = hidden_states # original attention query = attn.to_q(query_hidden_state) key = attn.to_k(key_value_hidden_state) value = attn.to_v(key_value_hidden_state) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class PoseAdaptorAttnProcessor2_0(nn.Module): def __init__( self, hidden_size, # dimension of hidden state pose_feature_dim=None, # dimension of the pose feature cross_attention_dim=None, # dimension of the text embedding query_condition=False, key_value_condition=False, scale=1.0, ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.hidden_size = hidden_size self.pose_feature_dim = pose_feature_dim self.cross_attention_dim = cross_attention_dim self.scale = scale self.query_condition = query_condition self.key_value_condition = key_value_condition assert hidden_size == pose_feature_dim if self.query_condition and self.key_value_condition: self.qkv_merge = nn.Linear(hidden_size, hidden_size) init.zeros_(self.qkv_merge.weight) init.zeros_(self.qkv_merge.bias) elif self.query_condition: self.q_merge = nn.Linear(hidden_size, hidden_size) init.zeros_(self.q_merge.weight) init.zeros_(self.q_merge.bias) else: self.kv_merge = nn.Linear(hidden_size, hidden_size) init.zeros_(self.kv_merge.weight) init.zeros_(self.kv_merge.bias) def forward( self, attn, hidden_states, pose_feature, encoder_hidden_states=None, attention_mask=None, temb=None, scale=None, ): assert pose_feature is not None pose_embedding_scale = (scale or self.scale) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) assert hidden_states.ndim == 3 and pose_feature.ndim == 3 if self.query_condition and self.key_value_condition: assert encoder_hidden_states is None if encoder_hidden_states is None: encoder_hidden_states = hidden_states assert encoder_hidden_states.ndim == 3 batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.query_condition and self.key_value_condition: # only self attention query_hidden_state = self.qkv_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states key_value_hidden_state = query_hidden_state elif self.query_condition: query_hidden_state = self.q_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states key_value_hidden_state = encoder_hidden_states else: key_value_hidden_state = self.kv_merge(encoder_hidden_states + pose_feature) * pose_embedding_scale + encoder_hidden_states query_hidden_state = hidden_states # original attention query = attn.to_q(query_hidden_state) key = attn.to_k(key_value_hidden_state) value = attn.to_v(key_value_hidden_state) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # [bs, seq_len, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) # [bs, nhead, seq_len, head_dim] hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) # [bs, seq_len, dim] hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class PoseAdaptorXFormersAttnProcessor(nn.Module): def __init__( self, hidden_size, # dimension of hidden state pose_feature_dim=None, # dimension of the pose feature cross_attention_dim=None, # dimension of the text embedding query_condition=False, key_value_condition=False, scale=1.0, attention_op: Optional[Callable] = None, ): super().__init__() self.hidden_size = hidden_size self.pose_feature_dim = pose_feature_dim self.cross_attention_dim = cross_attention_dim self.scale = scale self.query_condition = query_condition self.key_value_condition = key_value_condition self.attention_op = attention_op assert hidden_size == pose_feature_dim if self.query_condition and self.key_value_condition: self.qkv_merge = nn.Linear(hidden_size, hidden_size) init.zeros_(self.qkv_merge.weight) init.zeros_(self.qkv_merge.bias) elif self.query_condition: self.q_merge = nn.Linear(hidden_size, hidden_size) init.zeros_(self.q_merge.weight) init.zeros_(self.q_merge.bias) else: self.kv_merge = nn.Linear(hidden_size, hidden_size) init.zeros_(self.kv_merge.weight) init.zeros_(self.kv_merge.bias) def forward( self, attn, hidden_states, pose_feature, encoder_hidden_states=None, attention_mask=None, temb=None, scale=None, ): assert pose_feature is not None pose_embedding_scale = (scale or self.scale) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) assert hidden_states.ndim == 3 and pose_feature.ndim == 3 if self.query_condition and self.key_value_condition: assert encoder_hidden_states is None if encoder_hidden_states is None: encoder_hidden_states = hidden_states assert encoder_hidden_states.ndim == 3 batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size) if attention_mask is not None: # expand our mask's singleton query_tokens dimension: # [batch*heads, 1, key_tokens] -> # [batch*heads, query_tokens, key_tokens] # so that it can be added as a bias onto the attention scores that xformers computes: # [batch*heads, query_tokens, key_tokens] # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.query_condition and self.key_value_condition: # only self attention query_hidden_state = self.qkv_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states key_value_hidden_state = query_hidden_state elif self.query_condition: query_hidden_state = self.q_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states key_value_hidden_state = encoder_hidden_states else: key_value_hidden_state = self.kv_merge(encoder_hidden_states + pose_feature) * pose_embedding_scale + encoder_hidden_states query_hidden_state = hidden_states # original attention query = attn.to_q(query_hidden_state) key = attn.to_k(key_value_hidden_state) value = attn.to_v(key_value_hidden_state) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states