x54-729 commited on
Commit
4e70767
1 Parent(s): 254d72c

support flash attn 2

Browse files
Files changed (2) hide show
  1. configuration_internlm.py +5 -0
  2. modeling_internlm2.py +158 -55
configuration_internlm.py CHANGED
@@ -108,6 +108,7 @@ class InternLMConfig(PretrainedConfig):
108
  bias=True,
109
  rope_theta=10000,
110
  rope_scaling=None,
 
111
  **kwargs,
112
  ):
113
  self.vocab_size = vocab_size
@@ -129,6 +130,10 @@ class InternLMConfig(PretrainedConfig):
129
  self.rope_theta = rope_theta
130
  self.rope_scaling = rope_scaling
131
  self._rope_scaling_validation()
 
 
 
 
132
  super().__init__(
133
  pad_token_id=pad_token_id,
134
  bos_token_id=bos_token_id,
 
108
  bias=True,
109
  rope_theta=10000,
110
  rope_scaling=None,
111
+ attn_implementation="eager",
112
  **kwargs,
113
  ):
114
  self.vocab_size = vocab_size
 
130
  self.rope_theta = rope_theta
131
  self.rope_scaling = rope_scaling
132
  self._rope_scaling_validation()
133
+
134
+ self.attn_implementation = attn_implementation
135
+ if self.attn_implementation is None:
136
+ self.attn_implementation = "eager"
137
  super().__init__(
138
  pad_token_id=pad_token_id,
139
  bos_token_id=bos_token_id,
modeling_internlm2.py CHANGED
@@ -1,10 +1,6 @@
1
- # coding=utf-8
2
- # # Copyright (c) InternLM. All rights reserved.
3
  #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
@@ -25,6 +21,7 @@ import warnings
25
  from typing import List, Optional, Tuple, Union
26
 
27
  import torch
 
28
  import torch.utils.checkpoint
29
  from einops import rearrange
30
  from torch import nn
@@ -54,6 +51,18 @@ logger = logging.get_logger(__name__)
54
 
55
  _CONFIG_FOR_DOC = "InternLM2Config"
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
59
  def _make_causal_mask(
@@ -88,6 +97,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
88
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
89
 
90
 
 
91
  class InternLM2RMSNorm(nn.Module):
92
  def __init__(self, hidden_size, eps=1e-6):
93
  """
@@ -105,6 +115,7 @@ class InternLM2RMSNorm(nn.Module):
105
  return self.weight * hidden_states.to(input_dtype)
106
 
107
 
 
108
  class InternLM2RotaryEmbedding(nn.Module):
109
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
110
  super().__init__()
@@ -141,6 +152,7 @@ class InternLM2RotaryEmbedding(nn.Module):
141
  )
142
 
143
 
 
144
  class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
145
  """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
146
 
@@ -160,6 +172,7 @@ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
160
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
161
 
162
 
 
163
  class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
164
  """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
165
  Credits to the Reddit users /u/bloc97 and /u/emozilla.
@@ -188,6 +201,7 @@ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
188
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
189
 
190
 
 
191
  def rotate_half(x):
192
  """Rotates half the hidden dims of the input."""
193
  x1 = x[..., : x.shape[-1] // 2]
@@ -195,12 +209,13 @@ def rotate_half(x):
195
  return torch.cat((-x2, x1), dim=-1)
196
 
197
 
198
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
199
- cos = cos[position_ids].unsqueeze(1)
200
- sin = sin[position_ids].unsqueeze(1)
 
 
201
  q_embed = (q * cos) + (rotate_half(q) * sin)
202
  k_embed = (k * cos) + (rotate_half(k) * sin)
203
-
204
  return q_embed, k_embed
205
 
206
 
@@ -221,6 +236,7 @@ class InternLM2MLP(nn.Module):
221
  return down_proj
222
 
223
 
 
224
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
225
  """
226
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -233,6 +249,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
233
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
234
 
235
 
 
236
  class InternLM2Attention(nn.Module):
237
  """Multi-headed attention from 'Attention Is All You Need' paper"""
238
 
@@ -277,14 +294,14 @@ class InternLM2Attention(nn.Module):
277
  self.head_dim,
278
  max_position_embeddings=self.max_position_embeddings,
279
  base=self.config.rope_theta,
280
- scaling_factor=scaling_factor
281
  )
282
  elif scaling_type == "linear":
283
  self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
284
  self.head_dim,
285
  max_position_embeddings=self.max_position_embeddings,
286
  base=self.config.rope_theta,
287
- scaling_factor=scaling_factor
288
  )
289
  else:
290
  raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
@@ -381,6 +398,7 @@ class InternLM2Attention(nn.Module):
381
  return attn_output, attn_weights, past_key_value
382
 
383
 
 
384
  class InternLM2FlashAttention2(InternLM2Attention):
385
  """
386
  InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
@@ -417,9 +435,8 @@ class InternLM2FlashAttention2(InternLM2Attention):
417
  qkv_states = rearrange(
418
  qkv_states,
419
  "b q (h gs d) -> b q h gs d",
420
- gs=self.num_heads + 2 * self.num_key_value_heads,
421
  d=self.head_dim,
422
- q=q_len,
423
  )
424
 
425
  query_states = qkv_states[..., : self.num_key_value_groups, :]
@@ -427,6 +444,10 @@ class InternLM2FlashAttention2(InternLM2Attention):
427
  key_states = qkv_states[..., -2, :]
428
  value_states = qkv_states[..., -1, :]
429
 
 
 
 
 
430
  kv_seq_len = key_states.shape[-2]
431
  if past_key_value is not None:
432
  kv_seq_len += past_key_value[0].shape[-2]
@@ -448,34 +469,9 @@ class InternLM2FlashAttention2(InternLM2Attention):
448
 
449
  dropout_rate = 0.0 if not self.training else self.attention_dropout
450
 
451
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
452
- # therefore the input hidden states gets silently casted in float32. Hence, we need
453
- # cast them back in the correct dtype just to be sure everything works as expected.
454
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
455
- # in fp32. (InternLM2RMSNorm handles it correctly)
456
-
457
- input_dtype = query_states.dtype
458
- if input_dtype == torch.float32:
459
- # Handle the case where the model is quantized
460
- if hasattr(self.config, "_pre_quantization_dtype"):
461
- target_dtype = self.config._pre_quantization_dtype
462
- else:
463
- target_dtype = self.q_proj.weight.dtype
464
-
465
- logger.warning_once(
466
- f"The input hidden states seems to be silently casted in float32, this might be related to"
467
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back "
468
- f"the input in {target_dtype}."
469
- )
470
-
471
- query_states = query_states.to(target_dtype)
472
- key_states = key_states.to(target_dtype)
473
- value_states = value_states.to(target_dtype)
474
-
475
  attn_output = self._flash_attention_forward(
476
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
477
  )
478
-
479
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
480
  attn_output = self.wo(attn_output)
481
 
@@ -484,16 +480,115 @@ class InternLM2FlashAttention2(InternLM2Attention):
484
 
485
  return attn_output, attn_weights, past_key_value
486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  class InternLM2DecoderLayer(nn.Module):
489
  def __init__(self, config: InternLM2Config):
490
  super().__init__()
491
  self.hidden_size = config.hidden_size
492
- self.attention = (
493
- InternLM2Attention(config=config)
494
- if not getattr(config, "_flash_attn_2_enabled", False)
495
- else InternLM2FlashAttention2(config=config)
496
- )
497
  self.feed_forward = InternLM2MLP(config)
498
  self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
499
  self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -578,6 +673,7 @@ InternLM2_START_DOCSTRING = r"""
578
  """
579
 
580
 
 
581
  @add_start_docstrings(
582
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
583
  InternLM2_START_DOCSTRING,
@@ -588,7 +684,6 @@ class InternLM2PreTrainedModel(PreTrainedModel):
588
  supports_gradient_checkpointing = True
589
  _no_split_modules = ["InternLM2DecoderLayer"]
590
  _skip_keys_device_placement = "past_key_values"
591
- _supports_flash_attn_2 = True
592
 
593
  def _init_weights(self, module):
594
  std = self.config.initializer_range
@@ -667,6 +762,7 @@ InternLM2_INPUTS_DOCSTRING = r"""
667
  """
668
 
669
 
 
670
  @add_start_docstrings(
671
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
672
  InternLM2_START_DOCSTRING,
@@ -685,8 +781,10 @@ class InternLM2Model(InternLM2PreTrainedModel):
685
  super().__init__(config)
686
  self.padding_idx = config.pad_token_id
687
  self.vocab_size = config.vocab_size
 
688
 
689
  self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
690
  self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
691
  self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
692
 
@@ -700,7 +798,6 @@ class InternLM2Model(InternLM2PreTrainedModel):
700
  def set_input_embeddings(self, value):
701
  self.tok_embeddings = value
702
 
703
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
704
  def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
705
  # create causal mask
706
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -770,14 +867,18 @@ class InternLM2Model(InternLM2PreTrainedModel):
770
 
771
  if inputs_embeds is None:
772
  inputs_embeds = self.tok_embeddings(input_ids)
773
- # embed positions
774
- if attention_mask is None:
775
- attention_mask = torch.ones(
776
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 
 
 
 
 
 
 
777
  )
778
- attention_mask = self._prepare_decoder_attention_mask(
779
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
780
- )
781
 
782
  # embed positions
783
  hidden_states = inputs_embeds
@@ -851,6 +952,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
851
  )
852
 
853
 
 
854
  class InternLM2ForCausalLM(InternLM2PreTrainedModel):
855
  _auto_class = "AutoModelForCausalLM"
856
 
@@ -1043,8 +1145,8 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1043
  temperature: float = 0.8,
1044
  top_p: float = 0.8,
1045
  meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1046
- "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1047
- "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
1048
  **kwargs,
1049
  ):
1050
  inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
@@ -1149,6 +1251,7 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1149
  return consumer()
1150
 
1151
 
 
1152
  @add_start_docstrings(
1153
  """
1154
  The InternLM2 Model transformer with a sequence classification head on top (linear layer).
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
 
2
  #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
 
 
 
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
 
21
  from typing import List, Optional, Tuple, Union
22
 
23
  import torch
24
+ import torch.nn.functional as F
25
  import torch.utils.checkpoint
26
  from einops import rearrange
27
  from torch import nn
 
51
 
52
  _CONFIG_FOR_DOC = "InternLM2Config"
53
 
54
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
55
+ def _get_unpad_data(attention_mask):
56
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
57
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
58
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
59
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
60
+ return (
61
+ indices,
62
+ cu_seqlens,
63
+ max_seqlen_in_batch,
64
+ )
65
+
66
 
67
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
68
  def _make_causal_mask(
 
97
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
98
 
99
 
100
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
101
  class InternLM2RMSNorm(nn.Module):
102
  def __init__(self, hidden_size, eps=1e-6):
103
  """
 
115
  return self.weight * hidden_states.to(input_dtype)
116
 
117
 
118
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
119
  class InternLM2RotaryEmbedding(nn.Module):
120
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
121
  super().__init__()
 
152
  )
153
 
154
 
155
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
156
  class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
157
  """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
158
 
 
172
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
173
 
174
 
175
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
176
  class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
177
  """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
178
  Credits to the Reddit users /u/bloc97 and /u/emozilla.
 
201
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
202
 
203
 
204
+ # Copied from transformers.model.llama.modeling_llama.rotate_half
205
  def rotate_half(x):
206
  """Rotates half the hidden dims of the input."""
207
  x1 = x[..., : x.shape[-1] // 2]
 
209
  return torch.cat((-x2, x1), dim=-1)
210
 
211
 
212
+ # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
213
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
214
+ """Applies Rotary Position Embedding to the query and key tensors."""
215
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
216
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
217
  q_embed = (q * cos) + (rotate_half(q) * sin)
218
  k_embed = (k * cos) + (rotate_half(k) * sin)
 
219
  return q_embed, k_embed
220
 
221
 
 
236
  return down_proj
237
 
238
 
239
+ # Copied from transformers.model.llama.modeling_llama.repeat_kv
240
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
241
  """
242
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
249
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
250
 
251
 
252
+ # Modified from transformers.model.llama.modeling_llama.LlamaAttention
253
  class InternLM2Attention(nn.Module):
254
  """Multi-headed attention from 'Attention Is All You Need' paper"""
255
 
 
294
  self.head_dim,
295
  max_position_embeddings=self.max_position_embeddings,
296
  base=self.config.rope_theta,
297
+ scaling_factor=scaling_factor,
298
  )
