VarunGumma commited on
Commit
5f3ea6b
1 Parent(s): 59feb3e

Upload modeling_rotary_indictrans.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_rotary_indictrans.py +93 -233
modeling_rotary_indictrans.py CHANGED
@@ -1,20 +1,3 @@
1
- # coding=utf-8
2
- # Copyright 2023 The RotaryIndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ PyTorch RotaryIndicTrans model."""
16
-
17
-
18
  import math
19
  from typing import List, Optional, Tuple, Union
20
 
@@ -38,36 +21,24 @@ from transformers.modeling_outputs import (
38
  Seq2SeqModelOutput,
39
  )
40
 
41
- from transformers.utils import (
42
- logging,
43
- is_flash_attn_2_available,
44
- is_flash_attn_greater_or_equal_2_10,
45
- )
46
 
47
- from einops import rearrange
48
- from transformers.modeling_utils import PreTrainedModel
49
- from .configuration_rotary_indictrans import RotaryIndicTransConfig
50
 
51
- try:
52
- from rotary_embedding_torch import RotaryEmbedding
53
- except ImportError:
54
- raise ImportError("Please install the rotary-embedding-torch>=0.6.4")
55
 
 
 
 
 
 
 
56
 
57
  logger = logging.get_logger(__name__)
58
-
59
- ROTARY_INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
60
-
61
- try:
62
- if is_flash_attn_2_available():
63
- from flash_attn import flash_attn_func, flash_attn_varlen_func
64
- from flash_attn.bert_padding import (
65
- index_first_axis,
66
- pad_input,
67
- unpad_input,
68
- ) # noqa
69
- except:
70
- pass
71
 
72
 
73
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
@@ -87,29 +58,20 @@ def _get_unpad_data(attention_mask):
87
  def shift_tokens_right(
88
  input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
89
  ):
90
- """
91
- Shift input ids one token to the right.
92
- """
93
  shifted_input_ids = input_ids.new_zeros(input_ids.shape)
94
  shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
95
  shifted_input_ids[:, 0] = decoder_start_token_id
96
 
97
  if pad_token_id is None:
98
  raise ValueError("self.model.config.pad_token_id has to be defined.")
99
- # replace possible -100 values in labels by `pad_token_id`
100
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
101
 
 
102
  return shifted_input_ids
103
 
104
 
105
  def create_position_ids_from_input_ids(
106
  input_ids, padding_idx, past_key_values_length=0
107
  ):
108
- """
109
- Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
110
- are ignored. This is modified from fairseq's `utils.make_positions`.
111
- """
112
- # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
113
  mask = input_ids.ne(padding_idx).int()
