|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Tuple |
|
|
|
import math |
|
|
|
|
|
|
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): |
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
|
t = torch.arange(end, device=freqs.device) |
|
freqs = torch.outer(t, freqs).float() |
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
|
return freqs_cis |
|
|
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): |
|
ndim = x.ndim |
|
assert 0 <= 1 < ndim |
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1]) |
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
|
return freqs_cis.view(*shape) |
|
|
|
|
|
def apply_rotary_emb( |
|
xq: torch.Tensor, |
|
xk: torch.Tensor, |
|
freqs_cis: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
assert xq.shape[-1] == xk.shape[-1], "Query and Key must have same embedding dimension" |
|
assert xq.shape[-1] % 2 == 0, "Embedding dimension must be even" |
|
|
|
|
|
q_len = xq.shape[1] |
|
k_len = xk.shape[1] |
|
|
|
|
|
q_freqs = freqs_cis[:q_len] |
|
k_freqs = freqs_cis[:k_len] |
|
|
|
|
|
|
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
|
|
|
|
|
|
|
q_freqs = reshape_for_broadcast(q_freqs, xq_) |
|
k_freqs = reshape_for_broadcast(k_freqs, xk_) |
|
|
|
|
|
xq_out = torch.view_as_real(xq_ * q_freqs).flatten(xq.ndim-1) |
|
xk_out = torch.view_as_real(xk_ * k_freqs).flatten(xk.ndim-1) |
|
|
|
return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
|
|
|
|
|
class MultiHeadLatentAttention(nn.Module): |
|
""" |
|
Multi-Head Latent Attention(MLA) Module As in DeepSeek_V2 pape |
|
Key innovation from standard MHA: |
|
1. Low-Rank Key-Value Joint Compression |
|
2. Decoupled Rotary Position Embedding |
|
|
|
Args: |
|
d_model: Total dimension of the model. |
|
num_head: Number of attention heads. |
|
d_embed: Embedding dimension |
|
d_c: K/V compression dimension |
|
d_c1: Q compression dimension |
|
d_rotate: Dimension for Rotary Position Embedding |
|
dropout: Dropout rate for attention scores. |
|
bias: Whether to include bias in linear projections. |
|
|
|
d_head: Inferred from d_model//num_head |
|
|
|
Inputs: |
|
sequence: input sequence for self-attention and the query for cross-attention |
|
key_value_state: input for the key, values for cross-attention |
|
""" |
|
def __init__( |
|
self, |
|
d_model, |
|
num_head, |
|
d_embed, |
|
d_c, |
|
d_c1, |
|
d_rotate, |
|
dropout=0.1, |
|
bias=True, |
|
max_batch_size=32, |
|
max_seq_len=2048 |
|
): |
|
super().__init__() |
|
|
|
assert d_model % num_head == 0, "d_model must be divisible by num_head" |
|
assert d_c < d_embed, "Compression dim should be smaller than embedding dim" |
|
assert d_c1 < d_embed, "Query compression dim should be smaller than embedding dim" |
|
|
|
self.d_model = d_model |
|
self.num_head = num_head |
|
|
|
assert d_model % num_head == 0, f"d_model ({d_model}) must be divisible by num_head ({num_head})" |
|
self.d_head=d_model//num_head |
|
self.d_embed = d_embed |
|
self.d_c = d_c |
|
self.d_c1 = d_c1 |
|
self.d_rotate = d_rotate |
|
self.dropout_rate = dropout |
|
|
|
|
|
self.DKV_proj = nn.Linear(d_embed, d_c, bias=bias) |
|
self.DQ_proj = nn.Linear(d_embed, d_c1, bias=bias) |
|
|
|
|
|
self.UQ_proj = nn.Linear(d_c1, d_model, bias=bias) |
|
self.UK_proj = nn.Linear(d_c, d_model, bias=bias) |
|
self.UV_proj = nn.Linear(d_c, d_model, bias=bias) |
|
|
|
|
|
self.RQ_proj = nn.Linear(d_c1, num_head*d_rotate, bias=bias) |
|
self.RK_proj = nn.Linear(d_embed, d_rotate, bias=bias) |
|
|
|
|
|
self.output_proj = nn.Linear( d_model, d_model, bias=bias) |
|
|
|
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
|
|
self.scaler = float(1.0 / math.sqrt(self.d_head + d_rotate)) |
|
|
|
|
|
self.cache_kv = torch.zeros( |
|
(max_batch_size, max_seq_len, d_c) |
|
) |
|
self.cache_rk = torch.zeros( |
|
(max_batch_size, max_seq_len, d_rotate) |
|
) |
|
|
|
|
|
self.freqs_cis = precompute_freqs_cis( |
|
d_rotate, max_seq_len * 2 |
|
) |
|
|
|
|
|
def forward( |
|
self, |
|
sequence, |
|
key_value_states = None, |
|
att_mask=None, |
|
use_cache=False, |
|
start_pos: int = 0 |
|
): |
|
|
|
""" |
|
Forward pass supporting both standard attention and cached inference |
|
Input shape: [batch_size, seq_len, d_model=num_head * d_head] |
|
Args: |
|
sequence: Input sequence [batch_size, seq_len, d_model] |
|
key_value_states: Optional states for cross-attention |
|
att_mask: Optional attention mask |
|
use_cache: Whether to use KV caching (for inference) |
|
start_pos: Position in sequence when using KV cache |
|
""" |
|
batch_size, seq_len, model_dim = sequence.size() |
|
|
|
self.freqs_cis = self.freqs_cis.to(sequence.device) |
|
freqs_cis = self.freqs_cis[start_pos : ] |
|
|
|
|
|
assert model_dim == self.d_model, f"Input dimension {model_dim} doesn't match model dimension {self.d_model}" |
|
if key_value_states is not None: |
|
assert key_value_states.size(-1) == self.d_model, \ |
|
f"Cross attention key/value dimension {key_value_states.size(-1)} doesn't match model dimension {self.d_model}" |
|
|
|
|
|
|
|
is_cross_attention = key_value_states is not None |
|
|
|
|
|
kv_seq_len = key_value_states.size(1) if is_cross_attention else seq_len |
|
|
|
|
|
|
|
C_Q = self.DQ_proj(sequence) |
|
Q_state = self.UQ_proj(C_Q) |
|
|
|
Q_rotate = self.RQ_proj(C_Q) |
|
|
|
|
|
if use_cache: |
|
|
|
self.cache_kv = self.cache_kv.to(sequence.device) |
|
|
|
|
|
current_kv = self.DKV_proj(key_value_states if is_cross_attention else sequence) |
|
|
|
self.cache_kv[:batch_size, start_pos:start_pos + kv_seq_len] = current_kv |
|
|
|
C_KV = self.cache_kv[:batch_size, :start_pos + kv_seq_len] |
|
|
|
|
|
assert self.cache_rk.size(-1) == self.d_rotate, "RoPE cache dimension mismatch" |
|
self.cache_rk = self.cache_rk.to(sequence.device) |
|
|
|
current_K_rotate = self.RK_proj(key_value_states if is_cross_attention else sequence) |
|
|
|
self.cache_rk[:batch_size, start_pos:start_pos + kv_seq_len] = current_K_rotate |
|
|
|
K_rotate = self.cache_rk[:batch_size, :start_pos + kv_seq_len] |
|
|
|
|
|
"""handling attention mask""" |
|
if att_mask is not None: |
|
|
|
mask_size = att_mask.size(-1) |
|
cached_len = start_pos + kv_seq_len |
|
assert C_KV.size(1) == cached_len, \ |
|
f"Cached key/value length {C_KV.size(1)} doesn't match theoretical length {cached_len}" |
|
|
|
|
|
extended_mask = torch.zeros( |
|
(batch_size, 1, seq_len, cached_len), |
|
device=att_mask.device, |
|
dtype=att_mask.dtype |
|
) |
|
|
|
|
|
|
|
for i in range(seq_len): |
|
extended_mask[:, :, i, :(start_pos + i + 1)] = 0 |
|
extended_mask[:, :, i, (start_pos + i + 1):] = float('-inf') |
|
|
|
att_mask = extended_mask |
|
else: |
|
|
|
C_KV = self.DKV_proj(key_value_states if is_cross_attention else sequence) |
|
|
|
K_rotate = self.RK_proj(key_value_states if is_cross_attention else sequence) |
|
|
|
|
|
|
|
K_state = self.UK_proj(C_KV) |
|
V_state = self.UV_proj(C_KV) |
|
|
|
|
|
Q_state = Q_state.view(batch_size, seq_len, self.num_head, self.d_head) |
|
|
|
|
|
actual_kv_len = K_state.size(1) |
|
|
|
|
|
K_state = K_state.view(batch_size, actual_kv_len, self.num_head, self.d_head) |
|
V_state = V_state.view(batch_size, actual_kv_len, self.num_head, self.d_head) |
|
|
|
|
|
|
|
Q_rotate = Q_rotate.view(batch_size, seq_len, self.num_head, self.d_rotate) |
|
K_rotate = K_rotate.unsqueeze(2).expand(-1, -1, self.num_head, -1) |
|
Q_rotate, K_rotate = apply_rotary_emb(Q_rotate, K_rotate, freqs_cis=freqs_cis) |
|
|
|
|
|
|
|
Q_state = torch.cat([Q_state, Q_rotate], dim=-1) |
|
K_state = torch.cat([K_state, K_rotate], dim=-1) |
|
|
|
|
|
|
|
Q_state = Q_state * self.scaler |
|
Q_state = Q_state.transpose(1, 2) |
|
K_state = K_state.transpose(1, 2) |
|
V_state = V_state.transpose(1, 2) |
|
|
|
|
|
|
|
self.att_matrix = torch.matmul(Q_state, K_state.transpose(-1,-2)) |
|
|
|
|
|
if att_mask is not None and not isinstance(att_mask, torch.Tensor): |
|
raise TypeError("att_mask must be a torch.Tensor") |
|
|
|
if att_mask is not None: |
|
self.att_matrix = self.att_matrix + att_mask |
|
|
|
|
|
att_score = F.softmax(self.att_matrix, dim = -1) |
|
|
|
|
|
att_score = self.dropout(att_score) |
|
|
|
|
|
att_output = torch.matmul(att_score, V_state) |
|
assert att_output.size(0) == batch_size, "Batch size mismatch" |
|
assert att_output.size(2) == seq_len, "Output sequence length should match query sequence length" |
|
|
|
|
|
|
|
att_output = att_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_head*self.d_head) |
|
|
|
|
|
|
|
att_output = self.output_proj(att_output) |
|
|
|
assert att_output.size() == (batch_size, seq_len, self.d_model), \ |
|
f"Final output shape {att_output.size()} incorrect" |
|
|
|
return att_output |