299
  elif scaling_type == "linear":
300
  self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
301
  self.head_dim,
302
  max_position_embeddings=self.max_position_embeddings,
303
  base=self.config.rope_theta,
304
+ scaling_factor=scaling_factor,
305
  )
306
  else:
307
  raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
 
398
  return attn_output, attn_weights, past_key_value
399
 
400
 
401
+ # Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
402
  class InternLM2FlashAttention2(InternLM2Attention):
403
  """
404
  InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
 
435
  qkv_states = rearrange(
436
  qkv_states,
437
  "b q (h gs d) -> b q h gs d",
438
+ gs=2 + self.num_key_value_groups,
439
  d=self.head_dim,
 
440
  )
441
 
442
  query_states = qkv_states[..., : self.num_key_value_groups, :]
 
444
  key_states = qkv_states[..., -2, :]
445
  value_states = qkv_states[..., -1, :]
446
 
447
+ query_states = query_states.transpose(1, 2)
448
+ key_states = key_states.transpose(1, 2)
449
+ value_states = value_states.transpose(1, 2)
450
+
451
  kv_seq_len = key_states.shape[-2]
452
  if past_key_value is not None:
453
  kv_seq_len += past_key_value[0].shape[-2]
 
469
 
470
  dropout_rate = 0.0 if not self.training else self.attention_dropout
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  attn_output = self._flash_attention_forward(
473
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
474
  )
 
475
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
476
  attn_output = self.wo(attn_output)
477
 
 
480
 
481
  return attn_output, attn_weights, past_key_value
482
 
483
+ def _flash_attention_forward(
484
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
485
+ ):
486
+ """
487
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
488
+ first unpad the input, then computes the attention scores and pad the final attention scores.
489
+
490
+ Args:
491
+ query_states (`torch.Tensor`):
492
+ Input query states to be passed to Flash Attention API
493
+ key_states (`torch.Tensor`):
494
+ Input key states to be passed to Flash Attention API
495
+ value_states (`torch.Tensor`):
496
+ Input value states to be passed to Flash Attention API
497
+ attention_mask (`torch.Tensor`):
498
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
499
+ position of padding tokens and 1 for the position of non-padding tokens.
500
+ dropout (`int`, *optional*):
501
+ Attention dropout
502
+ softmax_scale (`float`, *optional*):
503
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
504
+ """
505
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
506
+ from flash_attn.bert_padding import pad_input
507
+ # Contains at least one padding token in the sequence
508
+ causal = self.is_causal and query_length != 1
509
+ if attention_mask is not None:
510
+ batch_size = query_states.shape[0]
511
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
512
+ query_states, key_states, value_states, attention_mask, query_length
513
+ )
514
+
515
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
516
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
517
+
518
+ attn_output_unpad = flash_attn_varlen_func(
519
+ query_states,
520
+ key_states,
521
+ value_states,
522
+ cu_seqlens_q=cu_seqlens_q,
523
+ cu_seqlens_k=cu_seqlens_k,
524
+ max_seqlen_q=max_seqlen_in_batch_q,
525
+ max_seqlen_k=max_seqlen_in_batch_k,
526
+ dropout_p=dropout,
527
+ softmax_scale=softmax_scale,
528
+ causal=causal,
529
+ )
530
 
531
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
532
+ else:
533
+ attn_output = flash_attn_func(
534
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
535
+ )
536
+
537
+ return attn_output
538
+
539
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
540
+ from flash_attn.bert_padding import index_first_axis, unpad_input
541
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
542
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
543
+
544
+ key_layer = index_first_axis(
545
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
546
+ )
547
+ value_layer = index_first_axis(
548
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
549
+ )
550
+
551
+ if query_length == kv_seq_len:
552
+ query_layer = index_first_axis(
553
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
554
+ )
555
+ cu_seqlens_q = cu_seqlens_k
556
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
557
+ indices_q = indices_k
558
+ elif query_length == 1:
559
+ max_seqlen_in_batch_q = 1
560
+ cu_seqlens_q = torch.arange(
561
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
562
+ ) # There is a memcpy here, that is very bad.
563
+ indices_q = cu_seqlens_q[:-1]
564
+ query_layer = query_layer.squeeze(1)
565
+ else:
566
+ # The -q_len: slice assumes left padding.
567
+ attention_mask = attention_mask[:, -query_length:]
568
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
569
+
570
+ return (
571
+ query_layer,
572
+ key_layer,
573
+ value_layer,
574
+ indices_q.to(torch.int64),
575
+ (cu_seqlens_q, cu_seqlens_k),
576
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
577
+ )
578
+
579
+ INTERNLM2_ATTENTION_CLASSES = {
580
+ "eager": InternLM2Attention,
581
+ "flash_attention_2": InternLM2FlashAttention2,
582
+ }
583
+
584
+ # Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
585
  class InternLM2DecoderLayer(nn.Module):
586
  def __init__(self, config: InternLM2Config):
587
  super().__init__()
588
  self.hidden_size = config.hidden_size
589
+
590
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
591
+
 
 
592
  self.feed_forward = InternLM2MLP(config)
593
  self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
594
  self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
673
  """
