katuni4ka commited on
Commit
44385a0
·
verified ·
1 Parent(s): 1578ff5

Update modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +13 -3
modeling_chatglm.py CHANGED
@@ -455,7 +455,7 @@ class SelfAttention(torch.nn.Module):
455
 
456
  def _config_to_kwargs(args):
457
  common_kwargs = {
458
- "dtype": args.torch_dtype,
459
  }
460
  return common_kwargs
461
 
@@ -746,7 +746,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
746
  init_method = default_init
747
  init_kwargs = {}
748
  if device is not None:
749
- init_kwargs["device"] = device
750
  self.embedding = init_method(Embedding, config, **init_kwargs)
751
  self.num_layers = config.num_layers
752
  self.multi_query_group_num = config.multi_query_group_num
@@ -868,6 +868,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
868
  if self.config.quantization_bit:
869
  self.quantize(self.config.quantization_bit, empty_init=True)
870
 
 
 
 
 
 
 
 
 
 
 
871
  def _update_model_kwargs_for_generation(
872
  self,
873
  outputs: ModelOutput,
@@ -1300,4 +1310,4 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1300
  past_key_values=transformer_outputs.past_key_values,
1301
  hidden_states=transformer_outputs.hidden_states,
1302
  attentions=transformer_outputs.attentions,
1303
- )
 
455
 
456
  def _config_to_kwargs(args):
457
  common_kwargs = {
458
+ "dtype": args.torch_dtype if not isinstance(args.torch_dtype, str) else getattr(torch, args.torch_dtype)
459
  }
460
  return common_kwargs
461
 
 
746
  init_method = default_init
747
  init_kwargs = {}
748
  if device is not None:
749
+ init_kwargs["device"] = device if not isinstance(device, str) else torch.device(device)
750
  self.embedding = init_method(Embedding, config, **init_kwargs)
751
  self.num_layers = config.num_layers
752
  self.multi_query_group_num = config.multi_query_group_num
 
868
  if self.config.quantization_bit:
869
  self.quantize(self.config.quantization_bit, empty_init=True)
870
 
871
+
872
+ @staticmethod
873
+ def _extract_past_from_model_output(outputs: ModelOutput, *args, **kwargs):
874
+ past_key_values = None
875
+ if "past_key_values" in outputs:
876
+ past_key_values = outputs.past_key_values
877
+ if is_transformers_4_42_or_higher:
878
+ return None, past_key_values
879
+ return past_key_values
880
+
881
  def _update_model_kwargs_for_generation(
882
  self,
883
  outputs: ModelOutput,
 
1310
  past_key_values=transformer_outputs.past_key_values,
1311
  hidden_states=transformer_outputs.hidden_states,
1312
  attentions=transformer_outputs.attentions,
1313
+ )