适配新版transformers | adapt transformers update (https://github.com/huggingface/transformers/pull/31116)

#58
Files changed (1) hide show
  1. modeling_chatglm.py +10 -1
modeling_chatglm.py CHANGED
@@ -936,9 +936,18 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
936
  standardize_cache_format: bool = False,
937
  ) -> Dict[str, Any]:
938
  # update past_key_values
939
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
940
  outputs, standardize_cache_format=standardize_cache_format
941
  )
 
 
 
 
 
 
 
 
 
942
 
943
  # update attention mask
944
  if "attention_mask" in model_kwargs:
 
936
  standardize_cache_format: bool = False,
937
  ) -> Dict[str, Any]:
938
  # update past_key_values
939
+ past_output = self._extract_past_from_model_output(
940
  outputs, standardize_cache_format=standardize_cache_format
941
  )
942
+ # adapt transformers update (https://github.com/huggingface/transformers/pull/31116)
943
+ if(type(past_output) is tuple and type(past_output[0]) is str):
944
+ if past_output[0]=="past_key_values":
945
+ model_kwargs["past_key_values"] = past_output[1]
946
+ else:
947
+ model_kwargs["past_key_values"] = None
948
+ print(f"WARN: Get \"{past_output[0]}\" during self._extract_past_from_model_output, not \"past_key_values\"")
949
+ else:
950
+ model_kwargs["past_key_values"] = past_output
951
 
952
  # update attention mask
953
  if "attention_mask" in model_kwargs: