alaeddine-13 commited on
Commit
e1b325c
1 Parent(s): c41d17d

feat: add dummy mask to validate sliding window

Browse files
Files changed (1) hide show
  1. modeling_bert.py +236 -214
modeling_bert.py CHANGED
@@ -16,7 +16,6 @@
16
  # limitations under the License.
17
  """PyTorch BERT model."""
18
 
19
-
20
  import math
21
  import os
22
  import warnings
@@ -96,6 +95,15 @@ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
96
  _SEQ_CLASS_EXPECTED_LOSS = 0.01
97
 
98
 
 
 
 
 
 
 
 
 
 
99
  def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
100
  """Load tf checkpoints in a pytorch model."""
101
  try:
@@ -126,15 +134,15 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
126
  # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
127
  # which are not required for using pretrained model
128
  if any(
129
- n
130
- in [
131
- "adam_v",
132
- "adam_m",
133
- "AdamWeightDecayOptimizer",
134
- "AdamWeightDecayOptimizer_1",
135
- "global_step",
136
- ]
137
- for n in name
138
  ):
139
  logger.info(f"Skipping {'/'.join(name)}")
140
  continue
@@ -214,12 +222,12 @@ class JinaBertEmbeddings(nn.Module):
214
  )
215
 
216
  def forward(
217
- self,
218
- input_ids: Optional[torch.LongTensor] = None,
219
- token_type_ids: Optional[torch.LongTensor] = None,
220
- position_ids: Optional[torch.LongTensor] = None,
221
- inputs_embeds: Optional[torch.FloatTensor] = None,
222
- past_key_values_length: int = 0,
223
  ) -> torch.Tensor:
224
  if input_ids is not None:
225
  input_shape = input_ids.size()
@@ -230,8 +238,8 @@ class JinaBertEmbeddings(nn.Module):
230
 
231
  if position_ids is None:
232
  position_ids = self.position_ids[
233
- :, past_key_values_length : seq_length + past_key_values_length
234
- ]
235
 
236
  # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
237
  # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
@@ -265,13 +273,13 @@ class JinaBertSelfAttention(nn.Module):
265
  def __init__(self, config: JinaBertConfig, position_embedding_type=None):
266
  super().__init__()
267
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
268
- config, "embedding_size"
269
  ):
270
  raise ValueError(
271
  f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
272
  f"heads ({config.num_attention_heads})"
273
  )
274
-
275
  self.attn_implementation = config.attn_implementation
276
  self.num_attention_heads = config.num_attention_heads
277
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
@@ -286,8 +294,8 @@ class JinaBertSelfAttention(nn.Module):
286
  config, "position_embedding_type", "absolute"
287
  )
288
  if (
289
- self.position_embedding_type == "relative_key"
290
- or self.position_embedding_type == "relative_key_query"
291
  ):
292
  self.max_position_embeddings = config.max_position_embeddings
293
  self.distance_embedding = nn.Embedding(
@@ -305,15 +313,16 @@ class JinaBertSelfAttention(nn.Module):
305
  return x.permute(0, 2, 1, 3)
306
 
307
  def forward(
308
- self,
309
- hidden_states: torch.Tensor,
310
- attention_mask: Optional[torch.FloatTensor] = None,
311
- head_mask: Optional[torch.FloatTensor] = None,
312
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
313
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
314
- past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
315
- output_attentions: Optional[bool] = False,
316
- bias: Optional[torch.FloatTensor] = None,
 
317
  ) -> Tuple[torch.Tensor]:
318
  mixed_query_layer = self.query(hidden_states)
319
 
@@ -364,8 +373,8 @@ class JinaBertSelfAttention(nn.Module):
364
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
365
 
366
  if (
367
- self.position_embedding_type == "relative_key"
368
- or self.position_embedding_type == "relative_key_query"
369
  ):
370
  query_length, key_length = query_layer.shape[2], key_layer.shape[2]
371
  if use_cache:
@@ -401,9 +410,9 @@ class JinaBertSelfAttention(nn.Module):
401
  "bhrd,lrd->bhlr", key_layer, positional_embedding
402
  )
403
  attention_scores = (
404
- attention_scores
405
- + relative_position_scores_query
406
- + relative_position_scores_key
407
  )
408
 
409
  attention_scores = attention_scores / math.sqrt(self.attention_head_size)
@@ -414,6 +423,10 @@ class JinaBertSelfAttention(nn.Module):
414
  # Normalize the attention scores to probabilities.
415
  attention_probs = nn.functional.softmax(attention_scores + bias, dim=-1)
416
 
 
 
 
 
417
  # This is actually dropping out entire tokens to attend to, which might
418
  # seem a bit unusual, but is taken from the original Transformer paper.
419
  attention_probs = self.dropout(attention_probs)
@@ -445,7 +458,7 @@ class JinaBertSelfOutput(nn.Module):
445
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
446
 
447
  def forward(
448
- self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
449
  ) -> torch.Tensor:
450
  hidden_states = self.dense(hidden_states)
451
  hidden_states = self.dropout(hidden_states)
@@ -481,20 +494,21 @@ class JinaBertAttention(nn.Module):
481
  # Update hyper params and store pruned heads
482
  self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
483
  self.self.all_head_size = (
484
- self.self.attention_head_size * self.self.num_attention_heads
485
  )
