zRzRzRzRzRzRzR katuni4ka commited on
Commit
67d005d
·
verified ·
1 Parent(s): 91a0561

compatibility with new transformers (#60)

Browse files

- compatibility with new transformers (9e21dac48837929ca4df28e3dcb6ae04c184573d)
- Update modeling_chatglm.py (c59cdd3bafd43d5f9b3e82c88c088eccb3925e02)


Co-authored-by: Ekaterina Aidova <katuni4ka@users.noreply.huggingface.co>

Files changed (1) hide show
  1. modeling_chatglm.py +17 -3
modeling_chatglm.py CHANGED
@@ -14,6 +14,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
14
  from torch.nn.utils import skip_init
15
  from typing import Optional, Tuple, Union, List, Callable, Dict, Any
16
  from copy import deepcopy
 
17
 
18
  from transformers.modeling_outputs import (
19
  BaseModelOutputWithPast,
@@ -45,6 +46,9 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
  # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
46
  ]
47
 
 
 
 
48
 
49
  def default_init(cls, *args, **kwargs):
50
  return cls(*args, **kwargs)
@@ -872,9 +876,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
872
  standardize_cache_format: bool = False,
873
  ) -> Dict[str, Any]:
874
  # update past_key_values
875
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
876
- outputs, standardize_cache_format=standardize_cache_format
877
- )
 
 
 
 
 
 
 
 
 
 
878
 
879
  # update attention mask
880
  if "attention_mask" in model_kwargs:
 
14
  from torch.nn.utils import skip_init
15
  from typing import Optional, Tuple, Union, List, Callable, Dict, Any
16
  from copy import deepcopy
17
+ import transformers
18
 
19
  from transformers.modeling_outputs import (
20
  BaseModelOutputWithPast,
 
46
  # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
47
  ]
48
 
49
+ is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
50
+ is_transformers_4_44_or_higher = int(transformers.__version__.split(".")[1]) >= 44
51
+
52
 
53
  def default_init(cls, *args, **kwargs):
54
  return cls(*args, **kwargs)
 
876
  standardize_cache_format: bool = False,
877
  ) -> Dict[str, Any]:
878
  # update past_key_values
879
+ if is_transformers_4_44_or_higher:
880
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
881
+ outputs
882
+ )[1]
883
+ elif is_transformers_4_42_or_higher:
884
+ # update past_key_values
885
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
886
+ outputs, standardize_cache_format=standardize_cache_format
887
+ )[1]
888
+ else:
889
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
890
+ outputs, standardize_cache_format=standardize_cache_format
891
+ )
892
 
893
  # update attention mask
894
  if "attention_mask" in model_kwargs: