|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Union |
|
|
|
import paddle |
|
import paddle.nn as nn |
|
import paddle.nn.functional as F |
|
|
|
from ..initializer import normal_, zeros_ |
|
|
|
|
|
class CrossAttention(nn.Layer): |
|
r""" |
|
A cross attention layer. |
|
|
|
Parameters: |
|
query_dim (`int`): The number of channels in the query. |
|
cross_attention_dim (`int`, *optional*): |
|
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. |
|
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. |
|
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. |
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
|
bias (`bool`, *optional*, defaults to False): |
|
Set to `True` for the query, key, and value linear layers to contain a bias parameter. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
query_dim: int, |
|
cross_attention_dim: Optional[int] = None, |
|
heads: int = 8, |
|
dim_head: int = 64, |
|
dropout: float = 0.0, |
|
bias=False, |
|
upcast_attention: bool = False, |
|
upcast_softmax: bool = False, |
|
added_kv_proj_dim: Optional[int] = None, |
|
norm_num_groups: Optional[int] = None, |
|
processor: Optional["AttnProcessor"] = None, |
|
): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim |
|
self.upcast_attention = upcast_attention |
|
self.upcast_softmax = upcast_softmax |
|
|
|
self.scale = dim_head**-0.5 |
|
self.num_heads = heads |
|
self.head_dim = inner_dim // heads |
|
|
|
|
|
|
|
self.sliceable_head_dim = heads |
|
|
|
self.added_kv_proj_dim = added_kv_proj_dim |
|
|
|
if norm_num_groups is not None: |
|
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, epsilon=1e-5) |
|
else: |
|
self.group_norm = None |
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias_attr=bias) |
|
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias) |
|
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias) |
|
|
|
if self.added_kv_proj_dim is not None: |
|
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) |
|
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) |
|
|
|
self.to_out = nn.LayerList([]) |
|
self.to_out.append(nn.Linear(inner_dim, query_dim)) |
|
self.to_out.append(nn.Dropout(dropout)) |
|
|
|
|
|
processor = processor if processor is not None else CrossAttnProcessor() |
|
self.set_processor(processor) |
|
|
|
def set_attention_slice(self, slice_size): |
|
if slice_size is not None and slice_size > self.sliceable_head_dim: |
|
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") |
|
|
|
if slice_size is not None and self.added_kv_proj_dim is not None: |
|
processor = SlicedAttnAddedKVProcessor(slice_size) |
|
elif slice_size is not None: |
|
processor = SlicedAttnProcessor(slice_size) |
|
elif self.added_kv_proj_dim is not None: |
|
processor = CrossAttnAddedKVProcessor() |
|
else: |
|
processor = CrossAttnProcessor() |
|
|
|
self.set_processor(processor) |
|
|
|
def set_processor(self, processor: "AttnProcessor"): |
|
self.processor = processor |
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): |
|
|
|
|
|
|
|
return self.processor( |
|
self, |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
|
|
def batch_to_head_dim(self, tensor): |
|
tensor = tensor.transpose([0, 2, 1, 3]) |
|
tensor = tensor.reshape([0, 0, tensor.shape[2] * tensor.shape[3]]) |
|
return tensor |
|
|
|
def head_to_batch_dim(self, tensor): |
|
tensor = tensor.reshape([0, 0, self.num_heads, self.head_dim]) |
|
tensor = tensor.transpose([0, 2, 1, 3]) |
|
return tensor |
|
|
|
def get_attention_scores(self, query, key, attention_mask=None): |
|
if self.upcast_attention: |
|
query = query.cast("float32") |
|
key = key.cast("float32") |
|
|
|
attention_scores = paddle.matmul(query, key, transpose_y=True) * self.scale |
|
|
|
if attention_mask is not None: |
|
attention_scores = attention_scores + attention_mask |
|
|
|
if self.upcast_softmax: |
|
attention_scores = attention_scores.cast("float32") |
|
|
|
attention_probs = F.softmax(attention_scores, axis=-1) |
|
if self.upcast_softmax: |
|
attention_probs = attention_probs.cast(query.dtype) |
|
|
|
return attention_probs |
|
|
|
def prepare_attention_mask(self, attention_mask, target_length): |
|
if attention_mask is None: |
|
return attention_mask |
|
|
|
if attention_mask.shape[-1] != target_length: |
|
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0, data_format="NCL") |
|
attention_mask = attention_mask.repeat_interleave(self.num_heads, axis=0) |
|
return attention_mask |
|
|
|
|
|
class CrossAttnProcessor: |
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): |
|
batch_size, sequence_length, _ = hidden_states.shape |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) |
|
attention_mask = ( |
|
attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]]) |
|
if attention_mask is not None |
|
else None |
|
) |
|
|
|
query = attn.to_q(hidden_states) |
|
query = attn.head_to_batch_dim(query) |
|
|
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = paddle.matmul(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class LoRALinearLayer(nn.Layer): |
|
def __init__(self, in_features, out_features, rank=4): |
|
super().__init__() |
|
|
|
if rank > min(in_features, out_features): |
|
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") |
|
|
|
self.down = nn.Linear(in_features, rank, bias_attr=False) |
|
self.up = nn.Linear(rank, out_features, bias_attr=False) |
|
self.scale = 1.0 |
|
|
|
normal_(self.down.weight, std=1 / rank) |
|
zeros_(self.up.weight) |
|
|
|
def forward(self, hidden_states): |
|
orig_dtype = hidden_states.dtype |
|
dtype = self.down.weight.dtype |
|
|
|
down_hidden_states = self.down(hidden_states.cast(dtype)) |
|
up_hidden_states = self.up(down_hidden_states) |
|
|
|
return up_hidden_states.cast(orig_dtype) |
|
|
|
|
|
class LoRACrossAttnProcessor(nn.Layer): |
|
def __init__(self, hidden_size, cross_attention_dim=None, rank=4): |
|
super().__init__() |
|
|
|
self.hidden_size = hidden_size |
|
self.cross_attention_dim = cross_attention_dim |
|
self.rank = rank |
|
|
|
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) |
|
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) |
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) |
|
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) |
|
|
|
def __call__( |
|
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 |
|
): |
|
batch_size, sequence_length, _ = hidden_states.shape |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) |
|
attention_mask = ( |
|
attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]]) |
|
if attention_mask is not None |
|
else None |
|
) |
|
|
|
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) |
|
query = attn.head_to_batch_dim(query) |
|
|
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states |
|
|
|
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) |
|
|
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = paddle.matmul(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class CrossAttnAddedKVProcessor: |
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): |
|
residual = hidden_states |
|
hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose( |
|
[0, 2, 1] |
|
) |
|
batch_size, sequence_length, _ = hidden_states.shape |
|
encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1]) |
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) |
|
attention_mask = ( |
|
attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]]) |
|
if attention_mask is not None |
|
else None |
|
) |
|
|
|
hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1]) |
|
|
|
query = attn.to_q(hidden_states) |
|
query = attn.head_to_batch_dim(query) |
|
|
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
|
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) |
|
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) |
|
|
|
key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2) |
|
value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2) |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = paddle.matmul(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape) |
|
hidden_states = hidden_states + residual |
|
|
|
return hidden_states |
|
|
|
|
|
class SlicedAttnProcessor: |
|
def __init__(self, slice_size): |
|
self.slice_size = slice_size |
|
|
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): |
|
batch_size, sequence_length, _ = hidden_states.shape |
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) |
|
|
|
query = attn.to_q(hidden_states) |
|
query = attn.head_to_batch_dim(query) |
|
|
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
query = query.flatten(0, 1) |
|
key = key.flatten(0, 1) |
|
value = value.flatten(0, 1) |
|
|
|
batch_size_attention = query.shape[0] |
|
hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype) |
|
|
|
for i in range(hidden_states.shape[0] // self.slice_size): |
|
start_idx = i * self.slice_size |
|
end_idx = (i + 1) * self.slice_size |
|
|
|
query_slice = query[start_idx:end_idx] |
|
key_slice = key[start_idx:end_idx] |
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None |
|
|
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) |
|
|
|
attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx]) |
|
|
|
hidden_states[start_idx:end_idx] = attn_slice |
|
|
|
|
|
hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim]) |
|
|
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class SlicedAttnAddedKVProcessor: |
|
def __init__(self, slice_size): |
|
self.slice_size = slice_size |
|
|
|
def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): |
|
residual = hidden_states |
|
hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose( |
|
[0, 2, 1] |
|
) |
|
encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1]) |
|
|
|
batch_size, sequence_length, _ = hidden_states.shape |
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) |
|
|
|
hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1]) |
|
|
|
query = attn.to_q(hidden_states) |
|
query = attn.head_to_batch_dim(query) |
|
|
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
|
|
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) |
|
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) |
|
|
|
key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2) |
|
value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2) |
|
|
|
query = query.flatten(0, 1) |
|
key = key.flatten(0, 1) |
|
value = value.flatten(0, 1) |
|
|
|
batch_size_attention = query.shape[0] |
|
hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype) |
|
for i in range(hidden_states.shape[0] // self.slice_size): |
|
start_idx = i * self.slice_size |
|
end_idx = (i + 1) * self.slice_size |
|
|
|
query_slice = query[start_idx:end_idx] |
|
key_slice = key[start_idx:end_idx] |
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None |
|
|
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) |
|
|
|
attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx]) |
|
|
|
hidden_states[start_idx:end_idx] = attn_slice |
|
|
|
|
|
hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim]) |
|
|
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape) |
|
hidden_states = hidden_states + residual |
|
|
|
return hidden_states |
|
|
|
|
|
AttnProcessor = Union[ |
|
CrossAttnProcessor, |
|
SlicedAttnProcessor, |
|
CrossAttnAddedKVProcessor, |
|
SlicedAttnAddedKVProcessor, |
|
] |
|
|