zxdu20 commited on
Commit
696788d
·
1 Parent(s): ff8a42a

Fix batch beam search

Browse files
Files changed (1) hide show
  1. modeling_glm.py +82 -10
modeling_glm.py CHANGED
@@ -30,6 +30,7 @@ from transformers.utils import (
30
  from transformers.modeling_outputs import (
31
  BaseModelOutputWithPastAndCrossAttentions,
32
  ModelOutput,
 
33
  )
34
 
35
  from transformers.modeling_utils import (
@@ -780,17 +781,15 @@ class GLMModel(GLMPreTrainedModel):
780
  attention_mask = torch.zeros(batch_size)
781
  # Transformer.
782
  transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems)
783
- logits, hidden_layers = transformer_output
784
- # outputs = hidden_layers
785
  if self.output_predict:
786
- # Parallel logits.
787
- # logits_parallel = mpu.copy_to_model_parallel_region(
788
- # logits)
789
- logits = F.linear(logits, self.word_embeddings.weight)
790
 
791
  return ModelOutput(
 
792
  logits=logits,
793
- mems=hidden_layers,
794
  )
795
 
796
 
@@ -815,7 +814,7 @@ class GLMForMultipleChoice(GLMPreTrainedModel):
815
  mems=None,
816
  **kwargs
817
  ):
818
- model_output = self.glm.forward(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
819
  lm_logits = model_output.logits
820
  log_probs = []
821
  for output, choices, choice_index in zip(F.log_softmax(lm_logits, dim=-1), choice_ids, choice_indices):
@@ -874,6 +873,16 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
874
  position_ids = position_ids[:, :, :seq_length]
875
  if attention_mask is not None:
876
  attention_mask = attention_mask[:, :, :seq_length, :seq_length]
 
 
 
 
 
 
 
 
 
 
877
  return {
878
  "input_ids": input_ids,
879
  "position_ids": position_ids,
@@ -890,7 +899,7 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
890
  mems=None,
891
  **kwargs
892
  ):
893
- model_output = self.glm.forward(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
894
  lm_logits = model_output.logits
895
  loss = None
896
  if labels is not None:
@@ -900,4 +909,67 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
900
  loss=loss,
901
  logits=lm_logits,
902
  mems=model_output.mems
903
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  from transformers.modeling_outputs import (
31
  BaseModelOutputWithPastAndCrossAttentions,
32
  ModelOutput,
33
+ SequenceClassifierOutput,
34
  )
35
 
36
  from transformers.modeling_utils import (
 
781
  attention_mask = torch.zeros(batch_size)
782
  # Transformer.
783
  transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems)
784
+ last_hidden_states, mems = transformer_output
785
+ logits = None
786
  if self.output_predict:
787
+ logits = F.linear(last_hidden_states, self.word_embeddings.weight)
 
 
 
788
 
789
  return ModelOutput(
790
+ last_hidden_states=last_hidden_states,
791
  logits=logits,
792
+ mems=mems,
793
  )
794
 
795
 
 
814
  mems=None,
815
  **kwargs
816
  ):
817
+ model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
818
  lm_logits = model_output.logits
819
  log_probs = []
820
  for output, choices, choice_index in zip(F.log_softmax(lm_logits, dim=-1), choice_ids, choice_indices):
 
873
  position_ids = position_ids[:, :, :seq_length]
874
  if attention_mask is not None:
875
  attention_mask = attention_mask[:, :, :seq_length, :seq_length]
876
+ if position_ids is not None and input_ids.size(0) > position_ids.size(0):
877
+ batch_size = position_ids.size(0)
878
+ num_beams = input_ids.size(0) // batch_size
879
+ position_ids = position_ids.unsqueeze(1).expand(-1, num_beams, -1, -1)
880
+ position_ids = position_ids.reshape(batch_size * num_beams, *position_ids.shape[-2:])
881
+ if attention_mask is not None and input_ids.size(0) > attention_mask.size(0):
882
+ batch_size = attention_mask.size(0)
883
+ num_beams = input_ids.size(0) // batch_size
884
+ attention_mask = attention_mask.unsqueeze(1).expand(-1, num_beams, -1, -1, -1)
885
+ attention_mask = attention_mask.reshape(batch_size * num_beams, *attention_mask.shape[-3:])
886
  return {
887
  "input_ids": input_ids,
888
  "position_ids": position_ids,
 
899
  mems=None,
900
  **kwargs
901
  ):
902
+ model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
903
  lm_logits = model_output.logits
904
  loss = None
905
  if labels is not None:
 
909
  loss=loss,
910
  logits=lm_logits,
911
  mems=model_output.mems
912
+ )
913
+
914
+
915
+ @add_start_docstrings(
916
+ """GLM Model transformer with a sequence classification/regression head on top (a linear layer on top of
917
+ the pooled output) e.g. for GLUE tasks. """,
918
+ GLM_START_DOCSTRING,
919
+ )
920
+ class GLMForSequenceClassification(GLMPreTrainedModel):
921
+ def __init__(self, config: GLMConfig, hidden_dropout=None, num_class=1):
922
+ super().__init__(config)
923
+ self.pool_token = config.pool_token
924
+ self.glm = GLMModel(config)
925
+ self.glm.output_predict = False
926
+ self.num_class = num_class
927
+ # Multi-choice head.
928
+ self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
929
+ classifier_dropout = (
930
+ config.classifier_dropout if config.classifier_dropout is not None else config.output_dropout_prob
931
+ )
932
+ self.dropout = torch.nn.Dropout(classifier_dropout)
933
+ self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels)
934
+
935
+ # Initialize weights and apply final processing
936
+ self.post_init()
937
+
938
+ @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
939
+ @add_code_sample_docstrings(
940
+ processor_class=_TOKENIZER_FOR_DOC,
941
+ checkpoint=_CHECKPOINT_FOR_DOC,
942
+ output_type=SequenceClassifierOutput,
943
+ config_class=_CONFIG_FOR_DOC,
944
+ )
945
+ def forward(self,
946
+ input_ids=None,
947
+ position_ids=None,
948
+ attention_mask=None,
949
+ labels=None):
950
+
951
+ num_choices = None
952
+
953
+ if len(input_ids.shape) == 3:
954
+ batch_size, num_choices = input_ids.shape[:2]
955
+ input_ids = input_ids.reshape(-1, input_ids.size(-1))
956
+ attention_mask = attention_mask.reshape(-1, *attention_mask.size()[2:])
957
+ position_ids = position_ids.reshape(-1, *position_ids.size()[2:])
958
+ model_out = self.glm(input_ids, position_ids, attention_mask)
959
+ outputs, mems = model_out.last_hidden_states, model_out.mems
960
+
961
+ output = outputs[:, 0, :]
962
+ output = self.dropout(output)
963
+ output = torch.tanh(self.dense(output))
964
+ output = self.dropout(output)
965
+ logits = self.out_proj(output)
966
+ if num_choices is not None:
967
+ logits = logits.view(-1, num_choices)
968
+ loss = None
969
+ if labels is not None:
970
+ loss_fct = CrossEntropyLoss()
971
+ loss = loss_fct(logits, labels)
972
+ # loss = F.cross_entropy(logits.contiguous().float(), labels.long())
973
+ return SequenceClassifierOutput(loss=loss,
974
+ logits=logits,
975
+ hidden_states=outputs)