jeffreygao
commited on
Commit
•
a3e1065
1
Parent(s):
6e571c3
Update modeling_bluelm.py
Browse files- modeling_bluelm.py +20 -22
modeling_bluelm.py
CHANGED
@@ -32,7 +32,12 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu
|
|
32 |
from transformers.modeling_utils import PreTrainedModel
|
33 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
34 |
from .configuration_bluelm import BlueLMConfig
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
try:
|
38 |
from xformers import ops as xops
|
@@ -213,6 +218,11 @@ class BlueLMAttention(nn.Module):
|
|
213 |
hidden_size,
|
214 |
bias=False,
|
215 |
)
|
|
|
|
|
|
|
|
|
|
|
216 |
self.rotary_emb = BlueLMRotaryEmbedding(self.head_dim)
|
217 |
if xops is not None:
|
218 |
self.causal_mask = xops.LowerTriangularMask()
|
@@ -230,7 +240,8 @@ class BlueLMAttention(nn.Module):
|
|
230 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
231 |
"""Input shape: Batch x Time x Channel"""
|
232 |
|
233 |
-
bsz, q_len,
|
|
|
234 |
|
235 |
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
236 |
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
@@ -245,7 +256,7 @@ class BlueLMAttention(nn.Module):
|
|
245 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, offset=offset)
|
246 |
# [bsz, t, nh, hd]
|
247 |
|
248 |
-
if
|
249 |
# reuse k, v, self_attention
|
250 |
key_states = torch.cat([past_key_value[0], key_states], dim=1)
|
251 |
value_states = torch.cat([past_key_value[1], value_states], dim=1)
|
@@ -260,25 +271,12 @@ class BlueLMAttention(nn.Module):
|
|
260 |
)
|
261 |
else:
|
262 |
# [bsz, t, nh, hd]
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
f" {attn_weights.size()}"
|
269 |
-
)
|
270 |
-
|
271 |
-
if attention_mask is not None:
|
272 |
-
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
273 |
-
raise ValueError(
|
274 |
-
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
275 |
-
)
|
276 |
-
attn_weights = attn_weights + attention_mask
|
277 |
-
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
278 |
|
279 |
-
# upcast attention to fp32
|
280 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
281 |
-
attn_output = torch.einsum("bnqk,bknh->bqnh", attn_weights, value_states)
|
282 |
|
283 |
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
284 |
raise ValueError(
|
@@ -612,7 +610,7 @@ class BlueLMModel(BlueLMPreTrainedModel):
|
|
612 |
seq_length_with_past = seq_length
|
613 |
past_key_values_length = 0
|
614 |
if past_key_values is not None:
|
615 |
-
past_key_values_length = past_key_values[0][0].shape[
|
616 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
617 |
if inputs_embeds is None:
|
618 |
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
32 |
from transformers.modeling_utils import PreTrainedModel
|
33 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
34 |
from .configuration_bluelm import BlueLMConfig
|
35 |
+
from flash_attn.flash_attn_interface import (
|
36 |
+
flash_attn_func,
|
37 |
+
flash_attn_kvpacked_func,
|
38 |
+
flash_attn_qkvpacked_func,
|
39 |
+
flash_attn_varlen_kvpacked_func,
|
40 |
+
)
|
41 |
|
42 |
try:
|
43 |
from xformers import ops as xops
|
|
|
218 |
hidden_size,
|
219 |
bias=False,
|
220 |
)
|
221 |
+
self.register_buffer(
|
222 |
+
"norm_factor",
|
223 |
+
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
|
224 |
+
persistent=False,
|
225 |
+
)
|
226 |
self.rotary_emb = BlueLMRotaryEmbedding(self.head_dim)
|
227 |
if xops is not None:
|
228 |
self.causal_mask = xops.LowerTriangularMask()
|
|
|
240 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
241 |
"""Input shape: Batch x Time x Channel"""
|
242 |
|
243 |
+
bsz, q_len, h_size = hidden_states.size()
|
244 |
+
has_layer_past = past_key_value is not None
|
245 |
|
246 |
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
247 |
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
|
|
256 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, offset=offset)
|
257 |
# [bsz, t, nh, hd]
|
258 |
|
259 |
+
if has_layer_past:
|
260 |
# reuse k, v, self_attention
|
261 |
key_states = torch.cat([past_key_value[0], key_states], dim=1)
|
262 |
value_states = torch.cat([past_key_value[1], value_states], dim=1)
|
|
|
271 |
)
|
272 |
else:
|
273 |
# [bsz, t, nh, hd]
|
274 |
+
kv = torch.stack([key_states, value_states], 2)
|
275 |
+
attn_outputs = flash_attn_kvpacked_func(
|
276 |
+
query_states, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions)
|
277 |
+
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
278 |
+
attn_weights = attn_outputs[2] if output_attentions else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
|
|
|
|
|
|
|
280 |
|
281 |
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
282 |
raise ValueError(
|
|
|
610 |
seq_length_with_past = seq_length
|
611 |
past_key_values_length = 0
|
612 |
if past_key_values is not None:
|
613 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
614 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
615 |
if inputs_embeds is None:
|
616 |
inputs_embeds = self.embed_tokens(input_ids)
|