zxdu20 commited on
Commit
cd8041e
1 Parent(s): 65bb3f0

Fix past_key_values

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +4 -1
modeling_chatglm.py CHANGED
@@ -952,6 +952,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
952
  self,
953
  input_ids: torch.LongTensor,
954
  past: Optional[torch.Tensor] = None,
 
955
  attention_mask: Optional[torch.Tensor] = None,
956
  **kwargs
957
  ) -> dict:
@@ -966,7 +967,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
966
  raise ValueError("You have to add either [MASK] or [gMASK] in your input")
967
 
968
  # only last token for input_ids if past is not None
969
- if past:
970
  context_length = seq.index(150004)
971
  last_token = input_ids[:, -1].unsqueeze(-1)
972
  if self.position_encoding_2d:
@@ -975,6 +976,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
975
  else:
976
  position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
977
 
 
 
978
  return {
979
  "input_ids": last_token,
980
  "past_key_values": past,
 
952
  self,
953
  input_ids: torch.LongTensor,
954
  past: Optional[torch.Tensor] = None,
955
+ past_key_values: Optional[torch.Tensor] = None,
956
  attention_mask: Optional[torch.Tensor] = None,
957
  **kwargs
958
  ) -> dict:
 
967
  raise ValueError("You have to add either [MASK] or [gMASK] in your input")
968
 
969
  # only last token for input_ids if past is not None
970
+ if past is not None or past_key_values is not None:
971
  context_length = seq.index(150004)
972
  last_token = input_ids[:, -1].unsqueeze(-1)
973
  if self.position_encoding_2d:
 
976
  else:
977
  position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
978
 
979
+ if past is None:
980
+ past = past_key_values
981
  return {
982
  "input_ids": last_token,
983
  "past_key_values": past,