Update modeling_chatglm.py
Browse files- modeling_chatglm.py +13 -3
modeling_chatglm.py
CHANGED
@@ -455,7 +455,7 @@ class SelfAttention(torch.nn.Module):
|
|
455 |
|
456 |
def _config_to_kwargs(args):
|
457 |
common_kwargs = {
|
458 |
-
"dtype": args.torch_dtype,
|
459 |
}
|
460 |
return common_kwargs
|
461 |
|
@@ -746,7 +746,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
746 |
init_method = default_init
|
747 |
init_kwargs = {}
|
748 |
if device is not None:
|
749 |
-
init_kwargs["device"] = device
|
750 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
751 |
self.num_layers = config.num_layers
|
752 |
self.multi_query_group_num = config.multi_query_group_num
|
@@ -868,6 +868,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
868 |
if self.config.quantization_bit:
|
869 |
self.quantize(self.config.quantization_bit, empty_init=True)
|
870 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
871 |
def _update_model_kwargs_for_generation(
|
872 |
self,
|
873 |
outputs: ModelOutput,
|
@@ -1300,4 +1310,4 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1300 |
past_key_values=transformer_outputs.past_key_values,
|
1301 |
hidden_states=transformer_outputs.hidden_states,
|
1302 |
attentions=transformer_outputs.attentions,
|
1303 |
-
)
|
|
|
455 |
|
456 |
def _config_to_kwargs(args):
|
457 |
common_kwargs = {
|
458 |
+
"dtype": args.torch_dtype if not isinstance(args.torch_dtype, str) else getattr(torch, args.torch_dtype)
|
459 |
}
|
460 |
return common_kwargs
|
461 |
|
|
|
746 |
init_method = default_init
|
747 |
init_kwargs = {}
|
748 |
if device is not None:
|
749 |
+
init_kwargs["device"] = device if not isinstance(device, str) else torch.device(device)
|
750 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
751 |
self.num_layers = config.num_layers
|
752 |
self.multi_query_group_num = config.multi_query_group_num
|
|
|
868 |
if self.config.quantization_bit:
|
869 |
self.quantize(self.config.quantization_bit, empty_init=True)
|
870 |
|
871 |
+
|
872 |
+
@staticmethod
|
873 |
+
def _extract_past_from_model_output(outputs: ModelOutput, *args, **kwargs):
|
874 |
+
past_key_values = None
|
875 |
+
if "past_key_values" in outputs:
|
876 |
+
past_key_values = outputs.past_key_values
|
877 |
+
if is_transformers_4_42_or_higher:
|
878 |
+
return None, past_key_values
|
879 |
+
return past_key_values
|
880 |
+
|
881 |
def _update_model_kwargs_for_generation(
|
882 |
self,
|
883 |
outputs: ModelOutput,
|
|
|
1310 |
past_key_values=transformer_outputs.past_key_values,
|
1311 |
hidden_states=transformer_outputs.hidden_states,
|
1312 |
attentions=transformer_outputs.attentions,
|
1313 |
+
)
|