Upload modeling_indictrans.py with huggingface_hub
Browse files- modeling_indictrans.py +271 -83
modeling_indictrans.py
CHANGED
@@ -34,7 +34,7 @@ from transformers.modeling_outputs import (
|
|
34 |
from transformers.utils import logging
|
35 |
from transformers.modeling_utils import PreTrainedModel
|
36 |
|
37 |
-
from
|
38 |
|
39 |
|
40 |
logger = logging.get_logger(__name__)
|
@@ -45,7 +45,9 @@ INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
|
|
45 |
|
46 |
|
47 |
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
48 |
-
def shift_tokens_right(
|
|
|
|
|
49 |
"""
|
50 |
Shift input ids one token to the right.
|
51 |
"""
|
@@ -63,7 +65,10 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
|
|
63 |
|
64 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
65 |
def _make_causal_mask(
|
66 |
-
input_ids_shape: torch.Size,
|
|
|
|
|
|
|
67 |
):
|
68 |
"""
|
69 |
Make causal mask used for bi-directional self-attention.
|
@@ -75,8 +80,18 @@ def _make_causal_mask(
|
|
75 |
mask = mask.to(dtype)
|
76 |
|
77 |
if past_key_values_length > 0:
|
78 |
-
mask = torch.cat(
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
|
82 |
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
@@ -91,17 +106,23 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|
91 |
|
92 |
inverted_mask = 1.0 - expanded_mask
|
93 |
|
94 |
-
return inverted_mask.masked_fill(
|
|
|
|
|
95 |
|
96 |
|
97 |
-
def create_position_ids_from_input_ids(
|
|
|
|
|
98 |
"""
|
99 |
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
100 |
are ignored. This is modified from fairseq's `utils.make_positions`.
|
101 |
"""
|
102 |
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
103 |
mask = input_ids.ne(padding_idx).int()
|
104 |
-
incremental_indices = (
|
|
|
|
|
105 |
return incremental_indices.long() + padding_idx
|
106 |
|
107 |
|
@@ -109,23 +130,31 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
|
|
109 |
class IndicTransSinusoidalPositionalEmbedding(nn.Module):
|
110 |
"""This module produces sinusoidal positional embeddings of any length."""
|
111 |
|
112 |
-
def __init__(
|
|
|
|
|
113 |
super().__init__()
|
114 |
self.offset = 2
|
115 |
self.embedding_dim = embedding_dim
|
116 |
self.padding_idx = padding_idx
|
117 |
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
|
118 |
|
119 |
-
def make_weights(
|
|
|
|
|
120 |
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
|
121 |
if hasattr(self, "weights"):
|
122 |
# in forward put the weights on the correct dtype and device of the param
|
123 |
-
emb_weights = emb_weights.to(
|
|
|
|
|
124 |
|
125 |
self.register_buffer("weights", emb_weights, persistent=False)
|
126 |
|
127 |
@staticmethod
|
128 |
-
def get_embedding(
|
|
|
|
|
129 |
"""
|
130 |
Build sinusoidal embeddings.
|
131 |
|
@@ -135,8 +164,12 @@ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
|
|
135 |
half_dim = embedding_dim // 2
|
136 |
emb = math.log(10000) / (half_dim - 1)
|
137 |
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
138 |
-
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
|
139 |
-
|
|
|
|
|
|
|
|
|
140 |
if embedding_dim % 2 == 1:
|
141 |
# zero pad
|
142 |
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
@@ -147,26 +180,39 @@ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
|
|
147 |
|
148 |
@torch.no_grad()
|
149 |
def forward(
|
150 |
-
self,
|
|
|
|
|
|
|
151 |
):
|
152 |
if input_ids is not None:
|
153 |
bsz, seq_len = input_ids.size()
|
154 |
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
155 |
-
position_ids = create_position_ids_from_input_ids(
|
156 |
-
input_ids.
|
157 |
-
)
|
158 |
else:
|
159 |
bsz, seq_len = inputs_embeds.size()[:-1]
|
160 |
-
position_ids = self.create_position_ids_from_inputs_embeds(
|
|
|
|
|
161 |
|
162 |
# expand embeddings if needed
|
163 |
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
|
164 |
if max_pos > self.weights.size(0):
|
165 |
-
self.make_weights(
|
|
|
|
|
166 |
|
167 |
-
return
|
|
|
|
|
|
|
|
|
168 |
|
169 |
-
def create_position_ids_from_inputs_embeds(
|
|
|
|
|
170 |
"""
|
171 |
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
172 |
|
@@ -179,9 +225,15 @@ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
|
|
179 |
sequence_length = input_shape[1]
|
180 |
|
181 |
position_ids = torch.arange(
|
182 |
-
self.padding_idx + 1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
)
|
184 |
-
return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
|
185 |
|
186 |
|
187 |
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
|
@@ -216,7 +268,11 @@ class IndicTransAttention(nn.Module):
|
|
216 |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
217 |
|
218 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
219 |
-
return
|
|
|
|
|
|
|
|
|
220 |
|
221 |
def forward(
|
222 |
self,
|
@@ -293,7 +349,10 @@ class IndicTransAttention(nn.Module):
|
|
293 |
raise ValueError(
|
294 |
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
295 |
)
|
296 |
-
attn_weights =
|
|
|
|
|
|
|
297 |
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
298 |
|
299 |
attn_weights = F.softmax(attn_weights, dim=-1)
|
@@ -304,7 +363,9 @@ class IndicTransAttention(nn.Module):
|
|
304 |
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
305 |
f" {layer_head_mask.size()}"
|
306 |
)
|
307 |
-
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
|
|
|
|
|
308 |
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
309 |
|
310 |
if output_attentions:
|
@@ -312,8 +373,12 @@ class IndicTransAttention(nn.Module):
|
|
312 |
# make sure that attn_weights keeps its gradient.
|
313 |
# In order to do so, attn_weights have to be reshaped
|
314 |
# twice and have to be reused in the following
|
315 |
-
attn_weights_reshaped = attn_weights.view(
|
316 |
-
|
|
|
|
|
|
|
|
|
317 |
else:
|
318 |
attn_weights_reshaped = None
|
319 |
|
@@ -394,7 +459,9 @@ class IndicTransEncoderLayer(nn.Module):
|
|
394 |
if self.normalize_before:
|
395 |
hidden_states = self.final_layer_norm(hidden_states)
|
396 |
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
397 |
-
hidden_states = F.dropout(
|
|
|
|
|
398 |
hidden_states = self.fc2(hidden_states)
|
399 |
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
400 |
hidden_states = residual + hidden_states
|
@@ -405,7 +472,9 @@ class IndicTransEncoderLayer(nn.Module):
|
|
405 |
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
406 |
):
|
407 |
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
408 |
-
hidden_states = torch.clamp(
|
|
|
|
|
409 |
|
410 |
outputs = (hidden_states,)
|
411 |
|
@@ -480,7 +549,9 @@ class IndicTransDecoderLayer(nn.Module):
|
|
480 |
|
481 |
# Self Attention
|
482 |
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
483 |
-
self_attn_past_key_value =
|
|
|
|
|
484 |
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
485 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
486 |
hidden_states=hidden_states,
|
@@ -503,8 +574,14 @@ class IndicTransDecoderLayer(nn.Module):
|
|
503 |
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
504 |
|
505 |
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
506 |
-
cross_attn_past_key_value =
|
507 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
hidden_states=hidden_states,
|
509 |
key_value_states=encoder_hidden_states,
|
510 |
attention_mask=encoder_attention_mask,
|
@@ -512,7 +589,9 @@ class IndicTransDecoderLayer(nn.Module):
|
|
512 |
past_key_value=cross_attn_past_key_value,
|
513 |
output_attentions=output_attentions,
|
514 |
)
|
515 |
-
hidden_states = F.dropout(
|
|
|
|
|
516 |
hidden_states = residual + hidden_states
|
517 |
if not self.normalize_before:
|
518 |
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
@@ -525,7 +604,9 @@ class IndicTransDecoderLayer(nn.Module):
|
|
525 |
if self.normalize_before:
|
526 |
hidden_states = self.final_layer_norm(hidden_states)
|
527 |
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
528 |
-
hidden_states = F.dropout(
|
|
|
|
|
529 |
hidden_states = self.fc2(hidden_states)
|
530 |
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
531 |
hidden_states = residual + hidden_states
|
@@ -577,7 +658,9 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
|
|
577 |
embed_tokens (nn.Embedding): output embedding
|
578 |
"""
|
579 |
|
580 |
-
def __init__(
|
|
|
|
|
581 |
super().__init__(config)
|
582 |
|
583 |
self.dropout = config.dropout
|
@@ -588,7 +671,9 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
|
|
588 |
self.max_source_positions = config.max_source_positions
|
589 |
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
590 |
|
591 |
-
self.embed_tokens = nn.Embedding(
|
|
|
|
|
592 |
|
593 |
if embed_tokens is not None:
|
594 |
self.embed_tokens.weight = embed_tokens.weight
|
@@ -598,9 +683,15 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
|
|
598 |
embed_dim,
|
599 |
self.padding_idx,
|
600 |
)
|
601 |
-
self.layers = nn.ModuleList(
|
602 |
-
|
603 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
|
605 |
self.gradient_checkpointing = False
|
606 |
# Initialize weights and apply final processing
|
@@ -652,15 +743,25 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
|
|
652 |
return_dict (`bool`, *optional*):
|
653 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
654 |
"""
|
655 |
-
output_attentions =
|
|
|
|
|
|
|
|
|
656 |
output_hidden_states = (
|
657 |
-
output_hidden_states
|
|
|
|
|
|
|
|
|
|
|
658 |
)
|
659 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
660 |
|
661 |
# retrieve input_ids and inputs_embeds
|
662 |
if input_ids is not None and inputs_embeds is not None:
|
663 |
-
raise ValueError(
|
|
|
|
|
664 |
elif input_ids is not None:
|
665 |
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
666 |
input_shape = input_ids.size()
|
@@ -705,7 +806,11 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
|
|
705 |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
706 |
dropout_probability = torch.rand([])
|
707 |
|
708 |
-
skip_the_layer =
|
|
|
|
|
|
|
|
|
709 |
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
710 |
# under deepspeed zero3 all gpus must run in sync
|
711 |
|
@@ -727,7 +832,9 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
|
|
727 |
layer_outputs = encoder_layer(
|
728 |
hidden_states,
|
729 |
attention_mask,
|
730 |
-
layer_head_mask=(
|
|
|
|
|
731 |
output_attentions=output_attentions,
|
732 |
)
|
733 |
|
@@ -746,9 +853,15 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
|
|
746 |
encoder_states = encoder_states + (hidden_states,)
|
747 |
|
748 |
if not return_dict:
|
749 |
-
return tuple(
|
|
|
|
|
|
|
|
|
750 |
return BaseModelOutput(
|
751 |
-
last_hidden_state=hidden_states,
|
|
|
|
|
752 |
)
|
753 |
|
754 |
|
@@ -762,7 +875,9 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
762 |
embed_tokens (nn.Embedding): output embedding
|
763 |
"""
|
764 |
|
765 |
-
def __init__(
|
|
|
|
|
766 |
super().__init__(config)
|
767 |
self.dropout = config.dropout
|
768 |
self.layerdrop = config.decoder_layerdrop
|
@@ -772,7 +887,9 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
772 |
self.max_target_positions = config.max_target_positions
|
773 |
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
774 |
|
775 |
-
self.embed_tokens = nn.Embedding(
|
|
|
|
|
776 |
|
777 |
if embed_tokens is not None:
|
778 |
self.embed_tokens.weight = embed_tokens.weight
|
@@ -782,9 +899,15 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
782 |
embed_dim,
|
783 |
self.padding_idx,
|
784 |
)
|
785 |
-
self.layers = nn.ModuleList(
|
786 |
-
|
787 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
788 |
|
789 |
self.gradient_checkpointing = False
|
790 |
# Initialize weights and apply final processing
|
@@ -870,26 +993,40 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
870 |
return_dict (`bool`, *optional*):
|
871 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
872 |
"""
|
873 |
-
output_attentions =
|
|
|
|
|
|
|
|
|
874 |
output_hidden_states = (
|
875 |
-
output_hidden_states
|
|
|
|
|
876 |
)
|
877 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
878 |
-
return_dict =
|
|
|
|
|
879 |
|
880 |
# retrieve input_ids and inputs_embeds
|
881 |
if input_ids is not None and inputs_embeds is not None:
|
882 |
-
raise ValueError(
|
|
|
|
|
883 |
elif input_ids is not None:
|
884 |
input_shape = input_ids.size()
|
885 |
input_ids = input_ids.view(-1, input_shape[-1])
|
886 |
elif inputs_embeds is not None:
|
887 |
input_shape = inputs_embeds.size()[:-1]
|
888 |
else:
|
889 |
-
raise ValueError(
|
|
|
|
|
890 |
|
891 |
# past_key_values_length
|
892 |
-
past_key_values_length =
|
|
|
|
|
893 |
|
894 |
if inputs_embeds is None:
|
895 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
@@ -914,10 +1051,14 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
914 |
# expand encoder attention mask
|
915 |
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
916 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
917 |
-
encoder_attention_mask = _expand_mask(
|
|
|
|
|
918 |
|
919 |
# embed positions
|
920 |
-
positions = self.embed_positions(
|
|
|
|
|
921 |
positions = positions.to(inputs_embeds.device)
|
922 |
|
923 |
hidden_states = inputs_embeds + positions
|
@@ -929,7 +1070,8 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
929 |
if self.gradient_checkpointing and self.training:
|
930 |
if use_cache:
|
931 |
logger.warning_once(
|
932 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting"
|
|
|
933 |
)
|
934 |
use_cache = False
|
935 |
|
@@ -940,7 +1082,9 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
940 |
next_decoder_cache = () if use_cache else None
|
941 |
|
942 |
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
943 |
-
for attn_mask, mask_name in zip(
|
|
|
|
|
944 |
if attn_mask is not None:
|
945 |
if attn_mask.size()[0] != len(self.layers):
|
946 |
raise ValueError(
|
@@ -956,11 +1100,17 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
956 |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
957 |
dropout_probability = torch.rand([])
|
958 |
|
959 |
-
skip_the_layer =
|
|
|
|
|
|
|
|
|
960 |
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
961 |
# under deepspeed zero3 all gpus must run in sync
|
962 |
|
963 |
-
past_key_value =
|
|
|
|
|
964 |
|
965 |
if self.gradient_checkpointing and self.training:
|
966 |
|
@@ -978,7 +1128,9 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
978 |
encoder_hidden_states,
|
979 |
encoder_attention_mask,
|
980 |
head_mask[idx] if head_mask is not None else None,
|
981 |
-
cross_attn_head_mask[idx]
|
|
|
|
|
982 |
None,
|
983 |
)
|
984 |
else:
|
@@ -987,9 +1139,13 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
987 |
attention_mask=combined_attention_mask,
|
988 |
encoder_hidden_states=encoder_hidden_states,
|
989 |
encoder_attention_mask=encoder_attention_mask,
|
990 |
-
layer_head_mask=(
|
|
|
|
|
991 |
cross_attn_layer_head_mask=(
|
992 |
-
cross_attn_head_mask[idx]
|
|
|
|
|
993 |
),
|
994 |
past_key_value=past_key_value,
|
995 |
output_attentions=output_attentions,
|
@@ -1019,7 +1175,13 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
1019 |
if not return_dict:
|
1020 |
return tuple(
|
1021 |
v
|
1022 |
-
for v in [
|
|
|
|
|
|
|
|
|
|
|
|
|
1023 |
if v is not None
|
1024 |
)
|
1025 |
return BaseModelOutputWithPastAndCrossAttentions(
|
@@ -1037,7 +1199,7 @@ class IndicTransModel(IndicTransPreTrainedModel):
|
|
1037 |
|
1038 |
def __init__(self, config: IndicTransConfig):
|
1039 |
super().__init__(config)
|
1040 |
-
|
1041 |
self.encoder = IndicTransEncoder(config)
|
1042 |
self.decoder = IndicTransDecoder(config)
|
1043 |
|
@@ -1068,12 +1230,20 @@ class IndicTransModel(IndicTransPreTrainedModel):
|
|
1068 |
output_hidden_states: Optional[bool] = None,
|
1069 |
return_dict: Optional[bool] = None,
|
1070 |
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
1071 |
-
output_attentions =
|
|
|
|
|
|
|
|
|
1072 |
output_hidden_states = (
|
1073 |
-
output_hidden_states
|
|
|
|
|
1074 |
)
|
1075 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1076 |
-
return_dict =
|
|
|
|
|
1077 |
|
1078 |
if encoder_outputs is None:
|
1079 |
encoder_outputs = self.encoder(
|
@@ -1128,17 +1298,20 @@ class IndicTransModel(IndicTransPreTrainedModel):
|
|
1128 |
class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
1129 |
base_model_prefix = "model"
|
1130 |
_tied_weights_keys = None
|
|
|
1131 |
|
1132 |
def __init__(self, config: IndicTransConfig):
|
1133 |
super().__init__(config)
|
1134 |
self.model = IndicTransModel(config)
|
1135 |
-
self.lm_head = nn.Linear(
|
|
|
|
|
1136 |
|
1137 |
if config.share_decoder_input_output_embed:
|
1138 |
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
1139 |
-
|
1140 |
self.post_init()
|
1141 |
-
|
1142 |
def tie_weights(self):
|
1143 |
pass
|
1144 |
|
@@ -1153,6 +1326,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1153 |
|
1154 |
def set_output_embeddings(self, new_embeddings):
|
1155 |
self.lm_head = new_embeddings
|
|
|
|
|
|
|
1156 |
|
1157 |
def forward(
|
1158 |
self,
|
@@ -1181,7 +1357,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1181 |
|
1182 |
Returns:
|
1183 |
"""
|
1184 |
-
return_dict =
|
|
|
|
|
1185 |
|
1186 |
if labels is not None:
|
1187 |
if decoder_input_ids is None:
|
@@ -1212,12 +1390,18 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1212 |
if labels is not None:
|
1213 |
# move labels to the correct device to enable PP
|
1214 |
labels = labels.to(lm_logits.device)
|
1215 |
-
|
1216 |
-
|
|
|
|
|
|
|
|
|
1217 |
|
1218 |
if not return_dict:
|
1219 |
output = (lm_logits,) + outputs[1:]
|
1220 |
-
return (
|
|
|
|
|
1221 |
|
1222 |
return Seq2SeqLMOutput(
|
1223 |
loss=masked_lm_loss,
|
@@ -1263,5 +1447,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1263 |
def _reorder_cache(past_key_values, beam_idx):
|
1264 |
reordered_past = ()
|
1265 |
for layer_past in past_key_values:
|
1266 |
-
reordered_past += (
|
1267 |
-
|
|
|
|
|
|
|
|
|
|
34 |
from transformers.utils import logging
|
35 |
from transformers.modeling_utils import PreTrainedModel
|
36 |
|
37 |
+
from configuration_indictrans import IndicTransConfig
|
38 |
|
39 |
|
40 |
logger = logging.get_logger(__name__)
|
|
|
45 |
|
46 |
|
47 |
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
48 |
+
def shift_tokens_right(
|
49 |
+
input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
|
50 |
+
):
|
51 |
"""
|
52 |
Shift input ids one token to the right.
|
53 |
"""
|
|
|
65 |
|
66 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
67 |
def _make_causal_mask(
|
68 |
+
input_ids_shape: torch.Size,
|
69 |
+
dtype: torch.dtype,
|
70 |
+
device: torch.device,
|
71 |
+
past_key_values_length: int = 0,
|
72 |
):
|
73 |
"""
|
74 |
Make causal mask used for bi-directional self-attention.
|
|
|
80 |
mask = mask.to(dtype)
|
81 |
|
82 |
if past_key_values_length > 0:
|
83 |
+
mask = torch.cat(
|
84 |
+
[
|
85 |
+
torch.zeros(
|
86 |
+
tgt_len, past_key_values_length, dtype=dtype, device=device
|
87 |
+
),
|
88 |
+
mask,
|
89 |
+
],
|
90 |
+
dim=-1,
|
91 |
+
)
|
92 |
+
return mask[None, None, :, :].expand(
|
93 |
+
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
94 |
+
)
|
95 |
|
96 |
|
97 |
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
|
|
106 |
|
107 |
inverted_mask = 1.0 - expanded_mask
|
108 |
|
109 |
+
return inverted_mask.masked_fill(
|
110 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
111 |
+
)
|
112 |
|
113 |
|
114 |
+
def create_position_ids_from_input_ids(
|
115 |
+
input_ids, padding_idx, past_key_values_length=0
|
116 |
+
):
|
117 |
"""
|
118 |
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
119 |
are ignored. This is modified from fairseq's `utils.make_positions`.
|
120 |
"""
|
121 |
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
122 |
mask = input_ids.ne(padding_idx).int()
|
123 |
+
incremental_indices = (
|
124 |
+
torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
|
125 |
+
) * mask
|
126 |
return incremental_indices.long() + padding_idx
|
127 |
|
128 |
|
|
|
130 |
class IndicTransSinusoidalPositionalEmbedding(nn.Module):
|
131 |
"""This module produces sinusoidal positional embeddings of any length."""
|
132 |
|
133 |
+
def __init__(
|
134 |
+
self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
|
135 |
+
):
|
136 |
super().__init__()
|
137 |
self.offset = 2
|
138 |
self.embedding_dim = embedding_dim
|
139 |
self.padding_idx = padding_idx
|
140 |
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
|
141 |
|
142 |
+
def make_weights(
|
143 |
+
self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
|
144 |
+
):
|
145 |
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
|
146 |
if hasattr(self, "weights"):
|
147 |
# in forward put the weights on the correct dtype and device of the param
|
148 |
+
emb_weights = emb_weights.to(
|
149 |
+
dtype=self.weights.dtype, device=self.weights.device
|
150 |
+
)
|
151 |
|
152 |
self.register_buffer("weights", emb_weights, persistent=False)
|
153 |
|
154 |
@staticmethod
|
155 |
+
def get_embedding(
|
156 |
+
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
|
157 |
+
):
|
158 |
"""
|
159 |
Build sinusoidal embeddings.
|
160 |
|
|
|
164 |
half_dim = embedding_dim // 2
|
165 |
emb = math.log(10000) / (half_dim - 1)
|
166 |
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
167 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
|
168 |
+
1
|
169 |
+
) * emb.unsqueeze(0)
|
170 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
|
171 |
+
num_embeddings, -1
|
172 |
+
)
|
173 |
if embedding_dim % 2 == 1:
|
174 |
# zero pad
|
175 |
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
|
|
180 |
|
181 |
@torch.no_grad()
|
182 |
def forward(
|
183 |
+
self,
|
184 |
+
input_ids: torch.Tensor = None,
|
185 |
+
inputs_embeds: torch.Tensor = None,
|
186 |
+
past_key_values_length: int = 0,
|
187 |
):
|
188 |
if input_ids is not None:
|
189 |
bsz, seq_len = input_ids.size()
|
190 |
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
191 |
+
position_ids = create_position_ids_from_input_ids(
|
192 |
+
input_ids, self.padding_idx, past_key_values_length
|
193 |
+
).to(input_ids.device)
|
194 |
else:
|
195 |
bsz, seq_len = inputs_embeds.size()[:-1]
|
196 |
+
position_ids = self.create_position_ids_from_inputs_embeds(
|
197 |
+
inputs_embeds, past_key_values_length
|
198 |
+
)
|
199 |
|
200 |
# expand embeddings if needed
|
201 |
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
|
202 |
if max_pos > self.weights.size(0):
|
203 |
+
self.make_weights(
|
204 |
+
max_pos + self.offset, self.embedding_dim, self.padding_idx
|
205 |
+
)
|
206 |
|
207 |
+
return (
|
208 |
+
self.weights.index_select(0, position_ids.view(-1))
|
209 |
+
.view(bsz, seq_len, self.weights.shape[-1])
|
210 |
+
.detach()
|
211 |
+
)
|
212 |
|
213 |
+
def create_position_ids_from_inputs_embeds(
|
214 |
+
self, inputs_embeds, past_key_values_length
|
215 |
+
):
|
216 |
"""
|
217 |
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
218 |
|
|
|
225 |
sequence_length = input_shape[1]
|
226 |
|
227 |
position_ids = torch.arange(
|
228 |
+
self.padding_idx + 1,
|
229 |
+
sequence_length + self.padding_idx + 1,
|
230 |
+
dtype=torch.long,
|
231 |
+
device=inputs_embeds.device,
|
232 |
+
)
|
233 |
+
return (
|
234 |
+
position_ids.unsqueeze(0).expand(input_shape).contiguous()
|
235 |
+
+ past_key_values_length
|
236 |
)
|
|
|
237 |
|
238 |
|
239 |
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
|
|
|
268 |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
269 |
|
270 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
271 |
+
return (
|
272 |
+
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
273 |
+
.transpose(1, 2)
|
274 |
+
.contiguous()
|
275 |
+
)
|
276 |
|
277 |
def forward(
|
278 |
self,
|
|
|
349 |
raise ValueError(
|
350 |
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
351 |
)
|
352 |
+
attn_weights = (
|
353 |
+
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
354 |
+
+ attention_mask
|
355 |
+
)
|
356 |
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
357 |
|
358 |
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
|
363 |
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
364 |
f" {layer_head_mask.size()}"
|
365 |
)
|
366 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
|
367 |
+
bsz, self.num_heads, tgt_len, src_len
|
368 |
+
)
|
369 |
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
370 |
|
371 |
if output_attentions:
|
|
|
373 |
# make sure that attn_weights keeps its gradient.
|
374 |
# In order to do so, attn_weights have to be reshaped
|
375 |
# twice and have to be reused in the following
|
376 |
+
attn_weights_reshaped = attn_weights.view(
|
377 |
+
bsz, self.num_heads, tgt_len, src_len
|
378 |
+
)
|
379 |
+
attn_weights = attn_weights_reshaped.view(
|
380 |
+
bsz * self.num_heads, tgt_len, src_len
|
381 |
+
)
|
382 |
else:
|
383 |
attn_weights_reshaped = None
|
384 |
|
|
|
459 |
if self.normalize_before:
|
460 |
hidden_states = self.final_layer_norm(hidden_states)
|
461 |
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
462 |
+
hidden_states = F.dropout(
|
463 |
+
hidden_states, p=self.activation_dropout, training=self.training
|
464 |
+
)
|
465 |
hidden_states = self.fc2(hidden_states)
|
466 |
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
467 |
hidden_states = residual + hidden_states
|
|
|
472 |
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
473 |
):
|
474 |
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
475 |
+
hidden_states = torch.clamp(
|
476 |
+
hidden_states, min=-clamp_value, max=clamp_value
|
477 |
+
)
|
478 |
|
479 |
outputs = (hidden_states,)
|
480 |
|
|
|
549 |
|
550 |
# Self Attention
|
551 |
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
552 |
+
self_attn_past_key_value = (
|
553 |
+
past_key_value[:2] if past_key_value is not None else None
|
554 |
+
)
|
555 |
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
556 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
557 |
hidden_states=hidden_states,
|
|
|
574 |
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
575 |
|
576 |
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
577 |
+
cross_attn_past_key_value = (
|
578 |
+
past_key_value[-2:] if past_key_value is not None else None
|
579 |
+
)
|
580 |
+
(
|
581 |
+
hidden_states,
|
582 |
+
cross_attn_weights,
|
583 |
+
cross_attn_present_key_value,
|
584 |
+
) = self.encoder_attn(
|
585 |
hidden_states=hidden_states,
|
586 |
key_value_states=encoder_hidden_states,
|
587 |
attention_mask=encoder_attention_mask,
|
|
|
589 |
past_key_value=cross_attn_past_key_value,
|
590 |
output_attentions=output_attentions,
|
591 |
)
|
592 |
+
hidden_states = F.dropout(
|
593 |
+
hidden_states, p=self.dropout, training=self.training
|
594 |
+
)
|
595 |
hidden_states = residual + hidden_states
|
596 |
if not self.normalize_before:
|
597 |
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
|
|
604 |
if self.normalize_before:
|
605 |
hidden_states = self.final_layer_norm(hidden_states)
|
606 |
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
607 |
+
hidden_states = F.dropout(
|
608 |
+
hidden_states, p=self.activation_dropout, training=self.training
|
609 |
+
)
|
610 |
hidden_states = self.fc2(hidden_states)
|
611 |
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
612 |
hidden_states = residual + hidden_states
|
|
|
658 |
embed_tokens (nn.Embedding): output embedding
|
659 |
"""
|
660 |
|
661 |
+
def __init__(
|
662 |
+
self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
|
663 |
+
):
|
664 |
super().__init__(config)
|
665 |
|
666 |
self.dropout = config.dropout
|
|
|
671 |
self.max_source_positions = config.max_source_positions
|
672 |
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
673 |
|
674 |
+
self.embed_tokens = nn.Embedding(
|
675 |
+
config.encoder_vocab_size, embed_dim, self.padding_idx
|
676 |
+
)
|
677 |
|
678 |
if embed_tokens is not None:
|
679 |
self.embed_tokens.weight = embed_tokens.weight
|
|
|
683 |
embed_dim,
|
684 |
self.padding_idx,
|
685 |
)
|
686 |
+
self.layers = nn.ModuleList(
|
687 |
+
[IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)]
|
688 |
+
)
|
689 |
+
self.layer_norm = (
|
690 |
+
nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
|
691 |
+
)
|
692 |
+
self.layernorm_embedding = (
|
693 |
+
nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
|
694 |
+
)
|
695 |
|
696 |
self.gradient_checkpointing = False
|
697 |
# Initialize weights and apply final processing
|
|
|
743 |
return_dict (`bool`, *optional*):
|
744 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
745 |
"""
|
746 |
+
output_attentions = (
|
747 |
+
output_attentions
|
748 |
+
if output_attentions is not None
|
749 |
+
else self.config.output_attentions
|
750 |
+
)
|
751 |
output_hidden_states = (
|
752 |
+
output_hidden_states
|
753 |
+
if output_hidden_states is not None
|
754 |
+
else self.config.output_hidden_states
|
755 |
+
)
|
756 |
+
return_dict = (
|
757 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
758 |
)
|
|
|
759 |
|
760 |
# retrieve input_ids and inputs_embeds
|
761 |
if input_ids is not None and inputs_embeds is not None:
|
762 |
+
raise ValueError(
|
763 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
764 |
+
)
|
765 |
elif input_ids is not None:
|
766 |
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
767 |
input_shape = input_ids.size()
|
|
|
806 |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
807 |
dropout_probability = torch.rand([])
|
808 |
|
809 |
+
skip_the_layer = (
|
810 |
+
True
|
811 |
+
if self.training and (dropout_probability < self.layerdrop)
|
812 |
+
else False
|
813 |
+
)
|
814 |
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
815 |
# under deepspeed zero3 all gpus must run in sync
|
816 |
|
|
|
832 |
layer_outputs = encoder_layer(
|
833 |
hidden_states,
|
834 |
attention_mask,
|
835 |
+
layer_head_mask=(
|
836 |
+
head_mask[idx] if head_mask is not None else None
|
837 |
+
),
|
838 |
output_attentions=output_attentions,
|
839 |
)
|
840 |
|
|
|
853 |
encoder_states = encoder_states + (hidden_states,)
|
854 |
|
855 |
if not return_dict:
|
856 |
+
return tuple(
|
857 |
+
v
|
858 |
+
for v in [hidden_states, encoder_states, all_attentions]
|
859 |
+
if v is not None
|
860 |
+
)
|
861 |
return BaseModelOutput(
|
862 |
+
last_hidden_state=hidden_states,
|
863 |
+
hidden_states=encoder_states,
|
864 |
+
attentions=all_attentions,
|
865 |
)
|
866 |
|
867 |
|
|
|
875 |
embed_tokens (nn.Embedding): output embedding
|
876 |
"""
|
877 |
|
878 |
+
def __init__(
|
879 |
+
self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
|
880 |
+
):
|
881 |
super().__init__(config)
|
882 |
self.dropout = config.dropout
|
883 |
self.layerdrop = config.decoder_layerdrop
|
|
|
887 |
self.max_target_positions = config.max_target_positions
|
888 |
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
889 |
|
890 |
+
self.embed_tokens = nn.Embedding(
|
891 |
+
config.decoder_vocab_size, embed_dim, self.padding_idx
|
892 |
+
)
|
893 |
|
894 |
if embed_tokens is not None:
|
895 |
self.embed_tokens.weight = embed_tokens.weight
|
|
|
899 |
embed_dim,
|
900 |
self.padding_idx,
|
901 |
)
|
902 |
+
self.layers = nn.ModuleList(
|
903 |
+
[IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)]
|
904 |
+
)
|
905 |
+
self.layer_norm = (
|
906 |
+
nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
|
907 |
+
)
|
908 |
+
self.layernorm_embedding = (
|
909 |
+
nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
|
910 |
+
)
|
911 |
|
912 |
self.gradient_checkpointing = False
|
913 |
# Initialize weights and apply final processing
|
|
|
993 |
return_dict (`bool`, *optional*):
|
994 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
995 |
"""
|
996 |
+
output_attentions = (
|
997 |
+
output_attentions
|
998 |
+
if output_attentions is not None
|
999 |
+
else self.config.output_attentions
|
1000 |
+
)
|
1001 |
output_hidden_states = (
|
1002 |
+
output_hidden_states
|
1003 |
+
if output_hidden_states is not None
|
1004 |
+
else self.config.output_hidden_states
|
1005 |
)
|
1006 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1007 |
+
return_dict = (
|
1008 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1009 |
+
)
|
1010 |
|
1011 |
# retrieve input_ids and inputs_embeds
|
1012 |
if input_ids is not None and inputs_embeds is not None:
|
1013 |
+
raise ValueError(
|
1014 |
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
1015 |
+
)
|
1016 |
elif input_ids is not None:
|
1017 |
input_shape = input_ids.size()
|
1018 |
input_ids = input_ids.view(-1, input_shape[-1])
|
1019 |
elif inputs_embeds is not None:
|
1020 |
input_shape = inputs_embeds.size()[:-1]
|
1021 |
else:
|
1022 |
+
raise ValueError(
|
1023 |
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
1024 |
+
)
|
1025 |
|
1026 |
# past_key_values_length
|
1027 |
+
past_key_values_length = (
|
1028 |
+
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
1029 |
+
)
|
1030 |
|
1031 |
if inputs_embeds is None:
|
1032 |
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
|
|
1051 |
# expand encoder attention mask
|
1052 |
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
1053 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1054 |
+
encoder_attention_mask = _expand_mask(
|
1055 |
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
1056 |
+
)
|
1057 |
|
1058 |
# embed positions
|
1059 |
+
positions = self.embed_positions(
|
1060 |
+
input_ids, inputs_embeds, past_key_values_length
|
1061 |
+
)
|
1062 |
positions = positions.to(inputs_embeds.device)
|
1063 |
|
1064 |
hidden_states = inputs_embeds + positions
|
|
|
1070 |
if self.gradient_checkpointing and self.training:
|
1071 |
if use_cache:
|
1072 |
logger.warning_once(
|
1073 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting"
|
1074 |
+
" `use_cache=False`..."
|
1075 |
)
|
1076 |
use_cache = False
|
1077 |
|
|
|
1082 |
next_decoder_cache = () if use_cache else None
|
1083 |
|
1084 |
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
1085 |
+
for attn_mask, mask_name in zip(
|
1086 |
+
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
|
1087 |
+
):
|
1088 |
if attn_mask is not None:
|
1089 |
if attn_mask.size()[0] != len(self.layers):
|
1090 |
raise ValueError(
|
|
|
1100 |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
1101 |
dropout_probability = torch.rand([])
|
1102 |
|
1103 |
+
skip_the_layer = (
|
1104 |
+
True
|
1105 |
+
if self.training and (dropout_probability < self.layerdrop)
|
1106 |
+
else False
|
1107 |
+
)
|
1108 |
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
1109 |
# under deepspeed zero3 all gpus must run in sync
|
1110 |
|
1111 |
+
past_key_value = (
|
1112 |
+
past_key_values[idx] if past_key_values is not None else None
|
1113 |
+
)
|
1114 |
|
1115 |
if self.gradient_checkpointing and self.training:
|
1116 |
|
|
|
1128 |
encoder_hidden_states,
|
1129 |
encoder_attention_mask,
|
1130 |
head_mask[idx] if head_mask is not None else None,
|
1131 |
+
cross_attn_head_mask[idx]
|
1132 |
+
if cross_attn_head_mask is not None
|
1133 |
+
else None,
|
1134 |
None,
|
1135 |
)
|
1136 |
else:
|
|
|
1139 |
attention_mask=combined_attention_mask,
|
1140 |
encoder_hidden_states=encoder_hidden_states,
|
1141 |
encoder_attention_mask=encoder_attention_mask,
|
1142 |
+
layer_head_mask=(
|
1143 |
+
head_mask[idx] if head_mask is not None else None
|
1144 |
+
),
|
1145 |
cross_attn_layer_head_mask=(
|
1146 |
+
cross_attn_head_mask[idx]
|
1147 |
+
if cross_attn_head_mask is not None
|
1148 |
+
else None
|
1149 |
),
|
1150 |
past_key_value=past_key_value,
|
1151 |
output_attentions=output_attentions,
|
|
|
1175 |
if not return_dict:
|
1176 |
return tuple(
|
1177 |
v
|
1178 |
+
for v in [
|
1179 |
+
hidden_states,
|
1180 |
+
next_cache,
|
1181 |
+
all_hidden_states,
|
1182 |
+
all_self_attns,
|
1183 |
+
all_cross_attentions,
|
1184 |
+
]
|
1185 |
if v is not None
|
1186 |
)
|
1187 |
return BaseModelOutputWithPastAndCrossAttentions(
|
|
|
1199 |
|
1200 |
def __init__(self, config: IndicTransConfig):
|
1201 |
super().__init__(config)
|
1202 |
+
|
1203 |
self.encoder = IndicTransEncoder(config)
|
1204 |
self.decoder = IndicTransDecoder(config)
|
1205 |
|
|
|
1230 |
output_hidden_states: Optional[bool] = None,
|
1231 |
return_dict: Optional[bool] = None,
|
1232 |
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
1233 |
+
output_attentions = (
|
1234 |
+
output_attentions
|
1235 |
+
if output_attentions is not None
|
1236 |
+
else self.config.output_attentions
|
1237 |
+
)
|
1238 |
output_hidden_states = (
|
1239 |
+
output_hidden_states
|
1240 |
+
if output_hidden_states is not None
|
1241 |
+
else self.config.output_hidden_states
|
1242 |
)
|
1243 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1244 |
+
return_dict = (
|
1245 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1246 |
+
)
|
1247 |
|
1248 |
if encoder_outputs is None:
|
1249 |
encoder_outputs = self.encoder(
|
|
|
1298 |
class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
1299 |
base_model_prefix = "model"
|
1300 |
_tied_weights_keys = None
|
1301 |
+
_label_smoothing = 0.0
|
1302 |
|
1303 |
def __init__(self, config: IndicTransConfig):
|
1304 |
super().__init__(config)
|
1305 |
self.model = IndicTransModel(config)
|
1306 |
+
self.lm_head = nn.Linear(
|
1307 |
+
config.decoder_embed_dim, config.decoder_vocab_size, bias=False
|
1308 |
+
)
|
1309 |
|
1310 |
if config.share_decoder_input_output_embed:
|
1311 |
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
1312 |
+
|
1313 |
self.post_init()
|
1314 |
+
|
1315 |
def tie_weights(self):
|
1316 |
pass
|
1317 |
|
|
|
1326 |
|
1327 |
def set_output_embeddings(self, new_embeddings):
|
1328 |
self.lm_head = new_embeddings
|
1329 |
+
|
1330 |
+
def set_label_smoothing(self, label_smoothing):
|
1331 |
+
self._label_smoothing = label_smoothing
|
1332 |
|
1333 |
def forward(
|
1334 |
self,
|
|
|
1357 |
|
1358 |
Returns:
|
1359 |
"""
|
1360 |
+
return_dict = (
|
1361 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1362 |
+
)
|
1363 |
|
1364 |
if labels is not None:
|
1365 |
if decoder_input_ids is None:
|
|
|
1390 |
if labels is not None:
|
1391 |
# move labels to the correct device to enable PP
|
1392 |
labels = labels.to(lm_logits.device)
|
1393 |
+
masked_lm_loss = F.cross_entropy(
|
1394 |
+
input=lm_logits.view(-1, self.config.decoder_vocab_size),
|
1395 |
+
target=labels.view(-1),
|
1396 |
+
ignore_index=self.config.pad_token_id,
|
1397 |
+
label_smoothing=self._label_smoothing,
|
1398 |
+
)
|
1399 |
|
1400 |
if not return_dict:
|
1401 |
output = (lm_logits,) + outputs[1:]
|
1402 |
+
return (
|
1403 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1404 |
+
)
|
1405 |
|
1406 |
return Seq2SeqLMOutput(
|
1407 |
loss=masked_lm_loss,
|
|
|
1447 |
def _reorder_cache(past_key_values, beam_idx):
|
1448 |
reordered_past = ()
|
1449 |
for layer_past in past_key_values:
|
1450 |
+
reordered_past += (
|
1451 |
+
tuple(
|
1452 |
+
past_state.index_select(0, beam_idx) for past_state in layer_past
|
1453 |
+
),
|
1454 |
+
)
|
1455 |
+
return reordered_past
|