114
  incremental_indices = (
115
  torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
@@ -117,10 +79,64 @@ def create_position_ids_from_input_ids(
117
  return incremental_indices.long() + padding_idx
118
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->RotaryIndicTrans
121
  class RotaryIndicTransAttention(nn.Module):
122
- """Multi-headed attention from 'Attention Is All You Need' paper"""
123
-
124
  def __init__(
125
  self,
126
  embed_dim: int,
@@ -133,12 +149,11 @@ class RotaryIndicTransAttention(nn.Module):
133
  config: Optional[RotaryIndicTransConfig] = None,
134
  ):
135
  super().__init__()
 
136
  self.embed_dim = embed_dim
137
  self.num_heads = num_heads
138
  self.dropout = dropout
139
  self.head_dim = embed_dim // num_heads
140
- self.config = config
141
- self.rope_args = config.rope_args
142
 
143
  if (self.head_dim * num_heads) != self.embed_dim:
144
  raise ValueError(
@@ -149,15 +164,12 @@ class RotaryIndicTransAttention(nn.Module):
149
  self.is_decoder = is_decoder
150
  self.is_causal = is_causal
151
 
152
- self.xpos = self.rope_args.get("use_xpos", False)
153
-
154
  # partial rotation in RoPE
155
  self.rotary_pos_embed = (
156
  RotaryEmbedding(
157
  dim=self.head_dim // 2,
158
- use_xpos=self.xpos,
159
- theta=self.rope_args.get("theta", 10000),
160
- xpos_scale_base=self.rope_args.get("xpos_scale_base", 512),
161
  )
162
  if not is_cross_attention
163
  else None
@@ -179,14 +191,10 @@ class RotaryIndicTransAttention(nn.Module):
179
  q = rearrange(q, "(b h) t d -> b h t d", h=self.num_heads)
180
  k = rearrange(k, "(b h) t d -> b h t d", h=self.num_heads)
181
 
182
- if is_inference:
183
- q, k = self.rotary_pos_embed.rotate_queries_with_cached_keys(q, k)
184
- else:
185
- if not self.xpos:
186
- q = self.rotary_pos_embed.rotate_queries_or_keys(q)
187
- k = self.rotary_pos_embed.rotate_queries_or_keys(k)
188
- else:
189
- q, k = self.rotary_pos_embed.rotate_queries_and_keys(q, k)
190
 
191
  q = rearrange(q, "b h t d -> (b h) t d")
192
  k = rearrange(k, "b h t d -> (b h) t d")
@@ -203,49 +211,32 @@ class RotaryIndicTransAttention(nn.Module):
203
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
204
  """Input shape: Batch x Time x Channel"""
205
 
206
- # if key_value_states are provided this layer is used as a cross-attention layer
207
- # for the decoder
208
  is_cross_attention = key_value_states is not None
209
 
210
  bsz, tgt_len, _ = hidden_states.size()
211
 
212
- # get query proj
213
  query_states = self.q_proj(hidden_states) * self.scaling
214
- # get key, value proj
215
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
216
- # is checking that the `sequence_length` of the `past_key_value` is the same as
217
- # the provided `key_value_states` to support prefix tuning
218
  if (
219
  is_cross_attention
220
  and past_key_value is not None
221
  and past_key_value[0].shape[2] == key_value_states.shape[1]
222
  ):
223
- # reuse k,v, cross_attentions
224
  key_states = past_key_value[0]
225
  value_states = past_key_value[1]
226
  elif is_cross_attention:
227
- # cross_attentions
228
  key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
229
  value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
230
  elif past_key_value is not None:
231
- # reuse k, v, self_attention
232
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
233
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
234
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
235
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
236
  else:
237
- # self_attention
238
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
239
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
240
 
241
  if self.is_decoder:
242
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
243
- # Further calls to cross_attention layer can then reuse all cross-attention
244
- # key/value_states (first "if" case)
245
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
246
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
247
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
248
- # if encoder bi-directional self-attention `past_key_value` is always `None`
249
  past_key_value = (key_states, value_states)
250
 
251
  proj_shape = (bsz * self.num_heads, -1, self.head_dim)
@@ -293,10 +284,6 @@ class RotaryIndicTransAttention(nn.Module):
293
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
294
 
295
  if output_attentions:
296
- # this operation is a bit awkward, but it's required to
297
- # make sure that attn_weights keeps its gradient.
298
- # In order to do so, attn_weights have to be reshaped
299
- # twice and have to be reused in the following
300
  attn_weights_reshaped = attn_weights.view(
301
  bsz, self.num_heads, tgt_len, src_len
302
  )
@@ -316,34 +303,19 @@ class RotaryIndicTransAttention(nn.Module):
316
  f" {attn_output.size()}"
317
  )
318
 
319
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
320
- attn_output = attn_output.transpose(1, 2)
321
-
322
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
323
- # partitioned across GPUs when using tensor-parallelism.
324
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
325
 
326
  attn_output = self.out_proj(attn_output)
327
-
328
  return attn_output, attn_weights_reshaped, past_key_value
329
 
330
 
331
  class RotaryIndicTransFlashAttention2(RotaryIndicTransAttention):
332
- """
333
- RotaryIndicTrans flash attention module. This module inherits from `RotaryIndicTransAttention` as the weights of the module stays
334
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
335
- flash attention and deal with padding tokens in case the input contains any of them.
336
- """
337
-
338
  # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
339
  def __init__(self, *args, **kwargs):
340
  super().__init__(*args, **kwargs)
341
 
342
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
343
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
344
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
345
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
346
-
347
  def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
348
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
349
 
@@ -362,32 +334,23 @@ class RotaryIndicTransFlashAttention2(RotaryIndicTransAttention):
362
  "RotaryIndicTransFlashAttention2 attention does not support output_attentions"
363
  )
364
 
365
- # if key_value_states are provided this layer is used as a cross-attention layer
366
- # for the decoder
367
  is_cross_attention = key_value_states is not None
368
 
369
  bsz, q_len, _ = hidden_states.size()
370
 
371
- # get query proj
372
  query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
373
- # get key, value proj
374
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
375
- # is checking that the `sequence_length` of the `past_key_value` is the same as
376
- # the provided `key_value_states` to support prefix tuning
377
  if (
378
  is_cross_attention
379
  and past_key_value is not None
380
  and past_key_value[0].shape[2] == key_value_states.shape[1]
381
  ):
382
- # reuse k,v, cross_attentions
383
  key_states = past_key_value[0].transpose(1, 2)
384
  value_states = past_key_value[1].transpose(1, 2)
385
  elif is_cross_attention:
386
- # cross_attentions
387
  key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
388
  value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
389
  elif past_key_value is not None:
390
- # reuse k, v, self_attention
391
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
392
  value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
393
  key_states = torch.cat(
@@ -397,30 +360,16 @@ class RotaryIndicTransFlashAttention2(RotaryIndicTransAttention):
397
  [past_key_value[1].transpose(1, 2), value_states], dim=1
398
  )
399
  else:
400
- # self_attention
401
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
402
  value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
403
 
404
  if self.is_decoder:
405
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
406
- # Further calls to cross_attention layer can then reuse all cross-attention
407
- # key/value_states (first "if" case)
408
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
409
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
410
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
411
- # if encoder bi-directional self-attention `past_key_value` is always `None`
412
  past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
413
 
414
  kv_seq_len = key_states.shape[-2]
415
  if past_key_value is not None:
416
  kv_seq_len += past_key_value[0].shape[-2]
417
 
418
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
419
- # therefore the input hidden states gets silently casted in float32. Hence, we need
420
- # cast them back in the correct dtype just to be sure everything works as expected.
421
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
422
- # in fp32. (LlamaRMSNorm handles it correctly)
423
-
424
  input_dtype = query_states.dtype
425
  if input_dtype == torch.float32:
426
  if torch.is_autocast_enabled():
@@ -493,12 +442,6 @@ class RotaryIndicTransFlashAttention2(RotaryIndicTransAttention):
493
  softmax_scale (`float`, *optional*):
494
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
495
  """
496
- if not self._flash_attn_uses_top_left_mask:
497
- causal = self.is_causal
498
- else:
499
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
500
- causal = self.is_causal and query_length != 1
501
-
502
  # Contains at least one padding token in the sequence
503
  if attention_mask is not None:
504
  batch_size = query_states.shape[0]
@@ -526,7 +469,7 @@ class RotaryIndicTransFlashAttention2(RotaryIndicTransAttention):
526
  max_seqlen_k=max_seqlen_in_batch_k,
527
  dropout_p=dropout,
528
  softmax_scale=softmax_scale,
529
- causal=causal,
530
  )
531
 
532
  attn_output = pad_input(
@@ -539,7 +482,7 @@ class RotaryIndicTransFlashAttention2(RotaryIndicTransAttention):
539
  value_states,
540
  dropout,
541
  softmax_scale=softmax_scale,
542
- causal=causal,
543
  )
544
 
545
  return attn_output
@@ -571,11 +514,10 @@ class RotaryIndicTransFlashAttention2(RotaryIndicTransAttention):
571
  max_seqlen_in_batch_q = 1
572
  cu_seqlens_q = torch.arange(
573
  batch_size + 1, dtype=torch.int32, device=query_layer.device
574
- ) # There is a memcpy here, that is very bad.
575
  indices_q = cu_seqlens_q[:-1]
576
  query_layer = query_layer.squeeze(1)
577
  else:
578
- # The -q_len: slice assumes left padding.
579
  attention_mask = attention_mask[:, -query_length:]
580
  query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
581
  query_layer, attention_mask
@@ -603,7 +545,6 @@ class RotaryIndicTransSdpaAttention(RotaryIndicTransAttention):
603
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
604
  """Input shape: Batch x Time x Channel"""
605
  if output_attentions or layer_head_mask is not None:
606
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
607
  logger.warning_once(
608
  "RotaryIndicTransModel is using RotaryIndicTransSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
609
  ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
@@ -617,49 +558,32 @@ class RotaryIndicTransSdpaAttention(RotaryIndicTransAttention):
617
  output_attentions=output_attentions,
618
  )
619
 
620
- # if key_value_states are provided this layer is used as a cross-attention layer
621
- # for the decoder
622
  is_cross_attention = key_value_states is not None
623
 
624
  bsz, tgt_len, _ = hidden_states.size()
625
 
626
- # get query proj
627
  query_states = self.q_proj(hidden_states)
628
- # get key, value proj
629
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
630
- # is checking that the `sequence_length` of the `past_key_value` is the same as
631
- # the provided `key_value_states` to support prefix tuning
632
  if (
633
  is_cross_attention
634
  and past_key_value is not None
635
  and past_key_value[0].shape[2] == key_value_states.shape[1]
636
  ):
637
- # reuse k,v, cross_attentions
638
  key_states = past_key_value[0]
639
  value_states = past_key_value[1]
640
  elif is_cross_attention:
641
- # cross_attentions
642
  key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
643
  value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
644
  elif past_key_value is not None:
645
- # reuse k, v, self_attention
646
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
647
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
648
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
649
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
650
  else:
651
- # self_attention
652
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
653
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
654
 
655
  if self.is_decoder:
656
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
657
- # Further calls to cross_attention layer can then reuse all cross-attention
658
- # key/value_states (first "if" case)
659
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
660
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
661
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
662
- # if encoder bi-directional self-attention `past_key_value` is always `None`
663
  past_key_value = (key_states, value_states)
664
 
665
  query_states = self._shape(query_states, tgt_len, bsz)
@@ -669,15 +593,12 @@ class RotaryIndicTransSdpaAttention(RotaryIndicTransAttention):
669
  query_states, key_states, is_inference=past_key_value is not None
670
  )
671
 
672
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
673
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
674
  attn_output = F.scaled_dot_product_attention(
675
  query_states,
676
  key_states,
677
  value_states,
678
  attn_mask=attention_mask,
679
  dropout_p=self.dropout if self.training else 0.0,
680
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
681
  is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
682
  )
683
 
@@ -687,14 +608,10 @@ class RotaryIndicTransSdpaAttention(RotaryIndicTransAttention):
687
  f" {attn_output.size()}"
688
  )
689
 
690
- attn_output = attn_output.transpose(1, 2)
691
-
692
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
693
- # partitioned across GPUs when using tensor-parallelism.
694
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
695
-
696
  attn_output = self.out_proj(attn_output)
697
-
698
  return attn_output, None, past_key_value
699
 
700
 
@@ -859,12 +776,10 @@ class RotaryIndicTransDecoderLayer(nn.Module):
859
  if self.normalize_before:
860
  hidden_states = self.self_attn_layer_norm(hidden_states)
861
 
862
- # Self Attention
863
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
864
  self_attn_past_key_value = (
865
  past_key_value[:2] if past_key_value is not None else None
866
  )
867
- # add present self-attn cache to positions 1,2 of present_key_value tuple
868
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
869
  hidden_states=hidden_states,
870
  past_key_value=self_attn_past_key_value,
@@ -877,7 +792,6 @@ class RotaryIndicTransDecoderLayer(nn.Module):
877
  if not self.normalize_before:
878
  hidden_states = self.self_attn_layer_norm(hidden_states)
879
 
880
- # Cross-Attention Block
881
  cross_attn_present_key_value = None
882
  cross_attn_weights = None
883
  if encoder_hidden_states is not None:
@@ -885,7 +799,6 @@ class RotaryIndicTransDecoderLayer(nn.Module):
885
  if self.normalize_before:
886
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
887
 
888
- # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
889
  cross_attn_past_key_value = (
890
  past_key_value[-2:] if past_key_value is not None else None
891
  )
@@ -908,10 +821,8 @@ class RotaryIndicTransDecoderLayer(nn.Module):
908
  if not self.normalize_before:
909
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
910
 
911
- # add cross-attn to positions 3,4 of present_key_value tuple
912
  present_key_value = present_key_value + cross_attn_present_key_value
913
 
914
- # Fully Connected
915
  residual = hidden_states
916
  if self.normalize_before:
917
  hidden_states = self.final_layer_norm(hidden_states)
@@ -961,15 +872,6 @@ class RotaryIndicTransPreTrainedModel(PreTrainedModel):
961
 
962
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->RotaryIndicTrans
963
  class RotaryIndicTransEncoder(RotaryIndicTransPreTrainedModel):
964
- """
965
- Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
966
- [`RotaryIndicTransEncoderLayer`].
967
-
968
- Args:
969
- config: RotaryIndicTransConfig
970
- embed_tokens (nn.Embedding): output embedding
971
- """
972
-
973
  def __init__(
974
  self,
975
  config: RotaryIndicTransConfig,
@@ -1005,7 +907,6 @@ class RotaryIndicTransEncoder(RotaryIndicTransPreTrainedModel):
1005
  self._use_sdpa = config._attn_implementation == "sdpa"
1006
 
1007
  self.gradient_checkpointing = False
1008
- # Initialize weights and apply final processing
1009
  self.post_init()
1010
 
1011
  def forward(
@@ -1068,7 +969,6 @@ class RotaryIndicTransEncoder(RotaryIndicTransPreTrainedModel):
1068
  return_dict if return_dict is not None else self.config.use_return_dict
1069
  )
1070
 
1071
- # retrieve input_ids and inputs_embeds
1072
  if input_ids is not None and inputs_embeds is not None:
1073
  raise ValueError(
1074
  "You cannot specify both input_ids and inputs_embeds at the same time"
@@ -1095,14 +995,10 @@ class RotaryIndicTransEncoder(RotaryIndicTransPreTrainedModel):
1095
  if self._use_flash_attention_2:
1096
  attention_mask = attention_mask if 0 in attention_mask else None
1097
  elif self._use_sdpa and head_mask is None and not output_attentions:
1098
- # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1099
- # the manual implementation that requires a 4D causal mask in all cases.
1100
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1101
  attention_mask = _prepare_4d_attention_mask_for_sdpa(
1102
  attention_mask, inputs_embeds.dtype
1103
  )
1104
  else:
1105
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1106
  attention_mask = _prepare_4d_attention_mask(
1107
  attention_mask, inputs_embeds.dtype
1108
  )
@@ -1110,7 +1006,6 @@ class RotaryIndicTransEncoder(RotaryIndicTransPreTrainedModel):
1110
  encoder_states = () if output_hidden_states else None
1111
  all_attentions = () if output_attentions else None
1112
 
1113
- # check if head_mask has a correct number of layers specified if desired
1114
  if head_mask is not None:
1115
  if head_mask.size()[0] != len(self.layers):
1116
  raise ValueError(
@@ -1123,7 +1018,6 @@ class RotaryIndicTransEncoder(RotaryIndicTransPreTrainedModel):
1123
  if output_hidden_states:
1124
  encoder_states = encoder_states + (hidden_states,)
1125
 
1126
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1127
  dropout_probability = torch.rand([])
1128
 
1129
  skip_the_layer = (
@@ -1132,10 +1026,8 @@ class RotaryIndicTransEncoder(RotaryIndicTransPreTrainedModel):
1132
  else False
1133
  )
1134
  if not skip_the_layer or deepspeed_zero3_is_enabled:
1135
- # under deepspeed zero3 all gpus must run in sync
1136
-
1137
  if self.gradient_checkpointing and self.training:
1138
- # create gradient checkpointing function
1139
  def create_custom_forward(module):
1140
  def custom_forward(*inputs):
1141
  return module(*inputs, output_attentions)
@@ -1187,14 +1079,6 @@ class RotaryIndicTransEncoder(RotaryIndicTransPreTrainedModel):
1187
 
1188
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->RotaryIndicTrans
1189
  class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
1190
- """
1191
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`RotaryIndicTransDecoderLayer`]
1192
-
1193
- Args:
1194
- config: RotaryIndicTransConfig
1195
- embed_tokens (nn.Embedding): output embedding
1196
- """
1197
-
1198
  def __init__(
1199
  self,
1200
  config: RotaryIndicTransConfig,
@@ -1229,7 +1113,6 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
1229
  self._use_sdpa = config._attn_implementation == "sdpa"
1230
 
1231
  self.gradient_checkpointing = False
1232
- # Initialize weights and apply final processing
1233
  self.post_init()
1234
 
1235
  def forward(
@@ -1327,7 +1210,6 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
1327
  return_dict if return_dict is not None else self.config.use_return_dict
1328
  )
1329
 
1330
- # retrieve input_ids and inputs_embeds
1331
  if input_ids is not None and inputs_embeds is not None:
1332
  raise ValueError(
1333
  "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
@@ -1342,7 +1224,6 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
1342
  "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1343
  )
1344
 
1345
- # past_key_values_length
1346
  past_key_values_length = (
1347
  past_key_values[0][0].shape[2] if past_key_values is not None else 0
1348
  )
@@ -1351,15 +1232,12 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
1351
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1352
 
1353
  if self._use_flash_attention_2:
1354
- # 2d mask is passed through the layers
1355
  attention_mask = (
1356
  attention_mask
1357
  if (attention_mask is not None and 0 in attention_mask)
1358
  else None
1359
  )
1360
  elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
1361
- # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1362
- # the manual implementation that requires a 4D causal mask in all cases.
1363
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1364
  attention_mask,
1365
  input_shape,
@@ -1367,12 +1245,10 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
1367
  past_key_values_length,
1368
  )
1369
  else:
1370
- # 4d mask is passed through the layers
1371
  attention_mask = _prepare_4d_causal_attention_mask(
1372
  attention_mask, input_shape, inputs_embeds, past_key_values_length
1373
  )
1374
 
1375
- # expand encoder attention mask
1376
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
1377
  if self._use_flash_attention_2:
1378
  encoder_attention_mask = (
@@ -1383,16 +1259,12 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
1383
  and cross_attn_head_mask is None
1384
  and not output_attentions
1385
  ):
1386
- # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1387
- # the manual implementation that requires a 4D causal mask in all cases.
1388
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1389
  encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1390
  encoder_attention_mask,
1391
  inputs_embeds.dtype,
1392
  tgt_len=input_shape[-1],
1393
  )
1394
  else:
1395
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1396
  encoder_attention_mask = _prepare_4d_attention_mask(
1397
  encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1398
  )
@@ -1412,13 +1284,11 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
1412
  )
1413
  use_cache = False
1414
 
1415
- # decoder layers
1416
  all_hidden_states = () if output_hidden_states else None
1417
  all_self_attns = () if output_attentions else None
1418
  all_cross_attentions = () if output_attentions else None
1419
  next_decoder_cache = () if use_cache else None
1420
 
1421
- # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1422
  for attn_mask, mask_name in zip(
1423
  [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
1424
  ):
@@ -1434,7 +1304,6 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
1434
  if output_hidden_states:
1435
  all_hidden_states += (hidden_states,)
1436
 
1437
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1438
  dropout_probability = torch.rand([])
1439
 
1440
  skip_the_layer = (
@@ -1443,8 +1312,6 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
1443
  else False
1444
  )
1445
  if not skip_the_layer or deepspeed_zero3_is_enabled:
1446
- # under deepspeed zero3 all gpus must run in sync
1447
-
1448
  past_key_value = (
1449
  past_key_values[idx] if past_key_values is not None else None
1450
  )
@@ -1506,7 +1373,6 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
1506
  if self.layer_norm is not None:
1507
  hidden_states = self.layer_norm(hidden_states)
1508
 
1509
- # add hidden states from the last decoder layer
1510
  if output_hidden_states:
1511
  all_hidden_states += (hidden_states,)
1512
 
@@ -1541,8 +1407,6 @@ class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
1541
 
1542
  self.encoder = RotaryIndicTransEncoder(config)
1543
  self.decoder = RotaryIndicTransDecoder(config)
1544
-
1545
- # Initialize weights and apply final processing
1546
  self.post_init()
1547
 
1548
  def get_encoder(self):
@@ -1594,7 +1458,6 @@ class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
1594
  output_hidden_states=output_hidden_states,
1595
  return_dict=return_dict,
1596
  )
1597
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1598
  elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1599
  encoder_outputs = BaseModelOutput(
1600
  last_hidden_state=encoder_outputs[0],
@@ -1602,7 +1465,6 @@ class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
1602
  attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1603
  )
1604
 
1605
- # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1606
  decoder_outputs = self.decoder(
1607
  input_ids=decoder_input_ids,
1608
  attention_mask=decoder_attention_mask,
@@ -1727,7 +1589,6 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel):
1727
 
1728
  masked_lm_loss = None
1729
  if labels is not None:
1730
- # move labels to the correct device to enable PP
1731
  labels = labels.to(lm_logits.device)
1732
  masked_lm_loss = F.cross_entropy(
1733
  input=lm_logits.view(-1, self.config.decoder_vocab_size),
@@ -1766,12 +1627,11 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel):
1766
  encoder_outputs=None,
1767
  **kwargs,
1768
  ):
1769
- # cut decoder_input_ids if past is used
1770
  if past_key_values is not None:
1771
  decoder_input_ids = decoder_input_ids[:, -1:]
1772
 
1773
  return {
1774
- "input_ids": None, # encoder_outputs is defined. input_ids not needed
1775
  "encoder_outputs": encoder_outputs,
1776
  "past_key_values": past_key_values,
1777
  "decoder_input_ids": decoder_input_ids,
@@ -1779,7 +1639,7 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel):
1779
  "head_mask": head_mask,
1780
  "decoder_head_mask": decoder_head_mask,
1781
  "cross_attn_head_mask": cross_attn_head_mask,
1782
- "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1783
  }
1784
 
1785
  @staticmethod
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  from typing import List, Optional, Tuple, Union
3
 
 
21
  Seq2SeqModelOutput,
22
  )
23
 
24
+ from transformers.utils import logging
25
+ from einops import rearrange, repeat
 
 
 
26
 
27
+ from torch.amp import autocast
28
+ from torch import einsum
 
29
 
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from configuration_rotary_indictrans import RotaryIndicTransConfig
 
 
32
 
33
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
34
+ from flash_attn.bert_padding import (
35
+ index_first_axis,
36
+ pad_input,
37
+ unpad_input,
38
+ )
39
 
40
  logger = logging.get_logger(__name__)
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
 
44
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
58
  def shift_tokens_right(
59
  input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
60
  ):
 
 
 
61
  shifted_input_ids = input_ids.new_zeros(input_ids.shape)
62
  shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
63
  shifted_input_ids[:, 0] = decoder_start_token_id
64
 
65
  if pad_token_id is None:
66
  raise ValueError("self.model.config.pad_token_id has to be defined.")
 
 
67
 
68
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
69
  return shifted_input_ids
70
 
71
 
72
  def create_position_ids_from_input_ids(
73
  input_ids, padding_idx, past_key_values_length=0
74
  ):
 
 
 
 
 
75
  mask = input_ids.ne(padding_idx).int()
76
  incremental_indices = (
77
  torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
 
79
  return incremental_indices.long() + padding_idx
80
 
81
 
82
+ def rotate_half(x):
83
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
84
+ x1, x2 = x.unbind(dim=-1)
85
+ x = torch.stack((-x2, x1), dim=-1)
86
+ return rearrange(x, "... d r -> ... (d r)")
87
+
88
+
89
+ @autocast("cuda", enabled=False)
90
+ def apply_rotary_emb(cos, sin, t):
91
+ rot_dim = cos.shape[-1]
92
+ assert rot_dim <= t.shape[-1] and cos.shape == sin.shape
93
+ t_left, t_right = t[..., :rot_dim], t[..., rot_dim:]
94
+ t_transformed = (t_left * cos) + (rotate_half(t_left) * sin)
95
+ return torch.cat((t_transformed, t_right), dim=-1).type(t.dtype)
96
+
97
+
98
+ class RotaryEmbedding(torch.nn.Module):
99
+ def __init__(
100
+ self, dim, theta=10000, interpolate_factor=1.0, cache_max_seq_len=8192
101
+ ):
102
+ super().__init__()
103
+
104
+ freqs_ = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
105
+ self.cache_max_seq_len = cache_max_seq_len
106
+ self.interpolate_factor = interpolate_factor
107
+
108
+ self.freqs = torch.nn.Parameter(freqs_, requires_grad=False).to(device)
109
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
110
+ self.precompute_freqs(cache_max_seq_len)
111
+
112
+ def precompute_freqs(self, max_seq_len):
113
+ thetas = self.forward(max_seq_len, device=device)
114
+ self.register_buffer("cached_cos", thetas.cos(), persistent=False)
115
+ self.register_buffer("cached_sin", thetas.sin(), persistent=False)
116
+
117
+ def rotate_queries_or_keys(self, t, seq_dim=-2, offset=0):
118
+ seq_len = t.shape[seq_dim]
119
+
120
+ if seq_len > self.cache_max_seq_len:
121
+ self.cache_max_seq_len = seq_len * 2
122
+ self.precompute_freqs(self.cache_max_seq_len)
123
+
124
+ cos, sin = (
125
+ self.cached_cos[offset : (offset + seq_len)],
126
+ self.cached_sin[offset : (offset + seq_len)],
127
+ )
128
+ return apply_rotary_emb(cos, sin, t)
129
+
130
+ @autocast("cuda", enabled=False)
131
+ def forward(self, seq_len, device):
132
+ seq = torch.arange(seq_len, device=device) / self.interpolate_factor
133
+ thetas = einsum("..., f -> ... f", seq, self.freqs)
134
+ thetas = repeat(thetas, "... n -> ... (n r)", r=2)
135
+ return thetas
136
+
137
+
138
  # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->RotaryIndicTrans
139
  class RotaryIndicTransAttention(nn.Module):
 
 
140
  def __init__(
141
  self,
142
  embed_dim: int,
 
149
  config: Optional[RotaryIndicTransConfig] = None,
150
  ):
151
  super().__init__()
152
+ self.config = config
153
  self.embed_dim = embed_dim
154
  self.num_heads = num_heads
155
  self.dropout = dropout
156
  self.head_dim = embed_dim // num_heads
 
 
157
 
158
  if (self.head_dim * num_heads) != self.embed_dim:
159
  raise ValueError(
 
164
  self.is_decoder = is_decoder
165
  self.is_causal = is_causal
166
 
 
 
167
  # partial rotation in RoPE
168
  self.rotary_pos_embed = (
169
  RotaryEmbedding(
170
  dim=self.head_dim // 2,
171
+ theta=config.rope_args.get("theta", 10000),
172
+ interpolate_factor=config.rope_args.get("interpolate_factor", 1.0),
 
173
  )
174
  if not is_cross_attention
175
  else None
 
191
  q = rearrange(q, "(b h) t d -> b h t d", h=self.num_heads)
192
  k = rearrange(k, "(b h) t d -> b h t d", h=self.num_heads)
193
 
194
+ offset = (k.shape[-2] - 1) if is_inference else 0
195
+
196
+ q = self.rotary_pos_embed.rotate_queries_or_keys(q, offset=offset)
197
+ k = self.rotary_pos_embed.rotate_queries_or_keys(k)
 
 
 
 
198
 
199
  q = rearrange(q, "b h t d -> (b h) t d")
200
  k = rearrange(k, "b h t d -> (b h) t d")
 
211
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
212
  """Input shape: Batch x Time x Channel"""
213
 
 
 
214
  is_cross_attention = key_value_states is not None
215
 
216
  bsz, tgt_len, _ = hidden_states.size()
217
 
 
218
  query_states = self.q_proj(hidden_states) * self.scaling
219
+
 
 
 
220
  if (
221
  is_cross_attention
222
  and past_key_value is not None
223
  and past_key_value[0].shape[2] == key_value_states.shape[1]
224
  ):
 
225
  key_states = past_key_value[0]
226
  value_states = past_key_value[1]
227
  elif is_cross_attention:
 
228
  key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
229
  value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
230
  elif past_key_value is not None:
 
231
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
232
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
233
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
234
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
235
  else:
 
236
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
237
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
238
 
239
  if self.is_decoder:
 
 
 
 
 
 
 
240
  past_key_value = (key_states, value_states)
241
 
242
  proj_shape = (bsz * self.num_heads, -1, self.head_dim)
 
284
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
285
 
286
  if output_attentions:
 
 
 
 
287
  attn_weights_reshaped = attn_weights.view(
288
  bsz, self.num_heads, tgt_len, src_len
289
  )
 
303
  f" {attn_output.size()}"
304
  )
305
 
306
+ attn_output = rearrange(
307
+ attn_output, "(b h) t d -> b t (h d)", h=self.num_heads, d=self.head_dim
308
+ )
 
 
 
309
 
310
  attn_output = self.out_proj(attn_output)
 
311
  return attn_output, attn_weights_reshaped, past_key_value
312
 
313
 
314
  class RotaryIndicTransFlashAttention2(RotaryIndicTransAttention):
 
 
 
 
 
 
315
  # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
316
  def __init__(self, *args, **kwargs):
317
  super().__init__(*args, **kwargs)
318
 
 
 
 
 
 
319
  def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
320
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
321
 
 
334
  "RotaryIndicTransFlashAttention2 attention does not support output_attentions"
335
  )
336
 
 
 
337
  is_cross_attention = key_value_states is not None
338
 
339
  bsz, q_len, _ = hidden_states.size()
340
 
 
341
  query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
342
+
 
 
 
343
  if (
344
  is_cross_attention
345
  and past_key_value is not None
346
  and past_key_value[0].shape[2] == key_value_states.shape[1]
347
  ):
 
348
  key_states = past_key_value[0].transpose(1, 2)
349
  value_states = past_key_value[1].transpose(1, 2)
350
  elif is_cross_attention:
 
351
  key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
352
  value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
353
  elif past_key_value is not None:
 
354
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
355
  value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
356
  key_states = torch.cat(
 
360
  [past_key_value[1].transpose(1, 2), value_states], dim=1
361
  )
362
  else:
 
363
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
364
  value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
365
 
366
  if self.is_decoder:
 
 
 
 
 
 
 
367
  past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
368
 
369
  kv_seq_len = key_states.shape[-2]
370
  if past_key_value is not None:
371
  kv_seq_len += past_key_value[0].shape[-2]
372
 
 
 
 
 
 
 
373
  input_dtype = query_states.dtype
374
  if input_dtype == torch.float32:
375
  if torch.is_autocast_enabled():
 
442
  softmax_scale (`float`, *optional*):
443
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
444
  """
 
 
 
 
 
 
445
  # Contains at least one padding token in the sequence
446
  if attention_mask is not None:
447
  batch_size = query_states.shape[0]
 
469
  max_seqlen_k=max_seqlen_in_batch_k,
470
  dropout_p=dropout,
471
  softmax_scale=softmax_scale,
472
+ causal=self.is_causal,
473
  )
474
 
475
  attn_output = pad_input(
 
482
  value_states,
483
  dropout,
484
  softmax_scale=softmax_scale,
485
+ causal=self.is_causal,
486
  )
487
 
488
  return attn_output
 
514
  max_seqlen_in_batch_q = 1
515
  cu_seqlens_q = torch.arange(
516
  batch_size + 1, dtype=torch.int32, device=query_layer.device
517
+ )
518
  indices_q = cu_seqlens_q[:-1]
519
  query_layer = query_layer.squeeze(1)
520
  else:
 
521
  attention_mask = attention_mask[:, -query_length:]
522
  query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
523
  query_layer, attention_mask
 
545
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
546
  """Input shape: Batch x Time x Channel"""
547
  if output_attentions or layer_head_mask is not None:
 
548
  logger.warning_once(
549
  "RotaryIndicTransModel is using RotaryIndicTransSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
550
  ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
 
558
  output_attentions=output_attentions,
559
  )
560
 
 
 
561
  is_cross_attention = key_value_states is not None
562
 
563
  bsz, tgt_len, _ = hidden_states.size()
564
 
 
565
  query_states = self.q_proj(hidden_states)
566
+
 
 
 
567
  if (
568
  is_cross_attention
569
  and past_key_value is not None
570
  and past_key_value[0].shape[2] == key_value_states.shape[1]
571
  ):
 
572
  key_states = past_key_value[0]
573
  value_states = past_key_value[1]
574
  elif is_cross_attention:
 
575
  key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
576
  value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
577
  elif past_key_value is not None:
 
578
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
579
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
580
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
581
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
582
  else:
 
583
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
584
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
585
 
586
  if self.is_decoder:
 
 
 
 
 
 
 
587
  past_key_value = (key_states, value_states)
588
 
589
  query_states = self._shape(query_states, tgt_len, bsz)
 
593
  query_states, key_states, is_inference=past_key_value is not None
594
  )
595
 
 
 
596
  attn_output = F.scaled_dot_product_attention(
597
  query_states,
598
  key_states,
599
  value_states,
600
  attn_mask=attention_mask,
601
  dropout_p=self.dropout if self.training else 0.0,
 
602
  is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
603
  )
604
 
 
608
  f" {attn_output.size()}"
609
  )
610
 
611
+ attn_output = rearrange(
612
+ attn_output, "b h t d -> b t (h d)", h=self.num_heads, d=self.head_dim
613
+ )
 
 
 
614
  attn_output = self.out_proj(attn_output)
 
615
  return attn_output, None, past_key_value
616
 
617
 
 
776
  if self.normalize_before:
777
  hidden_states = self.self_attn_layer_norm(hidden_states)
