适配新版transformers | adapt transformers update (https://github.com/huggingface/transformers/pull/31116)
#58
by
HibernantBear
- opened
- modeling_chatglm.py +10 -1
modeling_chatglm.py
CHANGED
@@ -936,9 +936,18 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
936 |
standardize_cache_format: bool = False,
|
937 |
) -> Dict[str, Any]:
|
938 |
# update past_key_values
|
939 |
-
|
940 |
outputs, standardize_cache_format=standardize_cache_format
|
941 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
942 |
|
943 |
# update attention mask
|
944 |
if "attention_mask" in model_kwargs:
|
|
|
936 |
standardize_cache_format: bool = False,
|
937 |
) -> Dict[str, Any]:
|
938 |
# update past_key_values
|
939 |
+
past_output = self._extract_past_from_model_output(
|
940 |
outputs, standardize_cache_format=standardize_cache_format
|
941 |
)
|
942 |
+
# adapt transformers update (https://github.com/huggingface/transformers/pull/31116)
|
943 |
+
if(type(past_output) is tuple and type(past_output[0]) is str):
|
944 |
+
if past_output[0]=="past_key_values":
|
945 |
+
model_kwargs["past_key_values"] = past_output[1]
|
946 |
+
else:
|
947 |
+
model_kwargs["past_key_values"] = None
|
948 |
+
print(f"WARN: Get \"{past_output[0]}\" during self._extract_past_from_model_output, not \"past_key_values\"")
|
949 |
+
else:
|
950 |
+
model_kwargs["past_key_values"] = past_output
|
951 |
|
952 |
# update attention mask
|
953 |
if "attention_mask" in model_kwargs:
|