chenlin
commited on
Commit
•
a285f80
1
Parent(s):
60d2df3
fix import
Browse files- modeling_InternLM.py +44 -53
modeling_InternLM.py
CHANGED
@@ -2,12 +2,10 @@ import math
|
|
2 |
from typing import List, Union
|
3 |
from typing import Optional, Tuple
|
4 |
|
5 |
-
import rotary_emb
|
6 |
import torch
|
7 |
import torch.utils.checkpoint
|
8 |
import torch.utils.checkpoint
|
9 |
from einops import rearrange
|
10 |
-
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
|
11 |
from torch import nn
|
12 |
from torch.nn import CrossEntropyLoss
|
13 |
from transformers.activations import ACT2FN
|
@@ -23,51 +21,70 @@ logger = logging.get_logger(__name__)
|
|
23 |
_CONFIG_FOR_DOC = "InternLMXComposerConfig"
|
24 |
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
@staticmethod
|
31 |
-
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
|
32 |
"""
|
33 |
-
qkv: (
|
34 |
cos, sin: (seqlen, rotary_dim / 2)
|
35 |
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
|
|
|
|
36 |
rotary_dim must be <= headdim
|
37 |
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
38 |
"""
|
39 |
-
|
40 |
assert three == 3
|
41 |
rotary_seqlen, rotary_dim = cos.shape
|
42 |
rotary_dim *= 2
|
43 |
assert rotary_dim <= headdim
|
|
|
44 |
cos_k = cos if cos_k is None else cos_k
|
45 |
sin_k = sin if sin_k is None else sin_k
|
46 |
-
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen,
|
47 |
-
|
48 |
-
q1, q2 =
|
49 |
-
rotary_emb.apply_rotary(q1, q2, rearrange(cos,
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
55 |
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
|
|
56 |
return qkv
|
57 |
|
58 |
@staticmethod
|
59 |
def backward(ctx, dqkv):
|
60 |
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
|
|
61 |
rotary_dim = cos.shape[-1]
|
62 |
rotary_dim *= 2
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
71 |
|
72 |
|
73 |
class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
|
@@ -120,23 +137,6 @@ class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
|
|
120 |
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
121 |
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
122 |
|
123 |
-
def forward(self,
|
124 |
-
qkv: torch.Tensor,
|
125 |
-
indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
126 |
-
self._update_cos_sin_cache(qkv, indexes)
|
127 |
-
if self.scale is None:
|
128 |
-
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes],
|
129 |
-
self._sin_cached[indexes]).to(
|
130 |
-
qkv.dtype)
|
131 |
-
else:
|
132 |
-
return apply_rotary_emb_qkv_(
|
133 |
-
qkv,
|
134 |
-
self._cos_cached[indexes],
|
135 |
-
self._sin_cached[indexes],
|
136 |
-
self._cos_k_cached[indexes],
|
137 |
-
self._sin_k_cached[indexes],
|
138 |
-
).to(qkv.dtype)
|
139 |
-
|
140 |
def eval_forward(self, qkv, seqlen_offset=0):
|
141 |
"""
|
142 |
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
@@ -157,7 +157,6 @@ class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
|
|
157 |
)
|
158 |
|
159 |
|
160 |
-
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
161 |
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
|
162 |
|
163 |
|
@@ -487,7 +486,6 @@ class InternLMPreTrainedModel(PreTrainedModel):
|
|
487 |
class InternLMModel(InternLMPreTrainedModel):
|
488 |
"""
|
489 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`]
|
490 |
-
|
491 |
Args:
|
492 |
config: InternLMXComposerConfig
|
493 |
"""
|
@@ -631,7 +629,7 @@ class InternLMModel(InternLMPreTrainedModel):
|
|
631 |
past_key_value = past_key_values[
|
632 |
idx] if past_key_values is not None else None
|
633 |
|
634 |
-
if self.gradient_checkpointing and self.training
|
635 |
|
636 |
def create_custom_forward(module):
|
637 |
def custom_forward(*inputs):
|
@@ -696,7 +694,6 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
696 |
setattr(config, 'kqvo_bias', config.kqvo_bias)
|
697 |
else:
|
698 |
setattr(config, 'kqvo_bias', False)
|
699 |
-
|
700 |
self.model = InternLMModel(config)
|
701 |
|
702 |
self.lm_head = nn.Linear(config.hidden_size,
|
@@ -762,20 +759,14 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
762 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
763 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
764 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
765 |
-
|
766 |
Returns:
|
767 |
-
|
768 |
Example:
|
769 |
-
|
770 |
```python
|
771 |
>>> from transformers import AutoTokenizer, InternLMForCausalLM
|
772 |
-
|
773 |
>>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
774 |
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
775 |
-
|
776 |
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
777 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
778 |
-
|
779 |
>>> # Generate
|
780 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
781 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
|
2 |
from typing import List, Union
|
3 |
from typing import Optional, Tuple
|
4 |
|
|
|
5 |
import torch
|
6 |
import torch.utils.checkpoint
|
7 |
import torch.utils.checkpoint
|
8 |
from einops import rearrange
|
|
|
9 |
from torch import nn
|
10 |
from torch.nn import CrossEntropyLoss
|
11 |
from transformers.activations import ACT2FN
|
|
|
21 |
_CONFIG_FOR_DOC = "InternLMXComposerConfig"
|
22 |
|
23 |
|
24 |
+
def rotary_embed(x1, x2, cos, sin, conj):
|
25 |
+
x1, x2 = x1.float(), x2.float()
|
26 |
+
if conj:
|
27 |
+
x1, x2 = x1 * cos + x2 * sin, x1 * sin + x2 * cos
|
28 |
+
else:
|
29 |
+
x1, x2 = x1 * cos - x2 * sin, x1 * sin + x2 * cos
|
30 |
+
return x1, x2
|
31 |
+
|
32 |
+
|
33 |
+
class LegacyApplyRotaryEmbQKV_(torch.autograd.Function):
|
34 |
+
|
35 |
@staticmethod
|
36 |
+
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
|
37 |
"""
|
38 |
+
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
39 |
cos, sin: (seqlen, rotary_dim / 2)
|
40 |
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
41 |
+
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
42 |
+
1st half and 2nd half (GPT-NeoX style).
|
43 |
rotary_dim must be <= headdim
|
44 |
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
45 |
"""
|
46 |
+
batch, seqlen, three, nheads, headdim = qkv.shape
|
47 |
assert three == 3
|
48 |
rotary_seqlen, rotary_dim = cos.shape
|
49 |
rotary_dim *= 2
|
50 |
assert rotary_dim <= headdim
|
51 |
+
assert seqlen <= rotary_seqlen
|
52 |
cos_k = cos if cos_k is None else cos_k
|
53 |
sin_k = sin if sin_k is None else sin_k
|
54 |
+
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
55 |
+
q_ro = qkv[:, :, 0, :, :rotary_dim]
|
56 |
+
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
|
57 |
+
# rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
58 |
+
# rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
|
59 |
+
q1, q2 = rotary_embed(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), False)
|
60 |
+
qkv[:, :, 0, :, :rotary_dim] = torch.cat([q1, q2], dim=-1)
|
61 |
+
k_ro = qkv[:, :, 1, :, :rotary_dim]
|
62 |
+
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
|
63 |
+
# rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
64 |
+
# rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
|
65 |
+
k1, k2 = rotary_embed(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), rearrange(sin_k[:seqlen], 's d -> s 1 d'), False)
|
66 |
+
qkv[:, :, 1, :, :rotary_dim] = torch.cat([k1, k2], dim=-1)
|
67 |
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
68 |
+
ctx.interleaved = interleaved
|
69 |
return qkv
|
70 |
|
71 |
@staticmethod
|
72 |
def backward(ctx, dqkv):
|
73 |
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
74 |
+
_, seqlen, _, _, headdim = dqkv.shape
|
75 |
rotary_dim = cos.shape[-1]
|
76 |
rotary_dim *= 2
|
77 |
+
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
|
78 |
+
dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
|
79 |
+
else (dq_ro[..., ::2], dq_ro[..., 1::2]))
|
80 |
+
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
81 |
+
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
|
82 |
+
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
|
83 |
+
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
|
84 |
+
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
|
85 |
+
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
86 |
+
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
|
87 |
+
return dqkv, None, None, None, None, None
|
88 |
|
89 |
|
90 |
class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
|
|
|
137 |
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
138 |
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
def eval_forward(self, qkv, seqlen_offset=0):
|
141 |
"""
|
142 |
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
|
|
157 |
)
|
158 |
|
159 |
|
|
|
160 |
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
|
161 |
|
162 |
|
|
|
486 |
class InternLMModel(InternLMPreTrainedModel):
|
487 |
"""
|
488 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`]
|
|
|
489 |
Args:
|
490 |
config: InternLMXComposerConfig
|
491 |
"""
|
|
|
629 |
past_key_value = past_key_values[
|
630 |
idx] if past_key_values is not None else None
|
631 |
|
632 |
+
if self.gradient_checkpointing and self.training:
|
633 |
|
634 |
def create_custom_forward(module):
|
635 |
def custom_forward(*inputs):
|
|
|
694 |
setattr(config, 'kqvo_bias', config.kqvo_bias)
|
695 |
else:
|
696 |
setattr(config, 'kqvo_bias', False)
|
|
|
697 |
self.model = InternLMModel(config)
|
698 |
|
699 |
self.lm_head = nn.Linear(config.hidden_size,
|
|
|
759 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
760 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
761 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
762 |
Returns:
|
|
|
763 |
Example:
|
|
|
764 |
```python
|
765 |
>>> from transformers import AutoTokenizer, InternLMForCausalLM
|
|
|
766 |
>>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
767 |
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
|
|
768 |
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
769 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
770 |
>>> # Generate
|
771 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
772 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|