486
  self.pruned_heads = self.pruned_heads.union(heads)
487
 
488
  def forward(
489
- self,
490
- hidden_states: torch.Tensor,
491
- attention_mask: Optional[torch.FloatTensor] = None,
492
- head_mask: Optional[torch.FloatTensor] = None,
493
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
494
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
495
- past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
496
- output_attentions: Optional[bool] = False,
497
- bias: Optional[torch.FloatTensor] = None,
 
498
  ) -> Tuple[torch.Tensor]:
499
  self_outputs = self.self(
500
  hidden_states,
@@ -505,11 +519,12 @@ class JinaBertAttention(nn.Module):
505
  past_key_value,
506
  output_attentions,
507
  bias,
 
508
  )
509
  attention_output = self.output(self_outputs[0], hidden_states)
510
  outputs = (attention_output,) + self_outputs[
511
- 1:
512
- ] # add attentions if we output them
513
  return outputs
514
 
515
 
@@ -536,7 +551,7 @@ class JinaBertOutput(nn.Module):
536
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
537
 
538
  def forward(
539
- self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
540
  ) -> torch.Tensor:
541
  hidden_states = self.dense(hidden_states)
542
  hidden_states = self.dropout(hidden_states)
@@ -568,7 +583,7 @@ class JinaBertGLUMLP(nn.Module):
568
  # compute the activation
569
  hidden_states = self.gated_layers(hidden_states)
570
  gated = hidden_states[:, :, : self.config.intermediate_size]
571
- non_gated = hidden_states[:, :, self.config.intermediate_size :]
572
  hidden_states = self.act(gated) * non_gated
573
  hidden_states = self.dropout(hidden_states)
574
  # multiply by the second matrix
@@ -602,15 +617,16 @@ class JinaBertLayer(nn.Module):
602
  self.output = JinaBertOutput(config)
603
 
604
  def forward(
605
- self,
606
- hidden_states: torch.Tensor,
607
- attention_mask: Optional[torch.FloatTensor] = None,
608
- head_mask: Optional[torch.FloatTensor] = None,
609
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
610
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
611
- bias: Optional[torch.FloatTensor] = None,
612
- past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
613
- output_attentions: Optional[bool] = False,
 
614
  ) -> Tuple[torch.Tensor]:
615
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
616
  self_attn_past_key_value = (
@@ -623,6 +639,7 @@ class JinaBertLayer(nn.Module):
623
  output_attentions=output_attentions,
624
  past_key_value=self_attn_past_key_value,
625
  bias=bias,
 
626
  )
627
  attention_output = self_attention_outputs[0]
628
 
@@ -632,8 +649,8 @@ class JinaBertLayer(nn.Module):
632
  present_key_value = self_attention_outputs[-1]
633
  else:
634
  outputs = self_attention_outputs[
635
- 1:
636
- ] # add self attentions if we output attention weights
637
 
638
  cross_attn_present_key_value = None
639
  if self.is_decoder and encoder_hidden_states is not None:
@@ -658,7 +675,7 @@ class JinaBertLayer(nn.Module):
658
  )
659
  attention_output = cross_attention_outputs[0]
660
  outputs = (
661
- outputs + cross_attention_outputs[1:-1]
662
  ) # add cross attentions if we output attention weights
663
 
664
  # add cross-attn cache to positions 3,4 of present_key_value tuple
@@ -704,7 +721,7 @@ class JinaBertEncoder(nn.Module):
704
  )
705
 
706
  def rebuild_alibi_tensor(
707
- self, size: int, device: Optional[Union[torch.device, str]] = None
708
  ):
709
  # Alibi
710
  # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
@@ -717,7 +734,7 @@ class JinaBertEncoder(nn.Module):
717
  def get_slopes_power_of_2(n):
718
  start = 2 ** (-(2 ** -(math.log2(n) - 3)))
719
  ratio = start
720
- return [start * ratio**i for i in range(n)]
721
 
722
  if math.log2(n_heads).is_integer():
723
  return get_slopes_power_of_2(
@@ -728,10 +745,10 @@ class JinaBertEncoder(nn.Module):
728
  math.log2(n_heads)
729
  ) # when the number of heads is not a power of 2, we use this workaround.
730
  return (
731
- get_slopes_power_of_2(closest_power_of_2)
732
- + _get_alibi_head_slopes(2 * closest_power_of_2)[0::2][
733
- : n_heads - closest_power_of_2
734
- ]
735
  )
736
 
737
  context_position = torch.arange(size, device=device)[:, None]
@@ -749,17 +766,18 @@ class JinaBertEncoder(nn.Module):
749
  return alibi
750
 
751
  def forward(
752
- self,
753
- hidden_states: torch.Tensor,
754
- attention_mask: Optional[torch.FloatTensor] = None,
755
- head_mask: Optional[torch.FloatTensor] = None,
756
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
757
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
758
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
759
- use_cache: Optional[bool] = None,
760
- output_attentions: Optional[bool] = False,
761
- output_hidden_states: Optional[bool] = False,
762
- return_dict: Optional[bool] = True,
 
763
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
764
  all_hidden_states = () if output_hidden_states else None
765
  all_self_attentions = () if output_attentions else None
@@ -828,6 +846,7 @@ class JinaBertEncoder(nn.Module):
828
  alibi_bias,
829
  past_key_value,
830
  output_attentions,
 
831
  )
832
 
833
  hidden_states = layer_outputs[0]
@@ -1117,16 +1136,17 @@ class JinaBertModel(JinaBertPreTrainedModel):
1117
 
1118
  @torch.inference_mode()
1119
  def encode(
1120
- self: 'JinaBertModel',
1121
- sentences: Union[str, List[str]],
1122
- batch_size: int = 32,
1123
- show_progress_bar: Optional[bool] = None,
1124
- output_value: str = 'sentence_embedding',
1125
- convert_to_numpy: bool = True,
1126
- convert_to_tensor: bool = False,
1127
- device: Optional[torch.device] = None,
1128
- normalize_embeddings: bool = False,
1129
- **tokenizer_kwargs,
 
1130
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
1131
  """
1132
  Computes sentence embeddings
@@ -1172,8 +1192,8 @@ class JinaBertModel(JinaBertPreTrainedModel):
1172
 
1173
  if show_progress_bar is None:
1174
  show_progress_bar = (
1175
- logger.getEffectiveLevel() == logging.INFO
1176
- or logger.getEffectiveLevel() == logging.DEBUG
1177
  )
1178
 
1179
  if convert_to_tensor:
@@ -1215,11 +1235,11 @@ class JinaBertModel(JinaBertPreTrainedModel):
1215
 
1216
  for i in range_iter:
1217
  encoded_input = self.tokenizer(
1218
- sentences[i : i + batch_size],
1219
  return_tensors='pt',
1220
  **tokenizer_kwargs,
1221
  ).to(self.device)
1222
- token_embs = self.forward(**encoded_input)[0]
1223
 
1224
  # Accumulate in fp32 to avoid overflow
1225
  token_embs = token_embs.float()
@@ -1254,7 +1274,7 @@ class JinaBertModel(JinaBertPreTrainedModel):
1254
  return all_embeddings
1255
 
1256
  def mean_pooling(
1257
- self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
1258
  ):
1259
  input_mask_expanded = (
1260
  attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
@@ -1286,20 +1306,21 @@ class JinaBertModel(JinaBertPreTrainedModel):
1286
  config_class=_CONFIG_FOR_DOC,
1287
  )
1288
  def forward(
1289
- self,
1290
- input_ids: Optional[torch.Tensor] = None,
1291
- attention_mask: Optional[torch.Tensor] = None,
1292
- token_type_ids: Optional[torch.Tensor] = None,
1293
- position_ids: Optional[torch.Tensor] = None,
1294
- head_mask: Optional[torch.Tensor] = None,
1295
- inputs_embeds: Optional[torch.Tensor] = None,
1296
- encoder_hidden_states: Optional[torch.Tensor] = None,
1297
- encoder_attention_mask: Optional[torch.Tensor] = None,
1298
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1299
- use_cache: Optional[bool] = None,
1300
- output_attentions: Optional[bool] = None,
1301
- output_hidden_states: Optional[bool] = None,
1302
- return_dict: Optional[bool] = None,
 
1303
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1304
  r"""
1305
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -1425,6 +1446,7 @@ class JinaBertModel(JinaBertPreTrainedModel):
1425
  output_attentions=output_attentions,
1426
  output_hidden_states=output_hidden_states,
1427
  return_dict=return_dict,
 
1428
  )
1429
  sequence_output = encoder_outputs[0]
1430
  pooled_output = (
@@ -1476,18 +1498,18 @@ class JinaBertForPreTraining(JinaBertPreTrainedModel):
1476
  output_type=JinaBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
1477
  )
1478
  def forward(
1479
- self,
1480
- input_ids: Optional[torch.Tensor] = None,
1481
- attention_mask: Optional[torch.Tensor] = None,
1482
- token_type_ids: Optional[torch.Tensor] = None,
1483
- position_ids: Optional[torch.Tensor] = None,
1484
- head_mask: Optional[torch.Tensor] = None,
1485
- inputs_embeds: Optional[torch.Tensor] = None,
1486
- labels: Optional[torch.Tensor] = None,
1487
- next_sentence_label: Optional[torch.Tensor] = None,
1488
- output_attentions: Optional[bool] = None,
1489
- output_hidden_states: Optional[bool] = None,
1490
- return_dict: Optional[bool] = None,
1491
  ) -> Union[Tuple[torch.Tensor], JinaBertForPreTrainingOutput]:
1492
  r"""
1493
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1586,21 +1608,21 @@ class JinaBertLMHeadModel(JinaBertPreTrainedModel):
1586
  config_class=_CONFIG_FOR_DOC,
1587
  )
1588
  def forward(
1589
- self,
1590
- input_ids: Optional[torch.Tensor] = None,
1591
- attention_mask: Optional[torch.Tensor] = None,
1592
- token_type_ids: Optional[torch.Tensor] = None,
1593
- position_ids: Optional[torch.Tensor] = None,
1594
- head_mask: Optional[torch.Tensor] = None,
1595
- inputs_embeds: Optional[torch.Tensor] = None,
1596
- encoder_hidden_states: Optional[torch.Tensor] = None,
1597
- encoder_attention_mask: Optional[torch.Tensor] = None,
1598
- labels: Optional[torch.Tensor] = None,
1599
- past_key_values: Optional[List[torch.Tensor]] = None,
1600
- use_cache: Optional[bool] = None,
1601
- output_attentions: Optional[bool] = None,
1602
- output_hidden_states: Optional[bool] = None,
1603
- return_dict: Optional[bool] = None,
1604
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1605
  r"""
1606
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -1676,12 +1698,12 @@ class JinaBertLMHeadModel(JinaBertPreTrainedModel):
1676
  )
1677
 
1678
  def prepare_inputs_for_generation(
1679
- self,
1680
- input_ids,
1681
- past_key_values=None,
1682
- attention_mask=None,
1683
- use_cache=True,
1684
- **model_kwargs,
1685
  ):
1686
  input_shape = input_ids.shape
1687
  # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
@@ -1748,19 +1770,19 @@ class JinaBertForMaskedLM(JinaBertPreTrainedModel):
1748
  expected_loss=0.88,
1749
  )
1750
  def forward(
1751
- self,
1752
- input_ids: Optional[torch.Tensor] = None,
1753
- attention_mask: Optional[torch.Tensor] = None,
1754
- token_type_ids: Optional[torch.Tensor] = None,
1755
- position_ids: Optional[torch.Tensor] = None,
1756
- head_mask: Optional[torch.Tensor] = None,
1757
- inputs_embeds: Optional[torch.Tensor] = None,
1758
- encoder_hidden_states: Optional[torch.Tensor] = None,
1759
- encoder_attention_mask: Optional[torch.Tensor] = None,
1760
- labels: Optional[torch.Tensor] = None,
1761
- output_attentions: Optional[bool] = None,
1762
- output_hidden_states: Optional[bool] = None,
1763
- return_dict: Optional[bool] = None,
1764
  ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1765
  r"""
1766
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1811,7 +1833,7 @@ class JinaBertForMaskedLM(JinaBertPreTrainedModel):
1811
  )
1812
 
1813
  def prepare_inputs_for_generation(
1814
- self, input_ids, attention_mask=None, **model_kwargs
1815
  ):
1816
  input_shape = input_ids.shape
1817
  effective_batch_size = input_shape[0]
@@ -1856,18 +1878,18 @@ class JinaBertForNextSentencePrediction(JinaBertPreTrainedModel):
1856
  output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
1857
  )
1858
  def forward(
1859
- self,
1860
- input_ids: Optional[torch.Tensor] = None,
1861
- attention_mask: Optional[torch.Tensor] = None,
1862
- token_type_ids: Optional[torch.Tensor] = None,
1863
- position_ids: Optional[torch.Tensor] = None,
1864
- head_mask: Optional[torch.Tensor] = None,
1865
- inputs_embeds: Optional[torch.Tensor] = None,
1866
- labels: Optional[torch.Tensor] = None,
1867
- output_attentions: Optional[bool] = None,
1868
- output_hidden_states: Optional[bool] = None,
1869
- return_dict: Optional[bool] = None,
1870
- **kwargs,
1871
  ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
1872
  r"""
1873
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1967,17 +1989,17 @@ class JinaBertForSequenceClassification(JinaBertPreTrainedModel):
1967
  expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1968
  )
1969
  def forward(
1970
- self,
1971
- input_ids: Optional[torch.Tensor] = None,
1972
- attention_mask: Optional[torch.Tensor] = None,
1973
- token_type_ids: Optional[torch.Tensor] = None,
1974
- position_ids: Optional[torch.Tensor] = None,
1975
- head_mask: Optional[torch.Tensor] = None,
1976
- inputs_embeds: Optional[torch.Tensor] = None,
1977
- labels: Optional[torch.Tensor] = None,
1978
- output_attentions: Optional[bool] = None,
1979
- output_hidden_states: Optional[bool] = None,
1980
- return_dict: Optional[bool] = None,
1981
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1982
  r"""
1983
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -2012,7 +2034,7 @@ class JinaBertForSequenceClassification(JinaBertPreTrainedModel):
2012
  if self.num_labels == 1:
2013
  self.config.problem_type = "regression"
2014
  elif self.num_labels > 1 and (
2015
- labels.dtype == torch.long or labels.dtype == torch.int
2016
  ):
2017
  self.config.problem_type = "single_label_classification"
2018
  else:
@@ -2074,17 +2096,17 @@ class JinaBertForMultipleChoice(JinaBertPreTrainedModel):
2074
  config_class=_CONFIG_FOR_DOC,
2075
  )
2076
  def forward(
2077
- self,
2078
- input_ids: Optional[torch.Tensor] = None,
2079
- attention_mask: Optional[torch.Tensor] = None,
2080
- token_type_ids: Optional[torch.Tensor] = None,
2081
- position_ids: Optional[torch.Tensor] = None,
2082
- head_mask: Optional[torch.Tensor] = None,
2083
- inputs_embeds: Optional[torch.Tensor] = None,
2084
- labels: Optional[torch.Tensor] = None,
2085
- output_attentions: Optional[bool] = None,
2086
- output_hidden_states: Optional[bool] = None,
2087
- return_dict: Optional[bool] = None,
2088
  ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
2089
  r"""
2090
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -2193,17 +2215,17 @@ class JinaBertForTokenClassification(JinaBertPreTrainedModel):
2193
  expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
2194
  )
