Raghavan commited on
Commit
f65ef53
1 Parent(s): ba500e5

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +116 -104
modeling_indictrans.py CHANGED
@@ -14,7 +14,6 @@
14
  # limitations under the License.
15
  """ PyTorch IndicTrans model."""
16
 
17
-
18
  import math
19
  from typing import List, Optional, Tuple, Union
20
 
@@ -36,12 +35,12 @@ from transformers.modeling_utils import PreTrainedModel
36
 
37
  from .configuration_indictrans import IndicTransConfig
38
 
39
-
40
  logger = logging.get_logger(__name__)
41
 
42
  _CONFIG_FOR_DOC = "IndicTransConfig"
43
 
44
  INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
 
45
 
46
 
47
  # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
@@ -61,9 +60,19 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
61
  return shifted_input_ids
62
 
63
 
 
 
 
 
 
 
 
 
 
 
64
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
65
  def _make_causal_mask(
66
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
67
  ):
68
  """
69
  Make causal mask used for bi-directional self-attention.
@@ -147,7 +156,7 @@ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
147
 
148
  @torch.no_grad()
149
  def forward(
150
- self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0
151
  ):
152
  if input_ids is not None:
153
  bsz, seq_len = input_ids.size()
@@ -189,12 +198,12 @@ class IndicTransAttention(nn.Module):
189
  """Multi-headed attention from 'Attention Is All You Need' paper"""
190
 
191
  def __init__(
192
- self,
193
- embed_dim: int,
194
- num_heads: int,
195
- dropout: float = 0.0,
196
- is_decoder: bool = False,
197
- bias: bool = True,
198
  ):
199
  super().__init__()
200
  self.embed_dim = embed_dim
@@ -207,7 +216,7 @@ class IndicTransAttention(nn.Module):
207
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
208
  f" and `num_heads`: {num_heads})."
209
  )
210
- self.scaling = self.head_dim**-0.5
211
  self.is_decoder = is_decoder
212
 
213
  self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@@ -219,13 +228,13 @@ class IndicTransAttention(nn.Module):
219
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
220
 
221
  def forward(
222
- self,
223
- hidden_states: torch.Tensor,
224
- key_value_states: Optional[torch.Tensor] = None,
225
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
226
- attention_mask: Optional[torch.Tensor] = None,
227
- layer_head_mask: Optional[torch.Tensor] = None,
228
- output_attentions: bool = False,
229
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
230
  """Input shape: Batch x Time x Channel"""
231
 
@@ -242,9 +251,9 @@ class IndicTransAttention(nn.Module):
242
  # is checking that the `sequence_length` of the `past_key_value` is the same as
243
  # the provided `key_value_states` to support prefix tuning
244
  if (
245
- is_cross_attention
246
- and past_key_value is not None
247
- and past_key_value[0].shape[2] == key_value_states.shape[1]
248
  ):
249
  # reuse k,v, cross_attentions
250
  key_states = past_key_value[0]
@@ -359,11 +368,11 @@ class IndicTransEncoderLayer(nn.Module):
359
  self.normalize_before = config.encoder_normalize_before
360
 
361
  def forward(
362
- self,
363
- hidden_states: torch.Tensor,
364
- attention_mask: torch.Tensor,
365
- layer_head_mask: torch.Tensor,
366
- output_attentions: bool = False,
367
  ) -> torch.Tensor:
368
  """
369
  Args:
@@ -402,7 +411,7 @@ class IndicTransEncoderLayer(nn.Module):
402
  hidden_states = self.final_layer_norm(hidden_states)
403
 
404
  if hidden_states.dtype == torch.float16 and (
405
- torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
406
  ):
407
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
408
  hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
@@ -445,16 +454,16 @@ class IndicTransDecoderLayer(nn.Module):
445
  self.normalize_before = config.decoder_normalize_before
446
 
447
  def forward(
448
- self,
449
- hidden_states: torch.Tensor,
450
- attention_mask: Optional[torch.Tensor] = None,
451
- encoder_hidden_states: Optional[torch.Tensor] = None,
452
- encoder_attention_mask: Optional[torch.Tensor] = None,
453
- layer_head_mask: Optional[torch.Tensor] = None,
454
- cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
455
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
456
- output_attentions: Optional[bool] = False,
457
- use_cache: Optional[bool] = True,
458
  ) -> torch.Tensor:
459
  """
460
  Args:
@@ -618,14 +627,14 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
618
  return sentence_embedding.unsqueeze(1)
619
 
620
  def forward(
621
- self,
622
- input_ids: Optional[torch.Tensor] = None,
623
- attention_mask: Optional[torch.Tensor] = None,
624
- head_mask: Optional[torch.Tensor] = None,
625
- inputs_embeds: Optional[torch.Tensor] = None,
626
- output_attentions: Optional[bool] = None,
627
- output_hidden_states: Optional[bool] = None,
628
- return_dict: Optional[bool] = None,
629
  ):
630
  r"""
631
  Args:
@@ -810,19 +819,19 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
810
  self.embed_tokens = value
811
 
812
  def forward(
813
- self,
814
- input_ids: Optional[torch.Tensor] = None,
815
- attention_mask: Optional[torch.Tensor] = None,
816
- encoder_hidden_states: Optional[torch.Tensor] = None,
817
- encoder_attention_mask: Optional[torch.Tensor] = None,
818
- head_mask: Optional[torch.Tensor] = None,
819
- cross_attn_head_mask: Optional[torch.Tensor] = None,
820
- past_key_values: Optional[List[torch.FloatTensor]] = None,
821
- inputs_embeds: Optional[torch.Tensor] = None,
822
- use_cache: Optional[bool] = None,
823
- output_attentions: Optional[bool] = None,
824
- output_hidden_states: Optional[bool] = None,
825
- return_dict: Optional[bool] = None,
826
  ):
827
  r"""
828
  Args:
@@ -1056,7 +1065,7 @@ class IndicTransModel(IndicTransPreTrainedModel):
1056
 
1057
  def __init__(self, config: IndicTransConfig):
1058
  super().__init__(config)
1059
-
1060
  self.encoder = IndicTransEncoder(config)
1061
  self.decoder = IndicTransDecoder(config)
1062
 
@@ -1070,22 +1079,22 @@ class IndicTransModel(IndicTransPreTrainedModel):
1070
  return self.decoder
1071
 
1072
  def forward(
1073
- self,
1074
- input_ids: Optional[torch.LongTensor] = None,
1075
- attention_mask: Optional[torch.Tensor] = None,
1076
- decoder_input_ids: Optional[torch.LongTensor] = None,
1077
- decoder_attention_mask: Optional[torch.LongTensor] = None,
1078
- head_mask: Optional[torch.Tensor] = None,
1079
- decoder_head_mask: Optional[torch.Tensor] = None,
1080
- cross_attn_head_mask: Optional[torch.Tensor] = None,
1081
- encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1082
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1083
- inputs_embeds: Optional[torch.FloatTensor] = None,
1084
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1085
- use_cache: Optional[bool] = None,
1086
- output_attentions: Optional[bool] = None,
1087
- output_hidden_states: Optional[bool] = None,
1088
- return_dict: Optional[bool] = None,
1089
  ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1090
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1091
  output_hidden_states = (
@@ -1155,9 +1164,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1155
 
1156
  if config.share_decoder_input_output_embed:
1157
  self.lm_head.weight = self.model.decoder.embed_tokens.weight
1158
-
1159
  self.post_init()
1160
-
1161
  def tie_weights(self):
1162
  pass
1163
 
@@ -1174,23 +1183,23 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1174
  self.lm_head = new_embeddings
1175
 
1176
  def forward(
1177
- self,
1178
- input_ids: Optional[torch.LongTensor] = None,
1179
- attention_mask: Optional[torch.Tensor] = None,
1180
- decoder_input_ids: Optional[torch.LongTensor] = None,
1181
- decoder_attention_mask: Optional[torch.LongTensor] = None,
1182
- head_mask: Optional[torch.Tensor] = None,
1183
- decoder_head_mask: Optional[torch.Tensor] = None,
1184
- cross_attn_head_mask: Optional[torch.Tensor] = None,
1185
- encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1186
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1187
- inputs_embeds: Optional[torch.FloatTensor] = None,
1188
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1189
- labels: Optional[torch.LongTensor] = None,
1190
- use_cache: Optional[bool] = None,
1191
- output_attentions: Optional[bool] = None,
1192
- output_hidden_states: Optional[bool] = None,
1193
- return_dict: Optional[bool] = None,
1194
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
1195
  r"""
1196
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1208,6 +1217,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1208
  # labels, self.config.pad_token_id, self.config.decoder_start_token_id
1209
  # )
