Fix past_key_values
Browse files- modeling_chatglm.py +4 -1
modeling_chatglm.py
CHANGED
@@ -952,6 +952,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
952 |
self,
|
953 |
input_ids: torch.LongTensor,
|
954 |
past: Optional[torch.Tensor] = None,
|
|
|
955 |
attention_mask: Optional[torch.Tensor] = None,
|
956 |
**kwargs
|
957 |
) -> dict:
|
@@ -966,7 +967,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
966 |
raise ValueError("You have to add either [MASK] or [gMASK] in your input")
|
967 |
|
968 |
# only last token for input_ids if past is not None
|
969 |
-
if past:
|
970 |
context_length = seq.index(150004)
|
971 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
972 |
if self.position_encoding_2d:
|
@@ -975,6 +976,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
975 |
else:
|
976 |
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
|
977 |
|
|
|
|
|
978 |
return {
|
979 |
"input_ids": last_token,
|
980 |
"past_key_values": past,
|
|
|
952 |
self,
|
953 |
input_ids: torch.LongTensor,
|
954 |
past: Optional[torch.Tensor] = None,
|
955 |
+
past_key_values: Optional[torch.Tensor] = None,
|
956 |
attention_mask: Optional[torch.Tensor] = None,
|
957 |
**kwargs
|
958 |
) -> dict:
|
|
|
967 |
raise ValueError("You have to add either [MASK] or [gMASK] in your input")
|
968 |
|
969 |
# only last token for input_ids if past is not None
|
970 |
+
if past is not None or past_key_values is not None:
|
971 |
context_length = seq.index(150004)
|
972 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
973 |
if self.position_encoding_2d:
|
|
|
976 |
else:
|
977 |
position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
|
978 |
|
979 |
+
if past is None:
|
980 |
+
past = past_key_values
|
981 |
return {
|
982 |
"input_ids": last_token,
|
983 |
"past_key_values": past,
|