2195
  def forward(
2196
- self,
2197
- input_ids: Optional[torch.Tensor] = None,
2198
- attention_mask: Optional[torch.Tensor] = None,
2199
- token_type_ids: Optional[torch.Tensor] = None,
2200
- position_ids: Optional[torch.Tensor] = None,
2201
- head_mask: Optional[torch.Tensor] = None,
2202
- inputs_embeds: Optional[torch.Tensor] = None,
2203
- labels: Optional[torch.Tensor] = None,
2204
- output_attentions: Optional[bool] = None,
2205
- output_hidden_states: Optional[bool] = None,
2206
- return_dict: Optional[bool] = None,
2207
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
2208
  r"""
2209
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -2278,18 +2300,18 @@ class JinaBertForQuestionAnswering(JinaBertPreTrainedModel):
2278
  expected_loss=_QA_EXPECTED_LOSS,
2279
  )
2280
  def forward(
2281
- self,
2282
- input_ids: Optional[torch.Tensor] = None,
2283
- attention_mask: Optional[torch.Tensor] = None,
2284
- token_type_ids: Optional[torch.Tensor] = None,
2285
- position_ids: Optional[torch.Tensor] = None,
2286
- head_mask: Optional[torch.Tensor] = None,
2287
- inputs_embeds: Optional[torch.Tensor] = None,
2288
- start_positions: Optional[torch.Tensor] = None,
2289
- end_positions: Optional[torch.Tensor] = None,
2290
- output_attentions: Optional[bool] = None,
2291
- output_hidden_states: Optional[bool] = None,
2292
- return_dict: Optional[bool] = None,
2293
  ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
2294
  r"""
2295
  start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
16
  # limitations under the License.
17
  """PyTorch BERT model."""
18
 
 
19
  import math
20
  import os
21
  import warnings
 
95
  _SEQ_CLASS_EXPECTED_LOSS = 0.01
96
 
97
 
98
+ def create_k_diag_mask(k, n):
99
+ mask = torch.zeros(n, n, dtype=bool)
100
+ for i in range(n):
101
+ for j in range(n):
102
+ if not math.fabs(i - j) < k:
103
+ mask[i, j] = True
104
+ return mask
105
+
106
+
107
  def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
108
  """Load tf checkpoints in a pytorch model."""
109
  try:
 
134
  # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
135
  # which are not required for using pretrained model
136
  if any(
137
+ n
138
+ in [
139
+ "adam_v",
140
+ "adam_m",
141
+ "AdamWeightDecayOptimizer",
142
+ "AdamWeightDecayOptimizer_1",
143
+ "global_step",
144
+ ]
145
+ for n in name
146
  ):
147
  logger.info(f"Skipping {'/'.join(name)}")
148
  continue
 
222
  )
223
 
224
  def forward(
225
+ self,
226
+ input_ids: Optional[torch.LongTensor] = None,
227
+ token_type_ids: Optional[torch.LongTensor] = None,
228
+ position_ids: Optional[torch.LongTensor] = None,
229
+ inputs_embeds: Optional[torch.FloatTensor] = None,
230
+ past_key_values_length: int = 0,
231
  ) -> torch.Tensor:
232
  if input_ids is not None:
233
  input_shape = input_ids.size()
 
238
 
239
  if position_ids is None:
240
  position_ids = self.position_ids[
241
+ :, past_key_values_length: seq_length + past_key_values_length
242
+ ]
243
 
244
  # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
245
  # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
 
273
  def __init__(self, config: JinaBertConfig, position_embedding_type=None):
274
  super().__init__()
275
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
276
+ config, "embedding_size"
277
  ):
278
  raise ValueError(
279
  f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
280
  f"heads ({config.num_attention_heads})"
281
  )
282
+
283
  self.attn_implementation = config.attn_implementation
284
  self.num_attention_heads = config.num_attention_heads
285
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
 
294
  config, "position_embedding_type", "absolute"
295
  )
296
  if (
297
+ self.position_embedding_type == "relative_key"
298
+ or self.position_embedding_type == "relative_key_query"
299
  ):
300
  self.max_position_embeddings = config.max_position_embeddings
301
  self.distance_embedding = nn.Embedding(
 
313
  return x.permute(0, 2, 1, 3)
314
 
315
  def forward(
316
+ self,
317
+ hidden_states: torch.Tensor,
318
+ attention_mask: Optional[torch.FloatTensor] = None,
319
+ head_mask: Optional[torch.FloatTensor] = None,
320
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
321
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
322
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
323
+ output_attentions: Optional[bool] = False,
324
+ bias: Optional[torch.FloatTensor] = None,
325
+ sliding_window: Optional[int] = None,
326
  ) -> Tuple[torch.Tensor]:
327
  mixed_query_layer = self.query(hidden_states)
328
 
 
373
  attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
374
 
375
  if (
376
+ self.position_embedding_type == "relative_key"
377
+ or self.position_embedding_type == "relative_key_query"
378
  ):
379
  query_length, key_length = query_layer.shape[2], key_layer.shape[2]
380
  if use_cache:
 
410
  "bhrd,lrd->bhlr", key_layer, positional_embedding
411
  )
412
  attention_scores = (
413
+ attention_scores
414
+ + relative_position_scores_query
415
+ + relative_position_scores_key
416
  )
417
 
418
  attention_scores = attention_scores / math.sqrt(self.attention_head_size)
 
423
  # Normalize the attention scores to probabilities.
424
  attention_probs = nn.functional.softmax(attention_scores + bias, dim=-1)
425
 
426
+ if sliding_window is not None:
427
+ mask = create_k_diag_mask(sliding_window, int(attention_scores.size(dim=2)))
428
+ attention_probs.masked_fill_(mask, 0)
429
+
430
  # This is actually dropping out entire tokens to attend to, which might
431
  # seem a bit unusual, but is taken from the original Transformer paper.
432
  attention_probs = self.dropout(attention_probs)
 
458
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
459
 
460
  def forward(
461
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
462
  ) -> torch.Tensor:
463
  hidden_states = self.dense(hidden_states)
464
  hidden_states = self.dropout(hidden_states)
 
494
  # Update hyper params and store pruned heads
495
  self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
496
  self.self.all_head_size = (
497
+ self.self.attention_head_size * self.self.num_attention_heads
498
  )
499
  self.pruned_heads = self.pruned_heads.union(heads)
500
 
501
  def forward(
502
+ self,
503
+ hidden_states: torch.Tensor,
504
+ attention_mask: Optional[torch.FloatTensor] = None,
505
+ head_mask: Optional[torch.FloatTensor] = None,
506
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
507
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
508
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
509
+ output_attentions: Optional[bool] = False,
510
+ bias: Optional[torch.FloatTensor] = None,
511
+ sliding_window: Optional[int] = None,
512
  ) -> Tuple[torch.Tensor]:
513
  self_outputs = self.self(
514
  hidden_states,
 
519
  past_key_value,
520
  output_attentions,
521
  bias,
522
+ sliding_window=sliding_window
523
  )
524
  attention_output = self.output(self_outputs[0], hidden_states)
525
  outputs = (attention_output,) + self_outputs[
526
+ 1:
527
+ ] # add attentions if we output them
528
  return outputs
529
 
530
 
 
551
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
552
 
553
  def forward(
554
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
555
  ) -> torch.Tensor:
556
  hidden_states = self.dense(hidden_states)
557
  hidden_states = self.dropout(hidden_states)
 
583
  # compute the activation
584
  hidden_states = self.gated_layers(hidden_states)
585
  gated = hidden_states[:, :, : self.config.intermediate_size]
586
+ non_gated = hidden_states[:, :, self.config.intermediate_size:]
587
  hidden_states = self.act(gated) * non_gated
588
  hidden_states = self.dropout(hidden_states)
589
  # multiply by the second matrix
 
617
  self.output = JinaBertOutput(config)
618
 
619
  def forward(
620
+ self,
621
+ hidden_states: torch.Tensor,
622
+ attention_mask: Optional[torch.FloatTensor] = None,
623
+ head_mask: Optional[torch.FloatTensor] = None,
624
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
625
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
626
+ bias: Optional[torch.FloatTensor] = None,
627
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
628
+ output_attentions: Optional[bool] = False,
629
+ sliding_window: Optional[int] = None,
630
  ) -> Tuple[torch.Tensor]:
631
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
632
  self_attn_past_key_value = (
 
639
  output_attentions=output_attentions,
640
  past_key_value=self_attn_past_key_value,
641
  bias=bias,
642
+ sliding_window=sliding_window
643
  )
644
  attention_output = self_attention_outputs[0]
645
 
 
649
  present_key_value = self_attention_outputs[-1]
650
  else:
651
  outputs = self_attention_outputs[
652
+ 1:
653
+ ] # add self attentions if we output attention weights
654
 
655
  cross_attn_present_key_value = None
656
  if self.is_decoder and encoder_hidden_states is not None:
 
675
  )
676
  attention_output = cross_attention_outputs[0]
677
  outputs = (
678
+ outputs + cross_attention_outputs[1:-1]
679
  ) # add cross attentions if we output attention weights
680
 
681
  # add cross-attn cache to positions 3,4 of present_key_value tuple
 
721
  )
722
 
723
  def rebuild_alibi_tensor(
724
+ self, size: int, device: Optional[Union[torch.device, str]] = None
725
  ):
726
  # Alibi
727
  # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
 
734
  def get_slopes_power_of_2(n):
735
  start = 2 ** (-(2 ** -(math.log2(n) - 3)))
736
  ratio = start
737
+ return [start * ratio ** i for i in range(n)]
738
 
739
  if math.log2(n_heads).is_integer():
740
  return get_slopes_power_of_2(
 
745
  math.log2(n_heads)
746
  ) # when the number of heads is not a power of 2, we use this workaround.
747
  return (
748
+ get_slopes_power_of_2(closest_power_of_2)
749
+ + _get_alibi_head_slopes(2 * closest_power_of_2)[0::2][
750
+ : n_heads - closest_power_of_2
751
+ ]
752
  )
753
 
754
  context_position = torch.arange(size, device=device)[:, None]
 
766
  return alibi
767
 
768
  def forward(
769
+ self,
770
+ hidden_states: torch.Tensor,
771
+ attention_mask: Optional[torch.FloatTensor] = None,
772
+ head_mask: Optional[torch.FloatTensor] = None,
773
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
774
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
775
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
776
+ use_cache: Optional[bool] = None,
777
+ output_attentions: Optional[bool] = False,
778
+ output_hidden_states: Optional[bool] = False,
779
+ return_dict: Optional[bool] = True,
780
+ sliding_window: Optional[int] = None,
781
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
782
  all_hidden_states = () if output_hidden_states else None
783
  all_self_attentions = () if output_attentions else None
 
846
  alibi_bias,
847
  past_key_value,
848
  output_attentions,
849
+ sliding_window
850
  )