674
 
675
 
676
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
677
  @add_start_docstrings(
678
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
679
  InternLM2_START_DOCSTRING,
 
684
  supports_gradient_checkpointing = True
685
  _no_split_modules = ["InternLM2DecoderLayer"]
686
  _skip_keys_device_placement = "past_key_values"
 
687
 
688
  def _init_weights(self, module):
689
  std = self.config.initializer_range
 
762
  """
763
 
764
 
765
+ # Modified from transformers.model.llama.modeling_llama.LlamaModel
766
  @add_start_docstrings(
767
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
768
  InternLM2_START_DOCSTRING,
 
781
  super().__init__(config)
782
  self.padding_idx = config.pad_token_id
783
  self.vocab_size = config.vocab_size
784
+ self.config = config
785
 
786
  self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
787
+
788
  self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
789
  self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
790
 
 
798
  def set_input_embeddings(self, value):
799
  self.tok_embeddings = value
800
 
 
801
  def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
802
  # create causal mask
803
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 
867
 
868
  if inputs_embeds is None:
869
  inputs_embeds = self.tok_embeddings(input_ids)
870
+
871
+ if self.config.attn_implementation == "flash_attention_2":
872
+ # 2d mask is passed through the layers
873
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
874
+ else:
875
+ if attention_mask is None:
876
+ attention_mask = torch.ones(
877
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
878
+ )
879
+ attention_mask = self._prepare_decoder_attention_mask(
880
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
881
  )
 
 
 
882
 
883
  # embed positions
884
  hidden_states = inputs_embeds
 
952
  )
953
 
954
 
955
+ # Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
956
  class InternLM2ForCausalLM(InternLM2PreTrainedModel):
957
  _auto_class = "AutoModelForCausalLM"
958
 
 
1145
  temperature: float = 0.8,
1146
  top_p: float = 0.8,
1147
  meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1148
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1149
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
1150
  **kwargs,
1151
  ):
1152
  inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
 
1251
  return consumer()
1252
 
1253
 
1254
+ # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1255
  @add_start_docstrings(
1256
  """
1257
  The InternLM2 Model transformer with a sequence classification head on top (linear layer).