jeffreygao commited on
Commit
a3e1065
1 Parent(s): 6e571c3

Update modeling_bluelm.py

Browse files
Files changed (1) hide show
  1. modeling_bluelm.py +20 -22
modeling_bluelm.py CHANGED
@@ -32,7 +32,12 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu
32
  from transformers.modeling_utils import PreTrainedModel
33
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
34
  from .configuration_bluelm import BlueLMConfig
35
-
 
 
 
 
 
36
 
37
  try:
38
  from xformers import ops as xops
@@ -213,6 +218,11 @@ class BlueLMAttention(nn.Module):
213
  hidden_size,
214
  bias=False,
215
  )
 
 
 
 
 
216
  self.rotary_emb = BlueLMRotaryEmbedding(self.head_dim)
217
  if xops is not None:
218
  self.causal_mask = xops.LowerTriangularMask()
@@ -230,7 +240,8 @@ class BlueLMAttention(nn.Module):
230
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
231
  """Input shape: Batch x Time x Channel"""
232
 
233
- bsz, q_len, _ = hidden_states.size()
 
234
 
235
  query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
236
  key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
@@ -245,7 +256,7 @@ class BlueLMAttention(nn.Module):
245
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, offset=offset)
246
  # [bsz, t, nh, hd]
247
 
248
- if past_key_value is not None:
249
  # reuse k, v, self_attention
250
  key_states = torch.cat([past_key_value[0], key_states], dim=1)
251
  value_states = torch.cat([past_key_value[1], value_states], dim=1)
@@ -260,25 +271,12 @@ class BlueLMAttention(nn.Module):
260
  )
261
  else:
262
  # [bsz, t, nh, hd]
263
- attn_weights = torch.einsum("bqnh,bknh->bnqk", query_states, key_states) / math.sqrt(self.head_dim)
264
-
265
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
266
- raise ValueError(
267
- f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
268
- f" {attn_weights.size()}"
269
- )
270
-
271
- if attention_mask is not None:
272
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
273
- raise ValueError(
274
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
275
- )
276
- attn_weights = attn_weights + attention_mask
277
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
278
 
279
- # upcast attention to fp32
280
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
281
- attn_output = torch.einsum("bnqk,bknh->bqnh", attn_weights, value_states)
282
 
283
  if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
284
  raise ValueError(
@@ -612,7 +610,7 @@ class BlueLMModel(BlueLMPreTrainedModel):
612
  seq_length_with_past = seq_length
613
  past_key_values_length = 0
614
  if past_key_values is not None:
615
- past_key_values_length = past_key_values[0][0].shape[1]
616
  seq_length_with_past = seq_length_with_past + past_key_values_length
617
  if inputs_embeds is None:
618
  inputs_embeds = self.embed_tokens(input_ids)
 
32
  from transformers.modeling_utils import PreTrainedModel
33
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
34
  from .configuration_bluelm import BlueLMConfig
35
+ from flash_attn.flash_attn_interface import (
36
+ flash_attn_func,
37
+ flash_attn_kvpacked_func,
38
+ flash_attn_qkvpacked_func,
39
+ flash_attn_varlen_kvpacked_func,
40
+ )
41
 
42
  try:
43
  from xformers import ops as xops
 
218
  hidden_size,
219
  bias=False,
220
  )
221
+ self.register_buffer(
222
+ "norm_factor",
223
+ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
224
+ persistent=False,
225
+ )
226
  self.rotary_emb = BlueLMRotaryEmbedding(self.head_dim)
227
  if xops is not None:
228
  self.causal_mask = xops.LowerTriangularMask()
 
240
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
241
  """Input shape: Batch x Time x Channel"""
242
 
243
+ bsz, q_len, h_size = hidden_states.size()
244
+ has_layer_past = past_key_value is not None
245
 
246
  query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
247
  key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
 
256
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, offset=offset)
257
  # [bsz, t, nh, hd]
258
 
259
+ if has_layer_past:
260
  # reuse k, v, self_attention
261
  key_states = torch.cat([past_key_value[0], key_states], dim=1)
262
  value_states = torch.cat([past_key_value[1], value_states], dim=1)
 
271
  )
272
  else:
273
  # [bsz, t, nh, hd]
274
+ kv = torch.stack([key_states, value_states], 2)
275
+ attn_outputs = flash_attn_kvpacked_func(
276
+ query_states, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions)
277
+ attn_output = attn_outputs[0] if output_attentions else attn_outputs
278
+ attn_weights = attn_outputs[2] if output_attentions else None
 
 
 
 
 
 
 
 
 
 
279
 
 
 
 
280
 
281
  if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
282
  raise ValueError(
 
610
  seq_length_with_past = seq_length
611
  past_key_values_length = 0
612
  if past_key_values is not None:
613
+ past_key_values_length = past_key_values[0][0].shape[2]
614
  seq_length_with_past = seq_length_with_past + past_key_values_length
615
  if inputs_embeds is None:
616
  inputs_embeds = self.embed_tokens(input_ids)