Upload 7 files
Browse files- modeling_indictrans.py +116 -104
modeling_indictrans.py
CHANGED
@@ -14,7 +14,6 @@
|
|
14 |
# limitations under the License.
|
15 |
""" PyTorch IndicTrans model."""
|
16 |
|
17 |
-
|
18 |
import math
|
19 |
from typing import List, Optional, Tuple, Union
|
20 |
|
@@ -36,12 +35,12 @@ from transformers.modeling_utils import PreTrainedModel
|
|
36 |
|
37 |
from .configuration_indictrans import IndicTransConfig
|
38 |
|
39 |
-
|
40 |
logger = logging.get_logger(__name__)
|
41 |
|
42 |
_CONFIG_FOR_DOC = "IndicTransConfig"
|
43 |
|
44 |
INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
|
|
|
45 |
|
46 |
|
47 |
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
@@ -61,9 +60,19 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
|
|
61 |
return shifted_input_ids
|
62 |
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
65 |
def _make_causal_mask(
|
66 |
-
|
67 |
):
|
68 |
"""
|
69 |
Make causal mask used for bi-directional self-attention.
|
@@ -147,7 +156,7 @@ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
|
|
147 |
|
148 |
@torch.no_grad()
|
149 |
def forward(
|
150 |
-
|
151 |
):
|
152 |
if input_ids is not None:
|
153 |
bsz, seq_len = input_ids.size()
|
@@ -189,12 +198,12 @@ class IndicTransAttention(nn.Module):
|
|
189 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
190 |
|
191 |
def __init__(
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
):
|
199 |
super().__init__()
|
200 |
self.embed_dim = embed_dim
|
@@ -207,7 +216,7 @@ class IndicTransAttention(nn.Module):
|
|
207 |
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
208 |
f" and `num_heads`: {num_heads})."
|
209 |
)
|
210 |
-
self.scaling = self.head_dim
|
211 |
self.is_decoder = is_decoder
|
212 |
|
213 |
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
@@ -219,13 +228,13 @@ class IndicTransAttention(nn.Module):
|
|
219 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
220 |
|
221 |
def forward(
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
230 |
"""Input shape: Batch x Time x Channel"""
|
231 |
|
@@ -242,9 +251,9 @@ class IndicTransAttention(nn.Module):
|
|
242 |
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
243 |
# the provided `key_value_states` to support prefix tuning
|
244 |
if (
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
):
|
249 |
# reuse k,v, cross_attentions
|
250 |
key_states = past_key_value[0]
|
@@ -359,11 +368,11 @@ class IndicTransEncoderLayer(nn.Module):
|
|
359 |
self.normalize_before = config.encoder_normalize_before
|
360 |
|
361 |
def forward(
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
) -> torch.Tensor:
|
368 |
"""
|
369 |
Args:
|
@@ -402,7 +411,7 @@ class IndicTransEncoderLayer(nn.Module):
|
|
402 |
hidden_states = self.final_layer_norm(hidden_states)
|
403 |
|
404 |
if hidden_states.dtype == torch.float16 and (
|
405 |
-
|
406 |
):
|
407 |
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
408 |
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
@@ -445,16 +454,16 @@ class IndicTransDecoderLayer(nn.Module):
|
|
445 |
self.normalize_before = config.decoder_normalize_before
|
446 |
|
447 |
def forward(
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
) -> torch.Tensor:
|
459 |
"""
|
460 |
Args:
|
@@ -618,14 +627,14 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
|
|
618 |
return sentence_embedding.unsqueeze(1)
|
619 |
|
620 |
def forward(
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
):
|
630 |
r"""
|
631 |
Args:
|
@@ -810,19 +819,19 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
|
|
810 |
self.embed_tokens = value
|
811 |
|
812 |
def forward(
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
):
|
827 |
r"""
|
828 |
Args:
|
@@ -1056,7 +1065,7 @@ class IndicTransModel(IndicTransPreTrainedModel):
|
|
1056 |
|
1057 |
def __init__(self, config: IndicTransConfig):
|
1058 |
super().__init__(config)
|
1059 |
-
|
1060 |
self.encoder = IndicTransEncoder(config)
|
1061 |
self.decoder = IndicTransDecoder(config)
|
1062 |
|
@@ -1070,22 +1079,22 @@ class IndicTransModel(IndicTransPreTrainedModel):
|
|
1070 |
return self.decoder
|
1071 |
|
1072 |
def forward(
|
1073 |
-
|
1074 |
-
|
1075 |
-
|
1076 |
-
|
1077 |
-
|
1078 |
-
|
1079 |
-
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
|
1085 |
-
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
1090 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1091 |
output_hidden_states = (
|
@@ -1155,9 +1164,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1155 |
|
1156 |
if config.share_decoder_input_output_embed:
|
1157 |
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
1158 |
-
|
1159 |
self.post_init()
|
1160 |
-
|
1161 |
def tie_weights(self):
|
1162 |
pass
|
1163 |
|
@@ -1174,23 +1183,23 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1174 |
self.lm_head = new_embeddings
|
1175 |
|
1176 |
def forward(
|
1177 |
-
|
1178 |
-
|
1179 |
-
|
1180 |
-
|
1181 |
-
|
1182 |
-
|
1183 |
-
|
1184 |
-
|
1185 |
-
|
1186 |
-
|
1187 |
-
|
1188 |
-
|
1189 |
-
|
1190 |
-
|
1191 |
-
|
1192 |
-
|
1193 |
-
|
1194 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
1195 |
r"""
|
1196 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
@@ -1208,6 +1217,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1208 |
# labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
1209 |
# )
|
1210 |
|
|
|
|
|
|
|
1211 |
outputs = self.model(
|
1212 |
input_ids,
|
1213 |
attention_mask=attention_mask,
|
@@ -1251,16 +1263,16 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1251 |
)
|
1252 |
|
1253 |
def prepare_inputs_for_generation(
|
1254 |
-
|
1255 |
-
|
1256 |
-
|
1257 |
-
|
1258 |
-
|
1259 |
-
|
1260 |
-
|
1261 |
-
|
1262 |
-
|
1263 |
-
|
1264 |
):
|
1265 |
# cut decoder_input_ids if past is used
|
1266 |
if past_key_values is not None:
|
|
|
14 |
# limitations under the License.
|
15 |
""" PyTorch IndicTrans model."""
|
16 |
|
|
|
17 |
import math
|
18 |
from typing import List, Optional, Tuple, Union
|
19 |
|
|
|
35 |
|
36 |
from .configuration_indictrans import IndicTransConfig
|
37 |
|
|
|
38 |
logger = logging.get_logger(__name__)
|
39 |
|
40 |
_CONFIG_FOR_DOC = "IndicTransConfig"
|
41 |
|
42 |
INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
|
43 |
+
eos_token_id = 2
|
44 |
|
45 |
|
46 |
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
|
|
60 |
return shifted_input_ids
|
61 |
|
62 |
|
63 |
+
def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
|
64 |
+
mask = (decoder_input_ids == eos_token_id)
|
65 |
+
decoder_input_ids[mask] = 1
|
66 |
+
decoder_attention_mask[mask] = 0
|
67 |
+
|
68 |
+
labels = decoder_input_ids[:, 1:]
|
69 |
+
|
70 |
+
return decoder_input_ids, decoder_attention_mask, labels
|
71 |
+
|
72 |
+
|
73 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
74 |
def _make_causal_mask(
|
75 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
76 |
):
|
77 |
"""
|
78 |
Make causal mask used for bi-directional self-attention.
|
|
|
156 |
|
157 |
@torch.no_grad()
|
158 |
def forward(
|
159 |
+
self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0
|
160 |
):
|
161 |
if input_ids is not None:
|
162 |
bsz, seq_len = input_ids.size()
|
|
|
198 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
199 |
|
200 |
def __init__(
|
201 |
+
self,
|
202 |
+
embed_dim: int,
|
203 |
+
num_heads: int,
|
204 |
+
dropout: float = 0.0,
|
205 |
+
is_decoder: bool = False,
|
206 |
+
bias: bool = True,
|
207 |
):
|
208 |
super().__init__()
|
209 |
self.embed_dim = embed_dim
|
|
|
216 |
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
217 |
f" and `num_heads`: {num_heads})."
|
218 |
)
|
219 |
+
self.scaling = self.head_dim ** -0.5
|
220 |
self.is_decoder = is_decoder
|
221 |
|
222 |
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
|
228 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
229 |
|
230 |
def forward(
|
231 |
+
self,
|
232 |
+
hidden_states: torch.Tensor,
|
233 |
+
key_value_states: Optional[torch.Tensor] = None,
|
234 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
235 |
+
attention_mask: Optional[torch.Tensor] = None,
|
236 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
237 |
+
output_attentions: bool = False,
|
238 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
239 |
"""Input shape: Batch x Time x Channel"""
|
240 |
|
|
|
251 |
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
252 |
# the provided `key_value_states` to support prefix tuning
|
253 |
if (
|
254 |
+
is_cross_attention
|
255 |
+
and past_key_value is not None
|
256 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
257 |
):
|
258 |
# reuse k,v, cross_attentions
|
259 |
key_states = past_key_value[0]
|
|
|
368 |
self.normalize_before = config.encoder_normalize_before
|
369 |
|
370 |
def forward(
|
371 |
+
self,
|
372 |
+
hidden_states: torch.Tensor,
|
373 |
+
attention_mask: torch.Tensor,
|
374 |
+
layer_head_mask: torch.Tensor,
|
375 |
+
output_attentions: bool = False,
|
376 |
) -> torch.Tensor:
|
377 |
"""
|
378 |
Args:
|
|
|
411 |
hidden_states = self.final_layer_norm(hidden_states)
|
412 |
|
413 |
if hidden_states.dtype == torch.float16 and (
|
414 |
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
415 |
):
|
416 |
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
417 |
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
|
|
454 |
self.normalize_before = config.decoder_normalize_before
|
455 |
|
456 |
def forward(
|
457 |
+
self,
|
458 |
+
hidden_states: torch.Tensor,
|
459 |
+
attention_mask: Optional[torch.Tensor] = None,
|
460 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
461 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
462 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
463 |
+
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
464 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
465 |
+
output_attentions: Optional[bool] = False,
|
466 |
+
use_cache: Optional[bool] = True,
|
467 |
) -> torch.Tensor:
|
468 |
"""
|
469 |
Args:
|
|
|
627 |
return sentence_embedding.unsqueeze(1)
|
628 |
|
629 |
def forward(
|
630 |
+
self,
|
631 |
+
input_ids: Optional[torch.Tensor] = None,
|
632 |
+
attention_mask: Optional[torch.Tensor] = None,
|
633 |
+
head_mask: Optional[torch.Tensor] = None,
|
634 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
635 |
+
output_attentions: Optional[bool] = None,
|
636 |
+
output_hidden_states: Optional[bool] = None,
|
637 |
+
return_dict: Optional[bool] = None,
|
638 |
):
|
639 |
r"""
|
640 |
Args:
|
|
|
819 |
self.embed_tokens = value
|
820 |
|
821 |
def forward(
|
822 |
+
self,
|
823 |
+
input_ids: Optional[torch.Tensor] = None,
|
824 |
+
attention_mask: Optional[torch.Tensor] = None,
|
825 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
826 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
827 |
+
head_mask: Optional[torch.Tensor] = None,
|
828 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
829 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
830 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
831 |
+
use_cache: Optional[bool] = None,
|
832 |
+
output_attentions: Optional[bool] = None,
|
833 |
+
output_hidden_states: Optional[bool] = None,
|
834 |
+
return_dict: Optional[bool] = None,
|
835 |
):
|
836 |
r"""
|
837 |
Args:
|
|
|
1065 |
|
1066 |
def __init__(self, config: IndicTransConfig):
|
1067 |
super().__init__(config)
|
1068 |
+
|
1069 |
self.encoder = IndicTransEncoder(config)
|
1070 |
self.decoder = IndicTransDecoder(config)
|
1071 |
|
|
|
1079 |
return self.decoder
|
1080 |
|
1081 |
def forward(
|
1082 |
+
self,
|
1083 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1084 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1085 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1086 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1087 |
+
head_mask: Optional[torch.Tensor] = None,
|
1088 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
1089 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1090 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1091 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1092 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1093 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1094 |
+
use_cache: Optional[bool] = None,
|
1095 |
+
output_attentions: Optional[bool] = None,
|
1096 |
+
output_hidden_states: Optional[bool] = None,
|
1097 |
+
return_dict: Optional[bool] = None,
|
1098 |
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
1099 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1100 |
output_hidden_states = (
|
|
|
1164 |
|
1165 |
if config.share_decoder_input_output_embed:
|
1166 |
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
1167 |
+
|
1168 |
self.post_init()
|
1169 |
+
|
1170 |
def tie_weights(self):
|
1171 |
pass
|
1172 |
|
|
|
1183 |
self.lm_head = new_embeddings
|
1184 |
|
1185 |
def forward(
|
1186 |
+
self,
|
1187 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1188 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1189 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1190 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1191 |
+
head_mask: Optional[torch.Tensor] = None,
|
1192 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
1193 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1194 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1195 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1196 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1197 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1198 |
+
labels: Optional[torch.LongTensor] = None,
|
1199 |
+
use_cache: Optional[bool] = None,
|
1200 |
+
output_attentions: Optional[bool] = None,
|
1201 |
+
output_hidden_states: Optional[bool] = None,
|
1202 |
+
return_dict: Optional[bool] = None,
|
1203 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
1204 |
r"""
|
1205 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
1217 |
# labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
1218 |
# )
|
1219 |
|
1220 |
+
decoder_input_ids, decoder_attention_mask, labels = prepare_decoder_input_ids_label(decoder_input_ids,
|
1221 |
+
decoder_attention_mask)
|
1222 |
+
|
1223 |
outputs = self.model(
|
1224 |
input_ids,
|
1225 |
attention_mask=attention_mask,
|
|
|
1263 |
)
|
1264 |
|
1265 |
def prepare_inputs_for_generation(
|
1266 |
+
self,
|
1267 |
+
decoder_input_ids,
|
1268 |
+
past_key_values=None,
|
1269 |
+
attention_mask=None,
|
1270 |
+
head_mask=None,
|
1271 |
+
decoder_head_mask=None,
|
1272 |
+
cross_attn_head_mask=None,
|
1273 |
+
use_cache=None,
|
1274 |
+
encoder_outputs=None,
|
1275 |
+
**kwargs,
|
1276 |
):
|
1277 |
# cut decoder_input_ids if past is used
|
1278 |
if past_key_values is not None:
|