x54-729 commited on
Commit
2f2f1b1
·
1 Parent(s): 497af06

support flash attn 2

Browse files
Files changed (2) hide show
  1. configuration_internlm.py +4 -0
  2. modeling_internlm.py +186 -18
configuration_internlm.py CHANGED
@@ -91,6 +91,7 @@ class InternLMConfig(PretrainedConfig):
91
  tie_word_embeddings=False,
92
  bias=True,
93
  rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102
 
94
  **kwargs,
95
  ):
96
  self.vocab_size = vocab_size
@@ -105,6 +106,9 @@ class InternLMConfig(PretrainedConfig):
105
  self.use_cache = use_cache
106
  self.bias = bias
107
  self.rotary = rotary
 
 
 
108
  super().__init__(
109
  pad_token_id=pad_token_id,
110
  bos_token_id=bos_token_id,
 
91
  tie_word_embeddings=False,
92
  bias=True,
93
  rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102
94
+ attn_implementation="eager",
95
  **kwargs,
96
  ):
97
  self.vocab_size = vocab_size
 
106
  self.use_cache = use_cache
107
  self.bias = bias
108
  self.rotary = rotary
109
+ self.attn_implementation = attn_implementation
110
+ if self.attn_implementation is None:
111
+ self.attn_implementation = "eager"
112
  super().__init__(
113
  pad_token_id=pad_token_id,
114
  bos_token_id=bos_token_id,
modeling_internlm.py CHANGED
@@ -1,10 +1,6 @@
1
- # coding=utf-8
2
- # Copyright (c) InternLM. All rights reserved.
3
  #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
@@ -52,6 +48,17 @@ logger = logging.get_logger(__name__)
52
 
53
  _CONFIG_FOR_DOC = "InternLMConfig"
54
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
57
  def _make_causal_mask(
@@ -85,7 +92,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
85
 
86
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
87
 
88
-
89
  class InternLMRMSNorm(nn.Module):
90
  """RMSNorm implemention."""
91
 
@@ -228,8 +234,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
228
  k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
229
  k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
230
  else:
231
- cos = cos[position_ids].unsqueeze(1).expand(q.shape)
232
- sin = sin[position_ids].unsqueeze(1).expand(q.shape)
233
  q_embed = (q * cos) + (rotate_half(q) * sin)
234
  k_embed = (k * cos) + (rotate_half(k) * sin)
235
  return q_embed, k_embed
@@ -273,6 +279,7 @@ class InternLMAttention(nn.Module):
273
  self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
274
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
275
  self.rotary_emb = self._init_rope()
 
276
 
277
  def _init_rope(self):
278
  if self.config.rotary["type"] == "origin":
@@ -356,13 +363,167 @@ class InternLMAttention(nn.Module):
356
  attn_weights = None
357
 
358
  return attn_output, attn_weights, past_key_value
 
 
 
 
 
 
 
359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
  class InternLMDecoderLayer(nn.Module):
362
  def __init__(self, config: InternLMConfig):
363
  super().__init__()
364
  self.hidden_size = config.hidden_size
365
- self.self_attn = InternLMAttention(config=config)
 
 
366
  self.mlp = InternLMMLP(
367
  hidden_size=self.hidden_size,
368
  intermediate_size=config.intermediate_size,
@@ -539,8 +700,10 @@ class InternLMModel(InternLMPreTrainedModel):
539
  super().__init__(config)
540
  self.padding_idx = config.pad_token_id
541
  self.vocab_size = config.vocab_size
 
542
 
543
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
544
  self.layers = nn.ModuleList([InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
545
  self.norm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
546
 
@@ -627,14 +790,16 @@ class InternLMModel(InternLMPreTrainedModel):
627
 
628
  if inputs_embeds is None:
629
  inputs_embeds = self.embed_tokens(input_ids)
630
- # embed positions
631
- if attention_mask is None:
632
- attention_mask = torch.ones(
633
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 
 
 
 
 
634
  )
635
- attention_mask = self._prepare_decoder_attention_mask(
636
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
637
- )
638
 
639
  hidden_states = inputs_embeds
640
 
@@ -759,6 +924,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
759
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
760
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
761
  Returns:
 
762
  Example:
763
  ```python
764
  >>> from transformers import AutoTokenizer, InternLMForCausalLM
@@ -770,7 +936,9 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
770
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
771
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
772
  "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
773
- ```"""
 
 
774
 
775
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
776
  output_hidden_states = (
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
 
2
  #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
 
 
 
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
 
48
 
49
  _CONFIG_FOR_DOC = "InternLMConfig"
50
 
51
+ def _get_unpad_data(attention_mask):
52
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
53
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
54
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
55
+ cu_seqlens = nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
56
+ return (
57
+ indices,
58
+ cu_seqlens,
59
+ max_seqlen_in_batch,
60
+ )
61
+
62
 
63
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
64
  def _make_causal_mask(
 
92
 
93
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
94
 
 
95
  class InternLMRMSNorm(nn.Module):
96
  """RMSNorm implemention."""
97
 
 
234
  k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
235
  k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
236
  else:
237
+ cos = cos[position_ids].unsqueeze(1)
238
+ sin = sin[position_ids].unsqueeze(1)
239
  q_embed = (q * cos) + (rotate_half(q) * sin)
240
  k_embed = (k * cos) + (rotate_half(k) * sin)
241
  return q_embed, k_embed
 
279
  self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
280
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
281
  self.rotary_emb = self._init_rope()
282
+ self.is_causal = True
283
 
284
  def _init_rope(self):
285
  if self.config.rotary["type"] == "origin":
 
363
  attn_weights = None
364
 
365
  return attn_output, attn_weights, past_key_value
366
+
367
+ class InternLMFlashAttention2(InternLMAttention):
368
+ """
369
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
370
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
371
+ flash attention and deal with padding tokens in case the input contains any of them.
372
+ """
373
 
374
+ def forward(
375
+ self,
376
+ hidden_states: torch.Tensor,
377
+ attention_mask: Optional[torch.LongTensor] = None,
378
+ position_ids: Optional[torch.LongTensor] = None,
379
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
380
+ output_attentions: bool = False,
381
+ use_cache: bool = False,
382
+ **kwargs,
383
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
384
+ # InternLM2FlashAttention2 attention does not support output_attentions
385
+ bsz, q_len, _ = hidden_states.size()
386
+
387
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
388
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
389
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
390
+
391
+ if past_key_value is not None:
392
+ # reuse k, v, self_attention
393
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
394
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
395
+
396
+ past_key_value = (key_states, value_states) if use_cache else None
397
+
398
+ kv_seq_len = key_states.shape[-2]
399
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
400
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
401
+
402
+ query_states = query_states.transpose(1, 2)
403
+ key_states = key_states.transpose(1, 2)
404
+ value_states = value_states.transpose(1, 2)
405
+
406
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
407
+
408
+ attn_output = self._flash_attention_forward(
409
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
410
+ )
411
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
412
+ attn_output = self.o_proj(attn_output)
413
+
414
+ if not output_attentions:
415
+ attn_weights = None
416
+
417
+ return attn_output, attn_weights, past_key_value
418
+
419
+ def _flash_attention_forward(
420
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
421
+ ):
422
+ """
423
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
424
+ first unpad the input, then computes the attention scores and pad the final attention scores.
425
+
426
+ Args:
427
+ query_states (`torch.Tensor`):
428
+ Input query states to be passed to Flash Attention API
429
+ key_states (`torch.Tensor`):
430
+ Input key states to be passed to Flash Attention API
431
+ value_states (`torch.Tensor`):
432
+ Input value states to be passed to Flash Attention API
433
+ attention_mask (`torch.Tensor`):
434
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
435
+ position of padding tokens and 1 for the position of non-padding tokens.
436
+ dropout (`int`, *optional*):
437
+ Attention dropout
438
+ softmax_scale (`float`, *optional*):
439
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
440
+ """
441
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
442
+ from flash_attn.bert_padding import pad_input
443
+ # Contains at least one padding token in the sequence
444
+ causal = self.is_causal and query_length != 1
445
+ if attention_mask is not None:
446
+ batch_size = query_states.shape[0]
447
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
448
+ query_states, key_states, value_states, attention_mask, query_length
449
+ )
450
+
451
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
452
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
453
+
454
+ attn_output_unpad = flash_attn_varlen_func(
455
+ query_states,
456
+ key_states,
457
+ value_states,
458
+ cu_seqlens_q=cu_seqlens_q,
459
+ cu_seqlens_k=cu_seqlens_k,
460
+ max_seqlen_q=max_seqlen_in_batch_q,
461
+ max_seqlen_k=max_seqlen_in_batch_k,
462
+ dropout_p=dropout,
463
+ softmax_scale=softmax_scale,
464
+ causal=causal,
465
+ )
466
+
467
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
468
+ else:
469
+ attn_output = flash_attn_func(
470
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
471
+ )
472
+
473
+ return attn_output
474
+
475
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
476
+ from flash_attn.bert_padding import index_first_axis, unpad_input
477
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
478
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
479
+
480
+ key_layer = index_first_axis(
481
+ key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
482
+ )
483
+ value_layer = index_first_axis(
484
+ value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
485
+ )
486
+
487
+ if query_length == kv_seq_len:
488
+ query_layer = index_first_axis(
489
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
490
+ )
491
+ cu_seqlens_q = cu_seqlens_k
492
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
493
+ indices_q = indices_k
494
+ elif query_length == 1:
495
+ max_seqlen_in_batch_q = 1
496
+ cu_seqlens_q = torch.arange(
497
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
498
+ ) # There is a memcpy here, that is very bad.
499
+ indices_q = cu_seqlens_q[:-1]
500
+ query_layer = query_layer.squeeze(1)
501
+ else:
502
+ # The -q_len: slice assumes left padding.
503
+ attention_mask = attention_mask[:, -query_length:]
504
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
505
+
506
+ return (
507
+ query_layer,
508
+ key_layer,
509
+ value_layer,
510
+ indices_q.to(torch.int64),
511
+ (cu_seqlens_q, cu_seqlens_k),
512
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
513
+ )
514
+
515
+ INTERNLM_ATTENTION_CLASSES = {
516
+ "eager": InternLMAttention,
517
+ "flash_attention_2": InternLMFlashAttention2,
518
+ }
519
 
520
  class InternLMDecoderLayer(nn.Module):
521
  def __init__(self, config: InternLMConfig):
522
  super().__init__()
523
  self.hidden_size = config.hidden_size
524
+
525
+ self.self_attn = INTERNLM_ATTENTION_CLASSES[config.attn_implementation](config=config)
526
+
527
  self.mlp = InternLMMLP(
528
  hidden_size=self.hidden_size,
529
  intermediate_size=config.intermediate_size,
 
700
  super().__init__(config)
701
  self.padding_idx = config.pad_token_id
702
  self.vocab_size = config.vocab_size
703
+ self.config = config
704
 
705
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
706
+
707
  self.layers = nn.ModuleList([InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
708
  self.norm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
709
 
 
790
 
791
  if inputs_embeds is None:
792
  inputs_embeds = self.embed_tokens(input_ids)
793
+ if self.config.attn_implementation == "flash_attention_2":
794
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
795
+ else:
796
+ if attention_mask is None:
797
+ attention_mask = torch.ones(
798
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
799
+ )
800
+ attention_mask = self._prepare_decoder_attention_mask(
801
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
802
  )
 
 
 
803
 
804
  hidden_states = inputs_embeds
805
 
 
924
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
925
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
926
  Returns:
927
+
928
  Example:
929
  ```python
930
  >>> from transformers import AutoTokenizer, InternLMForCausalLM
 
936
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
937
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
938
  "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
939
+ ```
940
+
941
+ """
942
 
943
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
944
  output_hidden_states = (