851
 
852
  hidden_states = layer_outputs[0]
 
1136
 
1137
  @torch.inference_mode()
1138
  def encode(
1139
+ self: 'JinaBertModel',
1140
+ sentences: Union[str, List[str]],
1141
+ batch_size: int = 32,
1142
+ show_progress_bar: Optional[bool] = None,
1143
+ output_value: str = 'sentence_embedding',
1144
+ convert_to_numpy: bool = True,
1145
+ convert_to_tensor: bool = False,
1146
+ device: Optional[torch.device] = None,
1147
+ normalize_embeddings: bool = False,
1148
+ sliding_window: Optional[int] = None,
1149
+ **tokenizer_kwargs,
1150
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
1151
  """
1152
  Computes sentence embeddings
 
1192
 
1193
  if show_progress_bar is None:
1194
  show_progress_bar = (
1195
+ logger.getEffectiveLevel() == logging.INFO
1196
+ or logger.getEffectiveLevel() == logging.DEBUG
1197
  )
1198
 
1199
  if convert_to_tensor:
 
1235
 
1236
  for i in range_iter:
1237
  encoded_input = self.tokenizer(
1238
+ sentences[i: i + batch_size],
1239
  return_tensors='pt',
1240
  **tokenizer_kwargs,
1241
  ).to(self.device)
1242
+ token_embs = self.forward(sliding_window=sliding_window, **encoded_input)[0]
1243
 
1244
  # Accumulate in fp32 to avoid overflow
1245
  token_embs = token_embs.float()
 
1274
  return all_embeddings
1275
 
1276
  def mean_pooling(
1277
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
1278
  ):
1279
  input_mask_expanded = (
1280
  attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
 
1306
  config_class=_CONFIG_FOR_DOC,
1307
  )
1308
  def forward(
1309
+ self,
1310
+ input_ids: Optional[torch.Tensor] = None,
1311
+ attention_mask: Optional[torch.Tensor] = None,
1312
+ token_type_ids: Optional[torch.Tensor] = None,
1313
+ position_ids: Optional[torch.Tensor] = None,
1314
+ head_mask: Optional[torch.Tensor] = None,
1315
+ inputs_embeds: Optional[torch.Tensor] = None,
1316
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1317
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1318
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1319
+ use_cache: Optional[bool] = None,
1320
+ output_attentions: Optional[bool] = None,
1321
+ output_hidden_states: Optional[bool] = None,
1322
+ return_dict: Optional[bool] = None,
1323
+ sliding_window: Optional[int] = None,
1324
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1325
  r"""
1326
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
 
1446
  output_attentions=output_attentions,
1447
  output_hidden_states=output_hidden_states,
1448
  return_dict=return_dict,
1449
+ sliding_window=sliding_window
1450
  )
1451
  sequence_output = encoder_outputs[0]
1452
  pooled_output = (
 
1498
  output_type=JinaBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
1499
  )
1500
  def forward(
1501
+ self,
1502
+ input_ids: Optional[torch.Tensor] = None,
1503
+ attention_mask: Optional[torch.Tensor] = None,
1504
+ token_type_ids: Optional[torch.Tensor] = None,
1505
+ position_ids: Optional[torch.Tensor] = None,
1506
+ head_mask: Optional[torch.Tensor] = None,
1507
+ inputs_embeds: Optional[torch.Tensor] = None,
1508
+ labels: Optional[torch.Tensor] = None,
1509
+ next_sentence_label: Optional[torch.Tensor] = None,
1510
+ output_attentions: Optional[bool] = None,
1511
+ output_hidden_states: Optional[bool] = None,
1512
+ return_dict: Optional[bool] = None,
1513
  ) -> Union[Tuple[torch.Tensor], JinaBertForPreTrainingOutput]:
1514
  r"""
1515
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
1608
  config_class=_CONFIG_FOR_DOC,
1609
  )
1610
  def forward(
1611
+ self,
1612
+ input_ids: Optional[torch.Tensor] = None,
1613
+ attention_mask: Optional[torch.Tensor] = None,
1614
+ token_type_ids: Optional[torch.Tensor] = None,
1615
+ position_ids: Optional[torch.Tensor] = None,
1616
+ head_mask: Optional[torch.Tensor] = None,
1617
+ inputs_embeds: Optional[torch.Tensor] = None,
1618
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1619
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1620
+ labels: Optional[torch.Tensor] = None,
1621
+ past_key_values: Optional[List[torch.Tensor]] = None,
1622
+ use_cache: Optional[bool] = None,
1623
+ output_attentions: Optional[bool] = None,
1624
+ output_hidden_states: Optional[bool] = None,
1625
+ return_dict: Optional[bool] = None,
1626
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
1627
  r"""
1628
  encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
 
1698
  )
1699
 
1700
  def prepare_inputs_for_generation(
1701
+ self,
1702
+ input_ids,
1703
+ past_key_values=None,
1704
+ attention_mask=None,
1705
+ use_cache=True,
1706
+ **model_kwargs,
1707
  ):
1708
  input_shape = input_ids.shape
1709
  # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
 
