Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +73 -63
modeling_gemmoe.py
CHANGED
@@ -24,8 +24,6 @@ import torch.nn.functional as F
|
|
24 |
import torch.utils.checkpoint
|
25 |
from torch import nn
|
26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
-
import flash_attn_cuda_utils
|
28 |
-
|
29 |
from transformers.activations import ACT2FN
|
30 |
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
31 |
from transformers.modeling_attn_mask_utils import (
|
@@ -375,72 +373,84 @@ class GemmoeFlashAttention2(GemmoeAttention):
|
|
375 |
# TODO: Remove this attribute once Flash Attention for RoCm is bumped to 2.1.
|
376 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
377 |
|
378 |
-
def forward(
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
389 |
-
|
390 |
-
|
391 |
-
bsz, q_len, _ = hidden_states.size()
|
392 |
-
|
393 |
-
query_states = self.q_proj(hidden_states)
|
394 |
-
key_states = self.k_proj(hidden_states)
|
395 |
-
value_states = self.v_proj(hidden_states)
|
396 |
-
|
397 |
-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
398 |
-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
399 |
-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
400 |
-
|
401 |
-
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
402 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
403 |
-
|
404 |
-
past_key_value = getattr(self, "past_key_value", past_key_value)
|
405 |
-
if past_key_value is not None:
|
406 |
-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
407 |
-
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
408 |
-
|
409 |
-
query_states = query_states.transpose(1, 2)
|
410 |
-
key_states = key_states.transpose(1, 2)
|
411 |
-
value_states = value_states.transpose(1, 2)
|
412 |
-
|
413 |
-
dropout_rate = self.attention_dropout if self.training else 0.0
|
414 |
-
|
415 |
-
input_dtype = query_states.dtype
|
416 |
-
if input_dtype == torch.float32:
|
417 |
-
if torch.is_autocast_enabled():
|
418 |
-
target_dtype = torch.get_autocast_gpu_dtype()
|
419 |
-
elif hasattr(self.config, "_pre_quantization_dtype"):
|
420 |
-
target_dtype = self.config._pre_quantization_dtype
|
421 |
-
else:
|
422 |
-
target_dtype = self.q_proj.weight.dtype
|
423 |
|
424 |
-
|
425 |
-
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
426 |
-
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
427 |
-
f" {target_dtype}."
|
428 |
-
)
|
429 |
-
query_states = query_states.to(target_dtype)
|
430 |
-
key_states = key_states.to(target_dtype)
|
431 |
-
value_states = value_states.to(target_dtype)
|
432 |
|
433 |
-
|
434 |
-
|
435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
|
437 |
-
|
438 |
-
|
439 |
|
440 |
-
|
441 |
-
|
|
|
|
|
|
|
442 |
|
443 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
def _flash_attention_forward(
|
446 |
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
|
|
24 |
import torch.utils.checkpoint
|
25 |
from torch import nn
|
26 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
|
27 |
from transformers.activations import ACT2FN
|
28 |
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
29 |
from transformers.modeling_attn_mask_utils import (
|
|
|
373 |
# TODO: Remove this attribute once Flash Attention for RoCm is bumped to 2.1.
|
374 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
375 |
|
376 |
+
def forward(
|
377 |
+
self,
|
378 |
+
hidden_states: torch.Tensor,
|
379 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
380 |
+
position_ids: Optional[torch.LongTensor] = None,
|
381 |
+
past_key_value: Optional[Cache] = None,
|
382 |
+
output_attentions: bool = False,
|
383 |
+
use_cache: bool = False,
|
384 |
+
cache_position: Optional[torch.LongTensor] = None,
|
385 |
+
**kwargs,
|
386 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
387 |
+
output_attentions = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
|
389 |
+
bsz, q_len, _ = hidden_states.size()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
|
391 |
+
query_states = self.q_proj(hidden_states)
|
392 |
+
key_states = self.k_proj(hidden_states)
|
393 |
+
value_states = self.v_proj(hidden_states)
|
394 |
+
|
395 |
+
# Flash attention requires the input to have the shape
|
396 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
397 |
+
# therefore we just need to keep the original shape
|
398 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
399 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
400 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
401 |
|
402 |
+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
403 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
|
404 |
|
405 |
+
past_key_value = getattr(self, "past_key_value", past_key_value)
|
406 |
+
if past_key_value is not None:
|
407 |
+
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
408 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
409 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
410 |
|
411 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
412 |
+
# to be able to avoid many of these transpose/reshape/view.
|
413 |
+
query_states = query_states.transpose(1, 2)
|
414 |
+
key_states = key_states.transpose(1, 2)
|
415 |
+
value_states = value_states.transpose(1, 2)
|
416 |
+
|
417 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
418 |
+
|
419 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
420 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
421 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
422 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
423 |
+
# in fp32. (GemmoeRMSNorm handles it correctly)
|
424 |
+
input_dtype = query_states.dtype
|
425 |
+
if input_dtype == torch.float32:
|
426 |
+
if torch.is_autocast_enabled():
|
427 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
428 |
+
# Handle the case where the model is quantized
|
429 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
430 |
+
target_dtype = self.config._pre_quantization_dtype
|
431 |
+
else:
|
432 |
+
target_dtype = self.q_proj.weight.dtype
|
433 |
+
|
434 |
+
logger.warning_once(
|
435 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
436 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
437 |
+
f" {target_dtype}."
|
438 |
+
)
|
439 |
+
query_states = query_states.to(target_dtype)
|
440 |
+
key_states = key_states.to(target_dtype)
|
441 |
+
value_states = value_states.to(target_dtype)
|
442 |
+
|
443 |
+
attn_output = self._flash_attention_forward(
|
444 |
+
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
445 |
+
)
|
446 |
+
|
447 |
+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
448 |
+
attn_output = self.o_proj(attn_output)
|
449 |
+
|
450 |
+
if not output_attentions:
|
451 |
+
attn_weights = None
|
452 |
+
|
453 |
+
return attn_output, attn_weights, past_key_value
|
454 |
|
455 |
def _flash_attention_forward(
|
456 |
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|