gugarosa commited on
Commit
0ef07d7
1 Parent(s): d931c54

Update modeling_phi3.py

Browse files
Files changed (1) hide show
  1. modeling_phi3.py +18 -28
modeling_phi3.py CHANGED
@@ -25,6 +25,7 @@ import torch.nn.functional as F
25
  import torch.utils.checkpoint
26
  from torch import nn
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
28
  from transformers.activations import ACT2FN
29
  from transformers.cache_utils import Cache, DynamicCache
30
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
@@ -43,9 +44,9 @@ from transformers.utils import (
43
  logging,
44
  replace_return_docstrings,
45
  )
46
-
47
  from .configuration_phi3 import Phi3Config
48
 
 
49
  logger = logging.get_logger(__name__)
50
 
51
  # Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
@@ -86,7 +87,7 @@ PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [
86
 
87
  # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
88
  class Phi3RMSNorm(nn.Module):
89
- def __init__(self, hidden_size, eps=1e-5):
90
  """
91
  Phi3RMSNorm is equivalent to T5LayerNorm
92
  """
@@ -120,7 +121,7 @@ def _get_unpad_data(attention_mask):
120
 
121
  # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi3
122
  class Phi3RotaryEmbedding(nn.Module):
123
- def __init__(self, dim, max_position_embeddings=4096, base=10000, device=None):
124
  super().__init__()
125
 
126
  self.dim = dim
@@ -228,7 +229,6 @@ def rotate_half(x):
228
  return torch.cat((-x2, x1), dim=-1)
229
 
230
 
231
- # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
232
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
233
  """Applies Rotary Position Embedding to the query and key tensors.
234
 
@@ -608,7 +608,7 @@ class Phi3FlashAttention2(Phi3Attention):
608
 
609
  return attn_output, attn_weights, past_key_value
610
 
611
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
612
  def _flash_attention_forward(
613
  self,
614
  query_states,
@@ -650,14 +650,9 @@ class Phi3FlashAttention2(Phi3Attention):
650
  # Contains at least one padding token in the sequence
651
  if attention_mask is not None:
652
  batch_size = query_states.shape[0]
653
- (
654
- query_states,
655
- key_states,
656
- value_states,
657
- indices_q,
658
- cu_seq_lens,
659
- max_seq_lens,
660
- ) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
661
 
662
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
663
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
@@ -687,10 +682,7 @@ class Phi3FlashAttention2(Phi3Attention):
687
  dropout_p=dropout,
688
  softmax_scale=softmax_scale,
689
  causal=causal,
690
- window_size=(
691
- self.config.sliding_window,
692
- self.config.sliding_window,
693
- ),
694
  )
695
 
696
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
@@ -712,15 +704,12 @@ class Phi3FlashAttention2(Phi3Attention):
712
  dropout,
713
  softmax_scale=softmax_scale,
714
  causal=causal,
715
- window_size=(
716
- self.config.sliding_window,
717
- self.config.sliding_window,
718
- ),
719
  )
720
 
721
  return attn_output
722
 
723
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
724
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
725
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
726
 
@@ -737,8 +726,7 @@ class Phi3FlashAttention2(Phi3Attention):
737
 
738
  if query_length == kv_seq_len:
739
  query_layer = index_first_axis(
740
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim),
741
- indices_k,
742
  )
743
  cu_seqlens_q = cu_seqlens_k
744
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
@@ -1233,7 +1221,7 @@ class Phi3Model(Phi3PreTrainedModel):
1233
  class Phi3ForCausalLM(Phi3PreTrainedModel):
1234
  _tied_weights_keys = ["lm_head.weight"]
1235
 
1236
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3,bias=False->bias=True
1237
  def __init__(self, config):
1238
  super().__init__(config)
1239
  self.model = Phi3Model(config)
@@ -1439,7 +1427,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
1439
  """,
1440
  PHI3_START_DOCSTRING,
1441
  )
1442
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3 with self.transformer->self.model, transformer_outputs->model_outputs
1443
  class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1444
  def __init__(self, config):