1770
  expected_loss=0.88,
1771
  )
1772
  def forward(
1773
+ self,
1774
+ input_ids: Optional[torch.Tensor] = None,
1775
+ attention_mask: Optional[torch.Tensor] = None,
1776
+ token_type_ids: Optional[torch.Tensor] = None,
1777
+ position_ids: Optional[torch.Tensor] = None,
1778
+ head_mask: Optional[torch.Tensor] = None,
1779
+ inputs_embeds: Optional[torch.Tensor] = None,
1780
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1781
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1782
+ labels: Optional[torch.Tensor] = None,
1783
+ output_attentions: Optional[bool] = None,
1784
+ output_hidden_states: Optional[bool] = None,
1785
+ return_dict: Optional[bool] = None,
1786
  ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1787
  r"""
1788
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
1833
  )
1834
 
1835
  def prepare_inputs_for_generation(
1836
+ self, input_ids, attention_mask=None, **model_kwargs
1837
  ):
1838
  input_shape = input_ids.shape
1839
  effective_batch_size = input_shape[0]
 
1878
  output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
1879
  )
1880
  def forward(
1881
+ self,
1882
+ input_ids: Optional[torch.Tensor] = None,
1883
+ attention_mask: Optional[torch.Tensor] = None,
1884
+ token_type_ids: Optional[torch.Tensor] = None,
1885
+ position_ids: Optional[torch.Tensor] = None,
1886
+ head_mask: Optional[torch.Tensor] = None,
1887
+ inputs_embeds: Optional[torch.Tensor] = None,
1888
+ labels: Optional[torch.Tensor] = None,
1889
+ output_attentions: Optional[bool] = None,
1890
+ output_hidden_states: Optional[bool] = None,
1891
+ return_dict: Optional[bool] = None,
1892
+ **kwargs,
1893
  ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
1894
  r"""
1895
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1989
  expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1990
  )
1991
  def forward(
1992
+ self,
1993
+ input_ids: Optional[torch.Tensor] = None,
1994
+ attention_mask: Optional[torch.Tensor] = None,
1995
+ token_type_ids: Optional[torch.Tensor] = None,
1996
+ position_ids: Optional[torch.Tensor] = None,
1997
+ head_mask: Optional[torch.Tensor] = None,
1998
+ inputs_embeds: Optional[torch.Tensor] = None,
1999
+ labels: Optional[torch.Tensor] = None,
2000
+ output_attentions: Optional[bool] = None,
2001
+ output_hidden_states: Optional[bool] = None,
2002
+ return_dict: Optional[bool] = None,
2003
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
2004
  r"""
2005
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
2034
  if self.num_labels == 1:
2035
  self.config.problem_type = "regression"
2036
  elif self.num_labels > 1 and (
2037
+ labels.dtype == torch.long or labels.dtype == torch.int
2038
  ):
2039
  self.config.problem_type = "single_label_classification"
2040
  else:
 
2096
  config_class=_CONFIG_FOR_DOC,
2097
  )
2098
  def forward(
2099
+ self,
2100
+ input_ids: Optional[torch.Tensor] = None,
2101
+ attention_mask: Optional[torch.Tensor] = None,
2102
+ token_type_ids: Optional[torch.Tensor] = None,
2103
+ position_ids: Optional[torch.Tensor] = None,
2104
+ head_mask: Optional[torch.Tensor] = None,
2105
+ inputs_embeds: Optional[torch.Tensor] = None,
2106
+ labels: Optional[torch.Tensor] = None,
2107
+ output_attentions: Optional[bool] = None,
2108
+ output_hidden_states: Optional[bool] = None,
2109
+ return_dict: Optional[bool] = None,
2110
  ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
2111
  r"""
2112
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
2215
  expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
2216
  )
2217
  def forward(
2218
+ self,
2219
+ input_ids: Optional[torch.Tensor] = None,
2220
+ attention_mask: Optional[torch.Tensor] = None,
2221
+ token_type_ids: Optional[torch.Tensor] = None,
2222
+ position_ids: Optional[torch.Tensor] = None,
2223
+ head_mask: Optional[torch.Tensor] = None,
2224
+ inputs_embeds: Optional[torch.Tensor] = None,
2225
+ labels: Optional[torch.Tensor] = None,
2226
+ output_attentions: Optional[bool] = None,
2227
+ output_hidden_states: Optional[bool] = None,
2228
+ return_dict: Optional[bool] = None,
2229
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
2230
  r"""
2231
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
2300
  expected_loss=_QA_EXPECTED_LOSS,
2301
  )
2302
  def forward(
2303
+ self,
2304
+ input_ids: Optional[torch.Tensor] = None,
2305
+ attention_mask: Optional[torch.Tensor] = None,
2306
+ token_type_ids: Optional[torch.Tensor] = None,
2307
+ position_ids: Optional[torch.Tensor] = None,
2308
+ head_mask: Optional[torch.Tensor] = None,
2309
+ inputs_embeds: Optional[torch.Tensor] = None,
2310
+ start_positions: Optional[torch.Tensor] = None,
2311
+ end_positions: Optional[torch.Tensor] = None,
2312
+ output_attentions: Optional[bool] = None,
2313
+ output_hidden_states: Optional[bool] = None,
2314
+ return_dict: Optional[bool] = None,
2315
  ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
2316
  r"""
2317
  start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):