chenlin commited on
Commit
a285f80
1 Parent(s): 60d2df3

fix import

Browse files
Files changed (1) hide show
  1. modeling_InternLM.py +44 -53
modeling_InternLM.py CHANGED
@@ -2,12 +2,10 @@ import math
2
  from typing import List, Union
3
  from typing import Optional, Tuple
4
 
5
- import rotary_emb
6
  import torch
7
  import torch.utils.checkpoint
8
  import torch.utils.checkpoint
9
  from einops import rearrange
10
- from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
11
  from torch import nn
12
  from torch.nn import CrossEntropyLoss
13
  from transformers.activations import ACT2FN
@@ -23,51 +21,70 @@ logger = logging.get_logger(__name__)
23
  _CONFIG_FOR_DOC = "InternLMXComposerConfig"
24
 
25
 
26
- class ApplyRotaryEmbQKV_(torch.autograd.Function):
27
- """
28
- ApplyRotaryEmbQKV_
29
- """
 
 
 
 
 
 
 
30
  @staticmethod
31
- def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
32
  """
33
- qkv: (total, 3, nheads, headdim)
34
  cos, sin: (seqlen, rotary_dim / 2)
35
  cos_k, sin_k: (seqlen, rotary_dim / 2), optional
 
 
36
  rotary_dim must be <= headdim
37
  Apply rotary embedding *inplace* to the first rotary_dim of q and k.
38
  """
39
- _, three, _, headdim = qkv.shape
40
  assert three == 3
41
  rotary_seqlen, rotary_dim = cos.shape
42
  rotary_dim *= 2
43
  assert rotary_dim <= headdim
 
44
  cos_k = cos if cos_k is None else cos_k
45
  sin_k = sin if sin_k is None else sin_k
46
- assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen,
47
- rotary_dim // 2)
48
- q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
49
- rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"),
50
- rearrange(sin, "s d -> s 1 d"), q1, q2, False)
51
- k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
52
- rotary_emb.apply_rotary(k1, k2, rearrange(cos_k, "s d -> s 1 d"),
53
- rearrange(sin_k, "s d -> s 1 d"), k1, k2,
54
- False)
 
 
 
 
55
  ctx.save_for_backward(cos, sin, cos_k, sin_k)
 
56
  return qkv
57
 
58
  @staticmethod
59
  def backward(ctx, dqkv):
60
  cos, sin, cos_k, sin_k = ctx.saved_tensors
 
61
  rotary_dim = cos.shape[-1]
62
  rotary_dim *= 2
63
- dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
64
- rotary_emb.apply_rotary(dq1, dq2, rearrange(cos, "s d -> s 1 d"),
65
- rearrange(sin, "s d -> s 1 d"), dq1, dq2, True)
66
- dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
67
- rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k, "s d -> s 1 d"),
68
- rearrange(sin_k, "s d -> s 1 d"), dk1, dk2,
69
- True)
70
- return dqkv, None, None, None, None
 
 
 
71
 
72
 
73
  class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
@@ -120,23 +137,6 @@ class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
120
  self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
121
  self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
122
 
123
- def forward(self,
124
- qkv: torch.Tensor,
125
- indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
126
- self._update_cos_sin_cache(qkv, indexes)
127
- if self.scale is None:
128
- return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes],
129
- self._sin_cached[indexes]).to(
130
- qkv.dtype)
131
- else:
132
- return apply_rotary_emb_qkv_(
133
- qkv,
134
- self._cos_cached[indexes],
135
- self._sin_cached[indexes],
136
- self._cos_k_cached[indexes],
137
- self._sin_k_cached[indexes],
138
- ).to(qkv.dtype)
139
-
140
  def eval_forward(self, qkv, seqlen_offset=0):
141
  """
142
  seqlen_offset: can be used in generation where the qkv being passed in is only the last
@@ -157,7 +157,6 @@ class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
157
  )
158
 
159
 
160
- apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
161
  legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
162
 
163
 
@@ -487,7 +486,6 @@ class InternLMPreTrainedModel(PreTrainedModel):
487
  class InternLMModel(InternLMPreTrainedModel):
488
  """
489
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`]
490
-
491
  Args:
492
  config: InternLMXComposerConfig
493
  """
@@ -631,7 +629,7 @@ class InternLMModel(InternLMPreTrainedModel):
631
  past_key_value = past_key_values[
632
  idx] if past_key_values is not None else None
633
 
634
- if self.gradient_checkpointing and self.training and idx % 2 == 0:
635
 
636
  def create_custom_forward(module):
637
  def custom_forward(*inputs):
@@ -696,7 +694,6 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
696
  setattr(config, 'kqvo_bias', config.kqvo_bias)
697
  else:
698
  setattr(config, 'kqvo_bias', False)
699
-
700
  self.model = InternLMModel(config)
701
 
702
  self.lm_head = nn.Linear(config.hidden_size,
@@ -762,20 +759,14 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
762
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
763
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
764
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
765
-
766
  Returns:
767
-
768
  Example:
769
-
770
  ```python