1445
  super().__init__(config)
@@ -1555,7 +1543,7 @@ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1555
  """,
1556
  PHI3_START_DOCSTRING,
1557
  )
1558
- # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,self.transformer->self.model,transformer_outputs->model_outputs
1559
  class Phi3ForTokenClassification(Phi3PreTrainedModel):
1560
  def __init__(self, config: Phi3Config):
1561
  super().__init__(config)
@@ -1622,7 +1610,9 @@ class Phi3ForTokenClassification(Phi3PreTrainedModel):
1622
  labels = labels.to(logits.device)
1623
  batch_size, seq_length = labels.shape
1624
  loss_fct = CrossEntropyLoss()
1625
- loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
 
 
1626
 
1627
  if not return_dict:
1628
  output = (logits,) + model_outputs[2:]
 
25
  import torch.utils.checkpoint
26
  from torch import nn
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+
29
  from transformers.activations import ACT2FN
30
  from transformers.cache_utils import Cache, DynamicCache
31
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 
44
  logging,
45
  replace_return_docstrings,
46
  )
 
47
  from .configuration_phi3 import Phi3Config
48
 
49
+
50
  logger = logging.get_logger(__name__)
51
 
52
  # Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
 
87
 
88
  # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
89
  class Phi3RMSNorm(nn.Module):
90
+ def __init__(self, hidden_size, eps=1e-6):
91
  """
92
  Phi3RMSNorm is equivalent to T5LayerNorm
93
  """
 
121
 
122
  # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi3
123
  class Phi3RotaryEmbedding(nn.Module):
124
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
125
  super().__init__()
126
 
127
  self.dim = dim
 
229
  return torch.cat((-x2, x1), dim=-1)
230
 
231
 
 
232
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
233
  """Applies Rotary Position Embedding to the query and key tensors.
234
 
 
608
 
609
  return attn_output, attn_weights, past_key_value
610
 
611
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
612
  def _flash_attention_forward(
613
  self,
614
  query_states,
 
650
  # Contains at least one padding token in the sequence
651
  if attention_mask is not None:
652
  batch_size = query_states.shape[0]
653
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
654
+ query_states, key_states, value_states, attention_mask, query_length
655
+ )
 
 
 
 
 
656
 
657
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
658
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
 
682
  dropout_p=dropout,
683
  softmax_scale=softmax_scale,
684
  causal=causal,
685
+ window_size=(self.config.sliding_window, self.config.sliding_window),
 
 
 
686
  )
687
 
688
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
704
  dropout,
705
  softmax_scale=softmax_scale,
706
  causal=causal,
707
+ window_size=(self.config.sliding_window, self.config.sliding_window),
 
 
 
708
  )
709
 
710
  return attn_output
711
 
712
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
713
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
714
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
715
 
 
726
 
727
  if query_length == kv_seq_len:
728
  query_layer = index_first_axis(
729
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
 
730
  )
731
  cu_seqlens_q = cu_seqlens_k
732
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
 
1221
  class Phi3ForCausalLM(Phi3PreTrainedModel):
1222
  _tied_weights_keys = ["lm_head.weight"]
1223
 
1224
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
1225
  def __init__(self, config):
1226
  super().__init__(config)
1227
  self.model = Phi3Model(config)
 
1427
  """,
1428
  PHI3_START_DOCSTRING,
1429
  )
1430
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
1431
  class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1432
  def __init__(self, config):
1433
  super().__init__(config)
 
1543
  """,
1544
  PHI3_START_DOCSTRING,
1545
  )
1546
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
1547
  class Phi3ForTokenClassification(Phi3PreTrainedModel):
1548
  def __init__(self, config: Phi3Config):
1549
  super().__init__(config)
 
1610
  labels = labels.to(logits.device)
1611
  batch_size, seq_length = labels.shape
1612
  loss_fct = CrossEntropyLoss()
1613
+ loss = loss_fct(
1614
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1615
+ )
1616
 
1617
  if not return_dict:
1618
  output = (logits,) + model_outputs[2:]