izhx commited on
Commit
e244c93
1 Parent(s): fcceab0

Update `NewTokenClassifierOutput `

Browse files
Files changed (1) hide show
  1. modeling.py +14 -3
modeling.py CHANGED
@@ -16,6 +16,7 @@
16
  """PyTorch NEW model."""
17
 
18
  import math
 
19
  from typing import List, Optional, Tuple, Union
20
 
21
  import torch
@@ -30,7 +31,7 @@ from transformers.modeling_outputs import (
30
  MultipleChoiceModelOutput,
31
  QuestionAnsweringModelOutput,
32
  SequenceClassifierOutput,
33
- TokenClassifierOutput,
34
  )
35
  from transformers.modeling_utils import PreTrainedModel
36
  from transformers.utils import logging
@@ -1249,6 +1250,15 @@ class NewForMultipleChoice(NewPreTrainedModel):
1249
  )
1250
 
1251
 
 
 
 
 
 
 
 
 
 
1252
  class NewForTokenClassification(NewPreTrainedModel):
1253
  def __init__(self, config):
1254
  super().__init__(config)
@@ -1277,7 +1287,7 @@ class NewForTokenClassification(NewPreTrainedModel):
1277
  output_hidden_states: Optional[bool] = None,
1278
  return_dict: Optional[bool] = None,
1279
  unpad_inputs: Optional[bool] = None,
1280
- ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1281
  r"""
1282
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1283
  Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
@@ -1311,9 +1321,10 @@ class NewForTokenClassification(NewPreTrainedModel):
1311
  output = (logits,) + outputs[2:]
1312
  return ((loss,) + output) if loss is not None else output
1313
 
1314
- return TokenClassifierOutput(
1315
  loss=loss,
1316
  logits=logits,
 
1317
  hidden_states=outputs.hidden_states,
1318
  attentions=outputs.attentions,
1319
  )
 
16
  """PyTorch NEW model."""
17
 
18
  import math
19
+ from dataclasses import dataclass
20
  from typing import List, Optional, Tuple, Union
21
 
22
  import torch
 
31
  MultipleChoiceModelOutput,
32
  QuestionAnsweringModelOutput,
33
  SequenceClassifierOutput,
34
+ ModelOutput,
35
  )
36
  from transformers.modeling_utils import PreTrainedModel
37
  from transformers.utils import logging
 
1250
  )
1251
 
1252
 
1253
+ @dataclass
1254
+ class NewTokenClassifierOutput(ModelOutput):
1255
+ loss: Optional[torch.FloatTensor] = None
1256
+ logits: torch.FloatTensor = None
1257
+ last_hidden_state: torch.FloatTensor = None
1258
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
1259
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
1260
+
1261
+
1262
  class NewForTokenClassification(NewPreTrainedModel):
1263
  def __init__(self, config):
1264
  super().__init__(config)
 
1287
  output_hidden_states: Optional[bool] = None,
1288
  return_dict: Optional[bool] = None,
1289
  unpad_inputs: Optional[bool] = None,
1290
+ ) -> Union[Tuple[torch.Tensor], NewTokenClassifierOutput]:
1291
  r"""
1292
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1293
  Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
 
1321
  output = (logits,) + outputs[2:]
1322
  return ((loss,) + output) if loss is not None else output
1323
 
1324
+ return NewTokenClassifierOutput(
1325
  loss=loss,
1326
  logits=logits,
1327
+ last_hidden_state=sequence_output,
1328
  hidden_states=outputs.hidden_states,
1329
  attentions=outputs.attentions,
1330
  )