771
  >>> from transformers import AutoTokenizer, InternLMForCausalLM
772
-
773
  >>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
774
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
775
-
776
  >>> prompt = "Hey, are you consciours? Can you talk to me?"
777
  >>> inputs = tokenizer(prompt, return_tensors="pt")
778
-
779
  >>> # Generate
780
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
781
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
2
  from typing import List, Union
3
  from typing import Optional, Tuple
4
 
 
5
  import torch
6
  import torch.utils.checkpoint
7
  import torch.utils.checkpoint
8
  from einops import rearrange
 
9
  from torch import nn
10
  from torch.nn import CrossEntropyLoss
11
  from transformers.activations import ACT2FN
 
21
  _CONFIG_FOR_DOC = "InternLMXComposerConfig"
22
 
23
 
24
+ def rotary_embed(x1, x2, cos, sin, conj):
25
+ x1, x2 = x1.float(), x2.float()
26
+ if conj:
27
+ x1, x2 = x1 * cos + x2 * sin, x1 * sin + x2 * cos
28
+ else:
29
+ x1, x2 = x1 * cos - x2 * sin, x1 * sin + x2 * cos
30
+ return x1, x2
31
+
32
+
33
+ class LegacyApplyRotaryEmbQKV_(torch.autograd.Function):
34
+
35
  @staticmethod
36
+ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
37
  """
38
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
39
  cos, sin: (seqlen, rotary_dim / 2)
40
  cos_k, sin_k: (seqlen, rotary_dim / 2), optional
41
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
42
+ 1st half and 2nd half (GPT-NeoX style).
43
  rotary_dim must be <= headdim
44
  Apply rotary embedding *inplace* to the first rotary_dim of q and k.
45
  """
46
+ batch, seqlen, three, nheads, headdim = qkv.shape
47
  assert three == 3
48
  rotary_seqlen, rotary_dim = cos.shape
49
  rotary_dim *= 2
50
  assert rotary_dim <= headdim
51
+ assert seqlen <= rotary_seqlen
52
  cos_k = cos if cos_k is None else cos_k
53
  sin_k = sin if sin_k is None else sin_k
54
+ assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
55
+ q_ro = qkv[:, :, 0, :, :rotary_dim]
56
+ q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
57
+ # rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
58
+ # rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
59
+ q1, q2 = rotary_embed(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), False)
60
+ qkv[:, :, 0, :, :rotary_dim] = torch.cat([q1, q2], dim=-1)
61
+ k_ro = qkv[:, :, 1, :, :rotary_dim]
62
+ k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
63
+ # rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
64
+ # rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
65
+ k1, k2 = rotary_embed(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), rearrange(sin_k[:seqlen], 's d -> s 1 d'), False)
66
+ qkv[:, :, 1, :, :rotary_dim] = torch.cat([k1, k2], dim=-1)
67
  ctx.save_for_backward(cos, sin, cos_k, sin_k)
68
+ ctx.interleaved = interleaved
69
  return qkv
70
 
71
  @staticmethod
72
  def backward(ctx, dqkv):
73
  cos, sin, cos_k, sin_k = ctx.saved_tensors
74
+ _, seqlen, _, _, headdim = dqkv.shape
75
  rotary_dim = cos.shape[-1]
76
  rotary_dim *= 2
77
+ dq_ro = dqkv[:, :, 0, :, :rotary_dim]
78
+ dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
79
+ else (dq_ro[..., ::2], dq_ro[..., 1::2]))
80
+ rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
81
+ rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
82
+ dk_ro = dqkv[:, :, 1, :, :rotary_dim]
83
+ dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
84
+ else (dk_ro[..., ::2], dk_ro[..., 1::2]))
85
+ rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
86
+ rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
87
+ return dqkv, None, None, None, None, None
88
 
89
 
90
  class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
 
137
  self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
138
  self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  def eval_forward(self, qkv, seqlen_offset=0):
141
  """
142
  seqlen_offset: can be used in generation where the qkv being passed in is only the last
 
157
  )
158
 
159
 
 
160
  legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
161
 
162
 
 
486
  class InternLMModel(InternLMPreTrainedModel):
487
  """
488
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`]
 
489
  Args:
490
  config: InternLMXComposerConfig
491
  """
 
629
  past_key_value = past_key_values[
630
  idx] if past_key_values is not None else None
631
 
632
+ if self.gradient_checkpointing and self.training:
633
 
634
  def create_custom_forward(module):
635
  def custom_forward(*inputs):
 
694
  setattr(config, 'kqvo_bias', config.kqvo_bias)
695
  else:
696
  setattr(config, 'kqvo_bias', False)
 
697
  self.model = InternLMModel(config)
698
 
699
  self.lm_head = nn.Linear(config.hidden_size,
 
759
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
760
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
761
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
762
  Returns:
 
763
  Example:
 
764
  ```python
765
  >>> from transformers import AutoTokenizer, InternLMForCausalLM
 
766
  >>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
767
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
 
768
  >>> prompt = "Hey, are you consciours? Can you talk to me?"
769
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
770
  >>> # Generate
771
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
772
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]