OpenNLPLab commited on
Commit
ec5dbd2
1 Parent(s): a5b83ec

Upload modeling_transnormer.py

Browse files
Files changed (1) hide show
  1. modeling_transnormer.py +4 -136
modeling_transnormer.py CHANGED
@@ -1,4 +1,3 @@
1
- # coding=utf-8
2
  # Copyright 2023 OpenNLPLab
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,8 +11,7 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
-
16
- # coding=utf-8
17
  """ PyTorch Transnormer model."""
18
  import math
19
  import os
@@ -30,7 +28,6 @@ from transformers.activations import ACT2FN
30
  from transformers.modeling_outputs import (
31
  BaseModelOutputWithPast,
32
  CausalLMOutputWithPast,
33
- SequenceClassifierOutputWithPast,
34
  )
35
  from transformers.modeling_utils import PreTrainedModel
36
  from transformers.utils import (
@@ -752,6 +749,9 @@ class TransnormerModel(TransnormerPreTrainedModel):
752
 
753
  if output_attentions:
754
  all_self_attns += (layer_outputs[1],)
 
 
 
755
 
756
  hidden_states = self.final_norm(hidden_states)
757
 
@@ -939,135 +939,3 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
939
  )
940
  return reordered_past
941
 
942
-
943
- @add_start_docstrings(
944
- """
945
- The LLaMa Model transformer with a sequence classification head on top (linear layer).
946
-
947
- [`TransnormerForSequenceClassification`] uses the last token in order to do the classification, as other causal models
948
- (e.g. GPT-2) do.
949
-
950
- Since it does classification on the last token, it requires to know the position of the last token. If a
951
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
952
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
953
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
954
- each row of the batch).
955
- """,
956
- TRANSNORMER_START_DOCSTRING,
957
- )
958
- class TransnormerForSequenceClassification(TransnormerPreTrainedModel):
959
- _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
960
-
961
- def __init__(self, config):
962
- super().__init__(config)
963
- self.num_labels = config.num_labels
964
- self.model = TransnormerModel(config)
965
- self.score = nn.Linear(config.decoder_embed_dim, self.num_labels, bias=False)
966
-
967
- # Initialize weights and apply final processing
968
- self.post_init()
969
-
970
- def get_input_embeddings(self):
971
- return self.model.embed_tokens
972
-
973
- def set_input_embeddings(self, value):
974
- self.model.embed_tokens = value
975
-
976
- @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
977
- def forward(
978
- self,
979
- input_ids: torch.LongTensor = None,
980
- attn_mask: Optional[torch.Tensor] = None,
981
- past_key_values: Optional[List[torch.FloatTensor]] = None,
982
- inputs_embeds: Optional[torch.FloatTensor] = None,
983
- labels: Optional[torch.LongTensor] = None,
984
- use_cache: Optional[bool] = None,
985
- output_attentions: Optional[bool] = None,
986
- output_hidden_states: Optional[bool] = None,
987
- return_dict: Optional[bool] = None,
988
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
989
- r"""
990
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
991
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
992
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
993
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
994
- """
995
- return_dict = (
996
- return_dict if return_dict is not None else self.config.use_return_dict
997
- )
998
-
999
- transformer_outputs = self.model(
1000
- input_ids,
1001
- attn_padding_mask=attn_mask,
1002
- past_key_values=past_key_values,
1003
- inputs_embeds=inputs_embeds,
1004
- use_cache=use_cache,
1005
- output_attentions=output_attentions,
1006
- output_hidden_states=output_hidden_states,
1007
- return_dict=return_dict,
1008
- )
1009
- hidden_states = transformer_outputs[0]
1010
-
1011
- logits = self.score(hidden_states)
1012
-
1013
- if input_ids is not None:
1014
- batch_size = input_ids.shape[0]
1015
- else:
1016
- batch_size = inputs_embeds.shape[0]
1017
-
1018
- if self.config.pad_token_id is None and batch_size != 1:
1019
- raise ValueError(
1020
- "Cannot handle batch sizes > 1 if no padding token is defined."
1021
- )
1022
- if self.config.pad_token_id is None:
1023
- sequence_lengths = -1
1024
- else:
1025
- if input_ids is not None:
1026
- sequence_lengths = (
1027
- torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1028
- ).to(logits.device)
1029
- else:
1030
- sequence_lengths = -1
1031
-
1032
- pooled_logits = logits[
1033
- torch.arange(batch_size, device=logits.device), sequence_lengths
1034
- ]
1035
-
1036
- loss = None
1037
- if labels is not None:
1038
- labels = labels.to(logits.device)
1039
- if self.config.problem_type is None:
1040
- if self.num_labels == 1:
1041
- self.config.problem_type = "regression"
1042
- elif self.num_labels > 1 and (
1043
- labels.dtype == torch.long or labels.dtype == torch.int
1044
- ):
1045
- self.config.problem_type = "single_label_classification"
1046
- else:
1047
- self.config.problem_type = "multi_label_classification"
1048
-
1049
- if self.config.problem_type == "regression":
1050
- loss_fct = MSELoss()
1051
- if self.num_labels == 1:
1052
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1053
- else:
1054
- loss = loss_fct(pooled_logits, labels)
1055
- elif self.config.problem_type == "single_label_classification":
1056
- loss_fct = CrossEntropyLoss()
1057
- loss = loss_fct(
1058
- pooled_logits.view(-1, self.num_labels), labels.view(-1)
1059
- )
1060
- elif self.config.problem_type == "multi_label_classification":
1061
- loss_fct = BCEWithLogitsLoss()
1062
- loss = loss_fct(pooled_logits, labels)
1063
- if not return_dict:
1064
- output = (pooled_logits,) + transformer_outputs[1:]
1065
- return ((loss,) + output) if loss is not None else output
1066
-
1067
- return SequenceClassifierOutputWithPast(
1068
- loss=loss,
1069
- logits=pooled_logits,
1070
- past_key_values=transformer_outputs.past_key_values,
1071
- hidden_states=transformer_outputs.hidden_states,
1072
- attentions=transformer_outputs.attentions,
1073
- )
 
 
1
  # Copyright 2023 OpenNLPLab
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ # coding=utf-8
 
15
  """ PyTorch Transnormer model."""
16
  import math
17
  import os
 
28
  from transformers.modeling_outputs import (
29
  BaseModelOutputWithPast,
30
  CausalLMOutputWithPast,
 
31
  )
32
  from transformers.modeling_utils import PreTrainedModel
33
  from transformers.utils import (
 
749
 
750
  if output_attentions:
751
  all_self_attns += (layer_outputs[1],)
752
+
753
+ # if idx == 0:
754
+ # break
755
 
756
  hidden_states = self.final_norm(hidden_states)
757
 
 
939
  )
940
  return reordered_past
941