1210
 
 
 
 
1211
  outputs = self.model(
1212
  input_ids,
1213
  attention_mask=attention_mask,
@@ -1251,16 +1263,16 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1251
  )
1252
 
1253
  def prepare_inputs_for_generation(
1254
- self,
1255
- decoder_input_ids,
1256
- past_key_values=None,
1257
- attention_mask=None,
1258
- head_mask=None,
1259
- decoder_head_mask=None,
1260
- cross_attn_head_mask=None,
1261
- use_cache=None,
1262
- encoder_outputs=None,
1263
- **kwargs,
1264
  ):
1265
  # cut decoder_input_ids if past is used
1266
  if past_key_values is not None:
 
14
  # limitations under the License.
15
  """ PyTorch IndicTrans model."""
16
 
 
17
  import math
18
  from typing import List, Optional, Tuple, Union
19
 
 
35
 
36
  from .configuration_indictrans import IndicTransConfig
37
 
 
38
  logger = logging.get_logger(__name__)
39
 
40
  _CONFIG_FOR_DOC = "IndicTransConfig"
41
 
42
  INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
43
+ eos_token_id = 2
44
 
45
 
46
  # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
 
60
  return shifted_input_ids
61
 
62
 
63
+ def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
64
+ mask = (decoder_input_ids == eos_token_id)
65
+ decoder_input_ids[mask] = 1
66
+ decoder_attention_mask[mask] = 0
67
+
68
+ labels = decoder_input_ids[:, 1:]
69
+
70
+ return decoder_input_ids, decoder_attention_mask, labels
71
+
72
+
73
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
74
  def _make_causal_mask(
75
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
76
  ):
77
  """
78
  Make causal mask used for bi-directional self-attention.
 
156
 
157
  @torch.no_grad()
158
  def forward(
159
+ self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0
160
  ):
161
  if input_ids is not None:
162
  bsz, seq_len = input_ids.size()
 
198
  """Multi-headed attention from 'Attention Is All You Need' paper"""
199
 
200
  def __init__(
201
+ self,
202
+ embed_dim: int,
203
+ num_heads: int,
204
+ dropout: float = 0.0,
205
+ is_decoder: bool = False,
206
+ bias: bool = True,
207
  ):
208
  super().__init__()
209
  self.embed_dim = embed_dim
 
216
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
217
  f" and `num_heads`: {num_heads})."
218
  )
219
+ self.scaling = self.head_dim ** -0.5
220
  self.is_decoder = is_decoder
221
 
222
  self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
 
228
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
229
 
230
  def forward(
231
+ self,
232
+ hidden_states: torch.Tensor,
233
+ key_value_states: Optional[torch.Tensor] = None,
234
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
235
+ attention_mask: Optional[torch.Tensor] = None,
236
+ layer_head_mask: Optional[torch.Tensor] = None,
237
+ output_attentions: bool = False,
238
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
239
  """Input shape: Batch x Time x Channel"""
240
 
 
251
  # is checking that the `sequence_length` of the `past_key_value` is the same as
252
  # the provided `key_value_states` to support prefix tuning
253
  if (
254
+ is_cross_attention
255
+ and past_key_value is not None
256
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
257
  ):
258
  # reuse k,v, cross_attentions
259
  key_states = past_key_value[0]
 
368
  self.normalize_before = config.encoder_normalize_before
369
 
370
  def forward(
371
+ self,
372
+ hidden_states: torch.Tensor,
373
+ attention_mask: torch.Tensor,
374
+ layer_head_mask: torch.Tensor,
375
+ output_attentions: bool = False,
376
  ) -> torch.Tensor:
377
  """
378
  Args:
 
411
  hidden_states = self.final_layer_norm(hidden_states)
412
 
413
  if hidden_states.dtype == torch.float16 and (
414
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
415
  ):
416
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
417
  hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
 
454
  self.normalize_before = config.decoder_normalize_before
455
 
456
  def forward(
457
+ self,
458
+ hidden_states: torch.Tensor,
459
+ attention_mask: Optional[torch.Tensor] = None,
460
+ encoder_hidden_states: Optional[torch.Tensor] = None,
461
+ encoder_attention_mask: Optional[torch.Tensor] = None,
462
+ layer_head_mask: Optional[torch.Tensor] = None,
463
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
464
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
465
+ output_attentions: Optional[bool] = False,
466
+ use_cache: Optional[bool] = True,
467
  ) -> torch.Tensor:
468
  """
469
  Args:
 
627
  return sentence_embedding.unsqueeze(1)
628
 
629
  def forward(
630
+ self,
631
+ input_ids: Optional[torch.Tensor] = None,
632
+ attention_mask: Optional[torch.Tensor] = None,
633
+ head_mask: Optional[torch.Tensor] = None,
634
+ inputs_embeds: Optional[torch.Tensor] = None,
635
+ output_attentions: Optional[bool] = None,
636
+ output_hidden_states: Optional[bool] = None,
637
+ return_dict: Optional[bool] = None,
638
  ):
639
  r"""
640
  Args:
 
819
  self.embed_tokens = value
820
 
821
  def forward(
822
+ self,
823
+ input_ids: Optional[torch.Tensor] = None,
824
+ attention_mask: Optional[torch.Tensor] = None,
825
+ encoder_hidden_states: Optional[torch.Tensor] = None,
826
+ encoder_attention_mask: Optional[torch.Tensor] = None,
827
+ head_mask: Optional[torch.Tensor] = None,
828
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
829
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
830
+ inputs_embeds: Optional[torch.Tensor] = None,
831
+ use_cache: Optional[bool] = None,
832
+ output_attentions: Optional[bool] = None,
833
+ output_hidden_states: Optional[bool] = None,
834
+ return_dict: Optional[bool] = None,
835
  ):
836
  r"""
837
  Args:
 
1065
 
1066
  def __init__(self, config: IndicTransConfig):
1067
  super().__init__(config)
1068
+
1069
  self.encoder = IndicTransEncoder(config)
1070
  self.decoder = IndicTransDecoder(config)
1071
 
 
1079
  return self.decoder
1080
 
1081
  def forward(
1082
+ self,
1083
+ input_ids: Optional[torch.LongTensor] = None,
1084
+ attention_mask: Optional[torch.Tensor] = None,
1085
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1086
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1087
+ head_mask: Optional[torch.Tensor] = None,
1088
+ decoder_head_mask: Optional[torch.Tensor] = None,
1089
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1090
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1091
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1092
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1093
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1094
+ use_cache: Optional[bool] = None,
1095
+ output_attentions: Optional[bool] = None,
1096
+ output_hidden_states: Optional[bool] = None,
1097
+ return_dict: Optional[bool] = None,
1098
  ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1099
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1100
  output_hidden_states = (
 
1164
 
1165
  if config.share_decoder_input_output_embed:
1166
  self.lm_head.weight = self.model.decoder.embed_tokens.weight
1167
+
1168
  self.post_init()
1169
+
1170
  def tie_weights(self):
1171
  pass
1172
 
 
1183
  self.lm_head = new_embeddings
1184
 
1185
  def forward(
1186
+ self,
1187
+ input_ids: Optional[torch.LongTensor] = None,
1188
+ attention_mask: Optional[torch.Tensor] = None,
1189
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1190
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1191
+ head_mask: Optional[torch.Tensor] = None,
1192
+ decoder_head_mask: Optional[torch.Tensor] = None,
1193
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1194
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1195
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1196
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1197
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1198
+ labels: Optional[torch.LongTensor] = None,
1199
+ use_cache: Optional[bool] = None,
1200
+ output_attentions: Optional[bool] = None,
1201
+ output_hidden_states: Optional[bool] = None,
1202
+ return_dict: Optional[bool] = None,
1203
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
1204
  r"""
1205
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
1217
  # labels, self.config.pad_token_id, self.config.decoder_start_token_id
1218
  # )
1219
 
1220
+ decoder_input_ids, decoder_attention_mask, labels = prepare_decoder_input_ids_label(decoder_input_ids,
1221
+ decoder_attention_mask)
1222
+
1223
  outputs = self.model(
1224
  input_ids,
1225
  attention_mask=attention_mask,
 
1263
  )
1264
 
1265
  def prepare_inputs_for_generation(
1266
+ self,
1267
+ decoder_input_ids,
1268
+ past_key_values=None,
1269
+ attention_mask=None,
1270
+ head_mask=None,
1271
+ decoder_head_mask=None,
1272
+ cross_attn_head_mask=None,
1273
+ use_cache=None,
1274
+ encoder_outputs=None,
1275
+ **kwargs,
1276
  ):
1277
  # cut decoder_input_ids if past is used
1278
  if past_key_values is not None: