lora_test / ppdiffusers /models /cross_attention.py
1toTree's picture
Upload with huggingface_hub
21231ee
raw
history blame
17.6 kB
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ..initializer import normal_, zeros_
class CrossAttention(nn.Layer):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
processor: Optional["AttnProcessor"] = None,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.scale = dim_head**-0.5
self.num_heads = heads
self.head_dim = inner_dim // heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self.added_kv_proj_dim = added_kv_proj_dim
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, epsilon=1e-5)
else:
self.group_norm = None
self.to_q = nn.Linear(query_dim, inner_dim, bias_attr=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias)
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.to_out = nn.LayerList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
processor = processor if processor is not None else CrossAttnProcessor()
self.set_processor(processor)
def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
if slice_size is not None and self.added_kv_proj_dim is not None:
processor = SlicedAttnAddedKVProcessor(slice_size)
elif slice_size is not None:
processor = SlicedAttnProcessor(slice_size)
elif self.added_kv_proj_dim is not None:
processor = CrossAttnAddedKVProcessor()
else:
processor = CrossAttnProcessor()
self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor"):
self.processor = processor
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
# The `CrossAttention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
def batch_to_head_dim(self, tensor):
tensor = tensor.transpose([0, 2, 1, 3])
tensor = tensor.reshape([0, 0, tensor.shape[2] * tensor.shape[3]])
return tensor
def head_to_batch_dim(self, tensor):
tensor = tensor.reshape([0, 0, self.num_heads, self.head_dim])
tensor = tensor.transpose([0, 2, 1, 3])
return tensor
def get_attention_scores(self, query, key, attention_mask=None):
if self.upcast_attention:
query = query.cast("float32")
key = key.cast("float32")
attention_scores = paddle.matmul(query, key, transpose_y=True) * self.scale
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
if self.upcast_softmax:
attention_scores = attention_scores.cast("float32")
attention_probs = F.softmax(attention_scores, axis=-1)
if self.upcast_softmax:
attention_probs = attention_probs.cast(query.dtype)
return attention_probs
def prepare_attention_mask(self, attention_mask, target_length):
if attention_mask is None:
return attention_mask
if attention_mask.shape[-1] != target_length:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0, data_format="NCL")
attention_mask = attention_mask.repeat_interleave(self.num_heads, axis=0)
return attention_mask
class CrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = (
attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]])
if attention_mask is not None
else None
)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = paddle.matmul(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LoRALinearLayer(nn.Layer):
def __init__(self, in_features, out_features, rank=4):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias_attr=False)
self.up = nn.Linear(rank, out_features, bias_attr=False)
self.scale = 1.0
normal_(self.down.weight, std=1 / rank)
zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.cast(dtype))
up_hidden_states = self.up(down_hidden_states)
return up_hidden_states.cast(orig_dtype)
class LoRACrossAttnProcessor(nn.Layer):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = (
attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]])
if attention_mask is not None
else None
)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = paddle.matmul(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CrossAttnAddedKVProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose(
[0, 2, 1]
)
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1])
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = (
attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]])
if attention_mask is not None
else None
)
hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1])
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2)
value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = paddle.matmul(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
class SlicedAttnProcessor:
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
query = query.flatten(0, 1)
key = key.flatten(0, 1)
value = value.flatten(0, 1)
batch_size_attention = query.shape[0]
hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype)
for i in range(hidden_states.shape[0] // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape back to [bs, num_heads, seqlen, head_dim]
hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim])
# reshape hidden_states
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SlicedAttnAddedKVProcessor:
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose(
[0, 2, 1]
)
encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1])
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1])
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2)
value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2)
query = query.flatten(0, 1)
key = key.flatten(0, 1)
value = value.flatten(0, 1)
batch_size_attention = query.shape[0]
hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype)
for i in range(hidden_states.shape[0] // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape back to [bs, num_heads, seqlen, head_dim]
hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim])
# reshape hidden_states
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape)
hidden_states = hidden_states + residual
return hidden_states
AttnProcessor = Union[
CrossAttnProcessor,
SlicedAttnProcessor,
CrossAttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
]