778
 
 
 
779
  self_attn_past_key_value = (
780
  past_key_value[:2] if past_key_value is not None else None
781
  )
782
+
783
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
784
  hidden_states=hidden_states,
785
  past_key_value=self_attn_past_key_value,
 
792
  if not self.normalize_before:
793
  hidden_states = self.self_attn_layer_norm(hidden_states)
794
 
 
795
  cross_attn_present_key_value = None
796
  cross_attn_weights = None
797
  if encoder_hidden_states is not None:
 
799
  if self.normalize_before:
800
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
801
 
 
802
  cross_attn_past_key_value = (
803
  past_key_value[-2:] if past_key_value is not None else None
804
  )
 
821
  if not self.normalize_before:
822
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
823
 
 
824
  present_key_value = present_key_value + cross_attn_present_key_value
825
 
 
826
  residual = hidden_states
827
  if self.normalize_before:
828
  hidden_states = self.final_layer_norm(hidden_states)
 
872
 
873
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->RotaryIndicTrans
874
  class RotaryIndicTransEncoder(RotaryIndicTransPreTrainedModel):
 
 
 
 
 
 
 
 
 
875
  def __init__(
876
  self,
877
  config: RotaryIndicTransConfig,
 
907
  self._use_sdpa = config._attn_implementation == "sdpa"
908
 
909
  self.gradient_checkpointing = False
 
910
  self.post_init()
911
 
912
  def forward(
 
969
  return_dict if return_dict is not None else self.config.use_return_dict
970
  )
971
 
 
972
  if input_ids is not None and inputs_embeds is not None:
973
  raise ValueError(
974
  "You cannot specify both input_ids and inputs_embeds at the same time"
 
995
  if self._use_flash_attention_2:
996
  attention_mask = attention_mask if 0 in attention_mask else None
997
  elif self._use_sdpa and head_mask is None and not output_attentions:
 
 
 
998
  attention_mask = _prepare_4d_attention_mask_for_sdpa(
999
  attention_mask, inputs_embeds.dtype
1000
  )
1001
  else:
 
1002
  attention_mask = _prepare_4d_attention_mask(
1003
  attention_mask, inputs_embeds.dtype
1004
  )
 
1006
  encoder_states = () if output_hidden_states else None
1007
  all_attentions = () if output_attentions else None
1008
 
 
1009
  if head_mask is not None:
1010
  if head_mask.size()[0] != len(self.layers):
1011
  raise ValueError(
 
1018
  if output_hidden_states:
1019
  encoder_states = encoder_states + (hidden_states,)
1020
 
 
1021
  dropout_probability = torch.rand([])
1022
 
1023
  skip_the_layer = (
 
1026
  else False
1027
  )
1028
  if not skip_the_layer or deepspeed_zero3_is_enabled:
 
 
1029
  if self.gradient_checkpointing and self.training:
1030
+
1031
  def create_custom_forward(module):
1032
  def custom_forward(*inputs):
1033
  return module(*inputs, output_attentions)
 
1079
 
1080
  # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->RotaryIndicTrans
1081
  class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
 
 
 
 
 
 
 
 
1082
  def __init__(
1083
  self,
1084
  config: RotaryIndicTransConfig,
 
1113
  self._use_sdpa = config._attn_implementation == "sdpa"
1114
 
1115
  self.gradient_checkpointing = False
 
1116
  self.post_init()
1117
 
1118
  def forward(
 
1210
  return_dict if return_dict is not None else self.config.use_return_dict
1211
  )
1212
 
 
1213
  if input_ids is not None and inputs_embeds is not None:
1214
  raise ValueError(
1215
  "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
 
1224
  "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1225
  )
1226
 
 
1227
  past_key_values_length = (
1228
  past_key_values[0][0].shape[2] if past_key_values is not None else 0
1229
  )
 
1232
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1233
 
1234
  if self._use_flash_attention_2:
 
1235
  attention_mask = (
1236
  attention_mask
1237
  if (attention_mask is not None and 0 in attention_mask)
1238
  else None
1239
  )
1240
  elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
 
 
1241
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1242
  attention_mask,
1243
  input_shape,
 
1245
  past_key_values_length,
1246
  )
1247
  else:
 
1248
  attention_mask = _prepare_4d_causal_attention_mask(
1249
  attention_mask, input_shape, inputs_embeds, past_key_values_length
1250
  )
1251
 
 
1252
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
1253
  if self._use_flash_attention_2:
1254
  encoder_attention_mask = (
 
1259
  and cross_attn_head_mask is None
1260
  and not output_attentions
1261
  ):
 
 
 
1262
  encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1263
  encoder_attention_mask,
1264
  inputs_embeds.dtype,
1265
  tgt_len=input_shape[-1],
1266
  )
1267
  else:
 
1268
  encoder_attention_mask = _prepare_4d_attention_mask(
1269
  encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1270
  )
 
1284
  )
1285
  use_cache = False
1286
 
 
1287
  all_hidden_states = () if output_hidden_states else None
1288
  all_self_attns = () if output_attentions else None
1289
  all_cross_attentions = () if output_attentions else None
1290
  next_decoder_cache = () if use_cache else None
1291
 
 
1292
  for attn_mask, mask_name in zip(
1293
  [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
1294
  ):
 
1304
  if output_hidden_states:
1305
  all_hidden_states += (hidden_states,)
1306
 
 
1307
  dropout_probability = torch.rand([])
1308
 
1309
  skip_the_layer = (
 
1312
  else False
1313
  )
1314
  if not skip_the_layer or deepspeed_zero3_is_enabled:
 
 
1315
  past_key_value = (
1316
  past_key_values[idx] if past_key_values is not None else None
1317
  )
 
1373
  if self.layer_norm is not None:
1374
  hidden_states = self.layer_norm(hidden_states)
1375
 
 
1376
  if output_hidden_states:
1377
  all_hidden_states += (hidden_states,)
1378
 
 
1407
 
1408
  self.encoder = RotaryIndicTransEncoder(config)
1409
  self.decoder = RotaryIndicTransDecoder(config)
 
 
1410
  self.post_init()
1411
 
1412
  def get_encoder(self):
 
1458
  output_hidden_states=output_hidden_states,
1459
  return_dict=return_dict,
1460
  )
 
1461
  elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1462
  encoder_outputs = BaseModelOutput(
1463
  last_hidden_state=encoder_outputs[0],
 
1465
  attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1466
  )
1467
 
 
1468
  decoder_outputs = self.decoder(
1469
  input_ids=decoder_input_ids,
1470
  attention_mask=decoder_attention_mask,
 
1589
 
1590
  masked_lm_loss = None
1591
  if labels is not None:
 
1592
  labels = labels.to(lm_logits.device)
1593
  masked_lm_loss = F.cross_entropy(
1594
  input=lm_logits.view(-1, self.config.decoder_vocab_size),
 
1627
  encoder_outputs=None,
1628
  **kwargs,
1629
  ):
 
1630
  if past_key_values is not None:
1631
  decoder_input_ids = decoder_input_ids[:, -1:]
1632
 
1633
  return {
1634
+ "input_ids": None,
1635
  "encoder_outputs": encoder_outputs,
1636
  "past_key_values": past_key_values,
1637
  "decoder_input_ids": decoder_input_ids,
 
1639
  "head_mask": head_mask,
1640
  "decoder_head_mask": decoder_head_mask,
1641
  "cross_attn_head_mask": cross_attn_head_mask,
1642
+ "use_cache": use_cache,
1643
  }
1644
 
1645
  @staticmethod