Crystalcareai commited on
Commit
ad43155
1 Parent(s): 5450314

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +73 -63
modeling_gemmoe.py CHANGED
@@ -24,8 +24,6 @@ import torch.nn.functional as F
24
  import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
- import flash_attn_cuda_utils
28
-
29
  from transformers.activations import ACT2FN
30
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
31
  from transformers.modeling_attn_mask_utils import (
@@ -375,72 +373,84 @@ class GemmoeFlashAttention2(GemmoeAttention):
375
  # TODO: Remove this attribute once Flash Attention for RoCm is bumped to 2.1.
376
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
377
 
378
- def forward(
379
- self,
380
- hidden_states: torch.Tensor,
381
- attention_mask: Optional[torch.LongTensor] = None,
382
- position_ids: Optional[torch.LongTensor] = None,
383
- past_key_value: Optional[Cache] = None,
384
- output_attentions: bool = False,
385
- use_cache: bool = False,
386
- cache_position: Optional[torch.LongTensor] = None,
387
- **kwargs,
388
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
389
- output_attentions = False
390
-
391
- bsz, q_len, _ = hidden_states.size()
392
-
393
- query_states = self.q_proj(hidden_states)
394
- key_states = self.k_proj(hidden_states)
395
- value_states = self.v_proj(hidden_states)
396
-
397
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
398
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
399
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
400
-
401
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
402
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
403
-
404
- past_key_value = getattr(self, "past_key_value", past_key_value)
405
- if past_key_value is not None:
406
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
407
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
408
-
409
- query_states = query_states.transpose(1, 2)
410
- key_states = key_states.transpose(1, 2)
411
- value_states = value_states.transpose(1, 2)
412
-
413
- dropout_rate = self.attention_dropout if self.training else 0.0
414
-
415
- input_dtype = query_states.dtype
416
- if input_dtype == torch.float32:
417
- if torch.is_autocast_enabled():
418
- target_dtype = torch.get_autocast_gpu_dtype()
419
- elif hasattr(self.config, "_pre_quantization_dtype"):
420
- target_dtype = self.config._pre_quantization_dtype
421
- else:
422
- target_dtype = self.q_proj.weight.dtype
423
 
424
- logger.warning_once(
425
- f"The input hidden states seems to be silently casted in float32, this might be related to"
426
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
427
- f" {target_dtype}."
428
- )
429
- query_states = query_states.to(target_dtype)
430
- key_states = key_states.to(target_dtype)
431
- value_states = value_states.to(target_dtype)
432
 
433
- attn_output = flash_attn_cuda_utils.pyt_flash_scaled_dot_attention(
434
- query_states, key_states, value_states, attn_mask=attention_mask, dropout_prob=dropout_rate
435
- )
 
 
 
 
 
 
 
436
 
437
- attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
438
- attn_output = self.o_proj(attn_output)
439
 
440
- if not output_attentions:
441
- attn_weights = None
 
 
 
442
 
443
- return attn_output, attn_weights, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  def _flash_attention_forward(
446
  self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
 
24
  import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
 
27
  from transformers.activations import ACT2FN
28
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
29
  from transformers.modeling_attn_mask_utils import (
 
373
  # TODO: Remove this attribute once Flash Attention for RoCm is bumped to 2.1.
374
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
375
 
376
+ def forward(
377
+ self,
378
+ hidden_states: torch.Tensor,
379
+ attention_mask: Optional[torch.LongTensor] = None,
380
+ position_ids: Optional[torch.LongTensor] = None,
381
+ past_key_value: Optional[Cache] = None,
382
+ output_attentions: bool = False,
383
+ use_cache: bool = False,
384
+ cache_position: Optional[torch.LongTensor] = None,
385
+ **kwargs,
386
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
387
+ output_attentions = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
+ bsz, q_len, _ = hidden_states.size()
 
 
 
 
 
 
 
390
 
391
+ query_states = self.q_proj(hidden_states)
392
+ key_states = self.k_proj(hidden_states)
393
+ value_states = self.v_proj(hidden_states)
394
+
395
+ # Flash attention requires the input to have the shape
396
+ # batch_size x seq_length x head_dim x hidden_dim
397
+ # therefore we just need to keep the original shape
398
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
399
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
400
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
401
 
402
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
403
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
404
 
405
+ past_key_value = getattr(self, "past_key_value", past_key_value)
406
+ if past_key_value is not None:
407
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
408
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
409
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
410
 
411
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
412
+ # to be able to avoid many of these transpose/reshape/view.
413
+ query_states = query_states.transpose(1, 2)
414
+ key_states = key_states.transpose(1, 2)
415
+ value_states = value_states.transpose(1, 2)
416
+
417
+ dropout_rate = self.attention_dropout if self.training else 0.0
418
+
419
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
420
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
421
+ # cast them back in the correct dtype just to be sure everything works as expected.
422
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
423
+ # in fp32. (GemmoeRMSNorm handles it correctly)
424
+ input_dtype = query_states.dtype
425
+ if input_dtype == torch.float32:
426
+ if torch.is_autocast_enabled():
427
+ target_dtype = torch.get_autocast_gpu_dtype()
428
+ # Handle the case where the model is quantized
429
+ elif hasattr(self.config, "_pre_quantization_dtype"):
430
+ target_dtype = self.config._pre_quantization_dtype
431
+ else:
432
+ target_dtype = self.q_proj.weight.dtype
433
+
434
+ logger.warning_once(
435
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
436
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
437
+ f" {target_dtype}."
438
+ )
439
+ query_states = query_states.to(target_dtype)
440
+ key_states = key_states.to(target_dtype)
441
+ value_states = value_states.to(target_dtype)
442
+
443
+ attn_output = self._flash_attention_forward(
444
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
445
+ )
446
+
447
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
448
+ attn_output = self.o_proj(attn_output)
449
+
450
+ if not output_attentions:
451
+ attn_weights = None
452
+
453
+ return attn_output, attn_weights, past_key_value
454
 
455
  def _flash_attention_forward(
456
  self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None