sy-chen commited on
Commit
b6ce8bd
1 Parent(s): 10bacc3

Implement MLA inference optimizations to DeepseekV2Attention

Browse files

This patched DeepseekV2Model contains the following modifications to DeepseekV2Attention for reducing VRAM consumption and improve efficiency:

1. Instead of caching the decompressed key/value states, we cache only the low-rank key-value joint compression as well as
the decoupled RoPE part of the keys. For the sake of reusing the cache utility of transformers library, we treat
k_pe as key_states and compressed_kv as value_states.
2. We implement the absorption technique described in the DeepseekV2 paper, by changing the multiplication order when
computing query and output vectors. This not only saves memory consumption of intermediate tensors but also reduces
the number of floating-point operations.
3. We compute the RoPE part and non-RoPE part of the attention score separately and then sum them up. The original
implementation concatenates the two parts of the query/key vectors, which has proven to be quite inefficient when
caching compressed key/value states due to unnecessary data broadcast and memory round-trips.

By applying the above changes, the MLA module can achieve up to 20.4x speedup for single request and 3.63x for 32
batched requests on an NVIDIA A100-PCIE-40GB GPU during the decoding phase, as well as 26.2x and 3.52x speedup on
NVIDIA GeForce RTX 4080 for single and batched requests, respectively.

More detailed description of the modification can be found in https://zhuanlan.zhihu.com/p/700214123?utm_psn=1779287628619632640 and https://github.com/madsys-dev/deepseekv2-profile/blob/924174cb5dc11fad24bdaad3fd820ebf87506368/workspace/blog/optimizing-mla.md (in Chinese).

Files changed (1) hide show
  1. modeling_deepseek.py +18 -28
modeling_deepseek.py CHANGED
@@ -822,17 +822,10 @@ class DeepseekV2Attention(nn.Module):
822
  compressed_kv, k_pe = torch.split(
823
  compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
824
  )
 
825
  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
826
- kv = (
827
- self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
828
- .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
829
- .transpose(1, 2)
830
- )
831
 
832
- k_nope, value_states = torch.split(
833
- kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
834
- )
835
- kv_seq_len = value_states.shape[-2]
836
  if past_key_value is not None:
837
  if self.layer_idx is None:
838
  raise ValueError(
@@ -841,27 +834,22 @@ class DeepseekV2Attention(nn.Module):
841
  "with a layer index."
842
  )
843
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
844
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
845
 
 
846
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
847
 
848
- query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
849
- query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
850
- query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
851
-
852
- key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
853
- key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
854
- key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
855
  if past_key_value is not None:
856
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
857
- key_states, value_states = past_key_value.update(
858
- key_states, value_states, self.layer_idx, cache_kwargs
859
- )
860
-
861
- attn_weights = (
862
- torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
863
- )
864
-
 
 
865
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
866
  raise ValueError(
867
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
@@ -878,11 +866,13 @@ class DeepseekV2Attention(nn.Module):
878
  # upcast attention to fp32
879
  attn_weights = nn.functional.softmax(
880
  attn_weights, dim=-1, dtype=torch.float32
881
- ).to(query_states.dtype)
882
  attn_weights = nn.functional.dropout(
883
  attn_weights, p=self.attention_dropout, training=self.training
884
  )
885
- attn_output = torch.matmul(attn_weights, value_states)
 
 
886
 
887
  if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
888
  raise ValueError(
@@ -1902,4 +1892,4 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
1902
  past_key_values=transformer_outputs.past_key_values,
1903
  hidden_states=transformer_outputs.hidden_states,
1904
  attentions=transformer_outputs.attentions,
1905
- )
 
822
  compressed_kv, k_pe = torch.split(
823
  compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
824
  )
825
+ compressed_kv = self.kv_a_layernorm(compressed_kv)
826
  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
 
 
 
 
 
827
 
828
+ kv_seq_len = k_pe.shape[-2]
 
 
 
829
  if past_key_value is not None:
830
  if self.layer_idx is None:
831
  raise ValueError(
 
834
  "with a layer index."
835
  )
836
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
837
 
838
+ cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
839
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
840
 
 
 
 
 
 
 
 
841
  if past_key_value is not None:
842
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
843
+ compressed_kv = compressed_kv.unsqueeze(1)
844
+ k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
845
+ compressed_kv = compressed_kv.squeeze(1)
846
+
847
+ kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
848
+ q_absorb = kv_b_proj[:, :self.qk_nope_head_dim,:]
849
+ out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]
850
+
851
+ q_nope = torch.matmul(q_nope, q_absorb)
852
+ attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
853
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
854
  raise ValueError(
855
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
 
866
  # upcast attention to fp32
867
  attn_weights = nn.functional.softmax(
868
  attn_weights, dim=-1, dtype=torch.float32
869
+ ).to(q_pe.dtype)
870
  attn_weights = nn.functional.dropout(
871
  attn_weights, p=self.attention_dropout, training=self.training
872
  )
873
+ attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
874
+
875
+ attn_output = torch.matmul(attn_output, out_absorb.mT)
876
 
877
  if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
878
  raise ValueError(
 
1892
  past_key_values=transformer_outputs.past_key_values,
1893
  hidden_states=transformer_outputs.hidden_states,
1894
  attentions=transformer_outputs.attentions,
1895
+ )