deepseek-mla / src /mla.py
bird-of-paradise's picture
Update class names to MultiHeadLatentAttention
2d7348d
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
import math
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Validate input dimensions
assert xq.shape[-1] == xk.shape[-1], "Query and Key must have same embedding dimension"
assert xq.shape[-1] % 2 == 0, "Embedding dimension must be even"
# Get sequence lengths
q_len = xq.shape[1]
k_len = xk.shape[1]
# Use appropriate part of freqs_cis for each sequence
q_freqs = freqs_cis[:q_len]
k_freqs = freqs_cis[:k_len]
# Apply rotary embeddings separately
# split last dimention to [xq.shape[:-1]/2, 2]
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# Reshape freqs for each
q_freqs = reshape_for_broadcast(q_freqs, xq_)
k_freqs = reshape_for_broadcast(k_freqs, xk_)
# Works for both [bsz, seqlen, n_heads*head_dim] and [bsz, seqlen, n_heads, head_dim]
xq_out = torch.view_as_real(xq_ * q_freqs).flatten(xq.ndim-1)
xk_out = torch.view_as_real(xk_ * k_freqs).flatten(xk.ndim-1)
return xq_out.type_as(xq), xk_out.type_as(xk)
class MultiHeadLatentAttention(nn.Module):
"""
Multi-Head Latent Attention(MLA) Module As in DeepSeek_V2 pape
Key innovation from standard MHA:
1. Low-Rank Key-Value Joint Compression
2. Decoupled Rotary Position Embedding
Args:
d_model: Total dimension of the model.
num_head: Number of attention heads.
d_embed: Embedding dimension
d_c: K/V compression dimension
d_c1: Q compression dimension
d_rotate: Dimension for Rotary Position Embedding
dropout: Dropout rate for attention scores.
bias: Whether to include bias in linear projections.
d_head: Inferred from d_model//num_head
Inputs:
sequence: input sequence for self-attention and the query for cross-attention
key_value_state: input for the key, values for cross-attention
"""
def __init__(
self,
d_model, # Infer d_head from d_model
num_head,
d_embed,
d_c,
d_c1,
d_rotate,
dropout=0.1,
bias=True,
max_batch_size=32, # For KV cache sizing
max_seq_len=2048 # For KV cache sizing
):
super().__init__()
assert d_model % num_head == 0, "d_model must be divisible by num_head"
assert d_c < d_embed, "Compression dim should be smaller than embedding dim"
assert d_c1 < d_embed, "Query compression dim should be smaller than embedding dim"
self.d_model = d_model
self.num_head = num_head
# Verify dimensions match up
assert d_model % num_head == 0, f"d_model ({d_model}) must be divisible by num_head ({num_head})"
self.d_head=d_model//num_head
self.d_embed = d_embed
self.d_c = d_c
self.d_c1 = d_c1
self.d_rotate = d_rotate
self.dropout_rate = dropout # Store dropout rate separately
# Linear down-projection(compression) transformations
self.DKV_proj = nn.Linear(d_embed, d_c, bias=bias)
self.DQ_proj = nn.Linear(d_embed, d_c1, bias=bias)
# linear up-projection transformations
self.UQ_proj = nn.Linear(d_c1, d_model, bias=bias)
self.UK_proj = nn.Linear(d_c, d_model, bias=bias)
self.UV_proj = nn.Linear(d_c, d_model, bias=bias)
# Linear RoPE-projection
self.RQ_proj = nn.Linear(d_c1, num_head*d_rotate, bias=bias)
self.RK_proj = nn.Linear(d_embed, d_rotate, bias=bias)
# linear output transformations
self.output_proj = nn.Linear( d_model, d_model, bias=bias)
# Dropout layer
self.dropout = nn.Dropout(p=dropout)
# Initiialize scaler
self.scaler = float(1.0 / math.sqrt(self.d_head + d_rotate)) # Store as float in initialization
# Initialize C_KV and R_K cache for inference
self.cache_kv = torch.zeros(
(max_batch_size, max_seq_len, d_c)
)
self.cache_rk = torch.zeros(
(max_batch_size, max_seq_len, d_rotate)
)
# Initialize freqs_cis for RoPE
self.freqs_cis = precompute_freqs_cis(
d_rotate, max_seq_len * 2
)
def forward(
self,
sequence,
key_value_states = None,
att_mask=None,
use_cache=False,
start_pos: int = 0
):
"""
Forward pass supporting both standard attention and cached inference
Input shape: [batch_size, seq_len, d_model=num_head * d_head]
Args:
sequence: Input sequence [batch_size, seq_len, d_model]
key_value_states: Optional states for cross-attention
att_mask: Optional attention mask
use_cache: Whether to use KV caching (for inference)
start_pos: Position in sequence when using KV cache
"""
batch_size, seq_len, model_dim = sequence.size()
# prepare for RoPE
self.freqs_cis = self.freqs_cis.to(sequence.device)
freqs_cis = self.freqs_cis[start_pos : ]
# Check only critical input dimensions
assert model_dim == self.d_model, f"Input dimension {model_dim} doesn't match model dimension {self.d_model}"
if key_value_states is not None:
assert key_value_states.size(-1) == self.d_model, \
f"Cross attention key/value dimension {key_value_states.size(-1)} doesn't match model dimension {self.d_model}"
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
# Determine kv_seq_len early
kv_seq_len = key_value_states.size(1) if is_cross_attention else seq_len
# Linear projections and reshape for multi-head, in the order of Q, K/V
# Down and up projection for query
C_Q = self.DQ_proj(sequence) #[batch_size, seq_len, d_c1]
Q_state = self.UQ_proj(C_Q) #[batch_size, seq_len, d_model]
# Linear projection for query RoPE pathway
Q_rotate = self.RQ_proj(C_Q) #[batch_size, seq_len, num_head*d_rotate]
if use_cache:
#Equation (41) in DeepSeek-v2 paper: cache c^{KV}_t
self.cache_kv = self.cache_kv.to(sequence.device)
# Get current compressed KV states
current_kv = self.DKV_proj(key_value_states if is_cross_attention else sequence) #[batch_size, kv_seq_len, d_c]
# Update cache using kv_seq_len instead of seq_len
self.cache_kv[:batch_size, start_pos:start_pos + kv_seq_len] = current_kv
# Use cached compressed KV up to current position
C_KV = self.cache_kv[:batch_size, :start_pos + kv_seq_len]
#Equation (43) in DeepSeek-v2 paper: cache the RoPE pathwway for shared key k^R_t
assert self.cache_rk.size(-1) == self.d_rotate, "RoPE cache dimension mismatch"
self.cache_rk = self.cache_rk.to(sequence.device)
# Get current RoPE key
current_K_rotate = self.RK_proj(key_value_states if is_cross_attention else sequence) #[batch_size, kv_seq_len, d_rotate]
# Update cache using kv_seq_len instead of seq_len
self.cache_rk[:batch_size, start_pos:start_pos + kv_seq_len] = current_K_rotate
# Use cached RoPE key up to current position
K_rotate = self.cache_rk[:batch_size, :start_pos + kv_seq_len] #[batch_size, cached_len, d_rotate]
"""handling attention mask"""
if att_mask is not None:
# Get the original mask shape
mask_size = att_mask.size(-1)
cached_len = start_pos + kv_seq_len # cached key_len, including previous key
assert C_KV.size(1) == cached_len, \
f"Cached key/value length {C_KV.size(1)} doesn't match theoretical length {cached_len}"
# Create new mask matching attention matrix shape
extended_mask = torch.zeros(
(batch_size, 1, seq_len, cached_len), # [batch, head, query_len, key_len]
device=att_mask.device,
dtype=att_mask.dtype
)
# Fill in the mask appropriately - we need to be careful about the causality here
# For each query position, it should only attend to cached positions up to that point
for i in range(seq_len):
extended_mask[:, :, i, :(start_pos + i + 1)] = 0 # Can attend
extended_mask[:, :, i, (start_pos + i + 1):] = float('-inf') # Cannot attend
att_mask = extended_mask
else:
# Compression projection for C_KV
C_KV = self.DKV_proj(key_value_states if is_cross_attention else sequence) #[batch_size, kv_seq_len, d_c]\
# RoPE pathway for *shared* key
K_rotate = self.RK_proj(key_value_states if is_cross_attention else sequence)
# Up projection for key and value
K_state = self.UK_proj(C_KV) #[batch_size, kv_seq_len/cached_len, d_model]
V_state = self.UV_proj(C_KV) #[batch_size, kv_seq_len/cached_len, d_model]
Q_state = Q_state.view(batch_size, seq_len, self.num_head, self.d_head)
# After getting K_state from projection, get its actual sequence length
actual_kv_len = K_state.size(1) # kv_seq_len or start_pos + kv_seq_len
# in cross-attention, key/value sequence length might be different from query sequence length
# Use actual_kv_len instead of kv_seq_len for reshaping
K_state = K_state.view(batch_size, actual_kv_len, self.num_head, self.d_head)
V_state = V_state.view(batch_size, actual_kv_len, self.num_head, self.d_head)
#Apply RoPE to query and shared key
Q_rotate = Q_rotate.view(batch_size, seq_len, self.num_head, self.d_rotate)
K_rotate = K_rotate.unsqueeze(2).expand(-1, -1, self.num_head, -1) # [batch, cached_len, num_head, d_rotate]
Q_rotate, K_rotate = apply_rotary_emb(Q_rotate, K_rotate, freqs_cis=freqs_cis)
# Concatenate along head dimension
Q_state = torch.cat([Q_state, Q_rotate], dim=-1) # [batch_size, seq_len, num_head, d_head + d_rotate]
K_state = torch.cat([K_state, K_rotate], dim=-1) # [batch_size, actual_kv_len, num_head, d_head + d_rotate]
# Scale Q by 1/sqrt(d_k)
Q_state = Q_state * self.scaler
Q_state = Q_state.transpose(1, 2) # [batch_size, num_head, seq_len, head_dim]
K_state = K_state.transpose(1, 2) # [batch_size, num_head, actual_kv_len, head_dim]
V_state = V_state.transpose(1, 2) # [batch_size, num_head, actual_kv_len, head_dim]
# Compute attention matrix: QK^T
self.att_matrix = torch.matmul(Q_state, K_state.transpose(-1,-2))
# apply attention mask to attention matrix
if att_mask is not None and not isinstance(att_mask, torch.Tensor):
raise TypeError("att_mask must be a torch.Tensor")
if att_mask is not None:
self.att_matrix = self.att_matrix + att_mask
# apply softmax to the last dimension to get the attention score: softmax(QK^T)
att_score = F.softmax(self.att_matrix, dim = -1)
# apply drop out to attention score
att_score = self.dropout(att_score)
# get final output: softmax(QK^T)V
att_output = torch.matmul(att_score, V_state)
assert att_output.size(0) == batch_size, "Batch size mismatch"
assert att_output.size(2) == seq_len, "Output sequence length should match query sequence length"
# concatinate all attention heads
att_output = att_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_head*self.d_head)
# final linear transformation to the concatenated output
att_output = self.output_proj(att_output)
assert att_output.size() == (batch_size, seq_len, self.d_model), \
f"Final output shape {att_output.size()} incorrect"
return att_output