katuni4ka commited on
Commit
de19195
1 Parent(s): 71090c6

Update modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +6 -1
modeling_chatglm.py CHANGED
@@ -47,6 +47,7 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
47
  ]
48
 
49
  is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
 
50
 
51
 
52
  def default_init(cls, *args, **kwargs):
@@ -870,7 +871,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
870
  is_encoder_decoder: bool = False,
871
  standardize_cache_format: bool = False,
872
  ) -> Dict[str, Any]:
873
- if is_transformers_4_42_or_higher:
 
 
 
 
874
  # update past_key_values
875
  model_kwargs["past_key_values"] = self._extract_past_from_model_output(
876
  outputs, standardize_cache_format=standardize_cache_format
 
47
  ]
48
 
49
  is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
50
+ is_transformers_4_44_or_higher = int(transformers.__version__.split(".")[1]) >= 44
51
 
52
 
53
  def default_init(cls, *args, **kwargs):
 
871
  is_encoder_decoder: bool = False,
872
  standardize_cache_format: bool = False,
873
  ) -> Dict[str, Any]:
874
+ if is_transformers_4_44_or_higher:
875
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
876
+ outputs
877
+ )[1]
878
+ elif is_transformers_4_42_or_higher:
879
  # update past_key_values
880
  model_kwargs["past_key_values"] = self._extract_past_from_model_output(
881
  outputs, standardize_cache_format=standardize_cache_format