Update modeling_chatglm.py
Browse files- modeling_chatglm.py +10 -6
modeling_chatglm.py
CHANGED
@@ -516,6 +516,7 @@ class GLMBlock(torch.nn.Module):
|
|
516 |
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
517 |
super(GLMBlock, self).__init__()
|
518 |
self.layer_number = layer_number
|
|
|
519 |
|
520 |
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
521 |
|
@@ -524,7 +525,7 @@ class GLMBlock(torch.nn.Module):
|
|
524 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
525 |
# Layernorm on the input data.
|
526 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
527 |
-
dtype=
|
528 |
|
529 |
# Self attention.
|
530 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
@@ -532,7 +533,7 @@ class GLMBlock(torch.nn.Module):
|
|
532 |
|
533 |
# Layernorm on the attention output
|
534 |
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
535 |
-
dtype=
|
536 |
|
537 |
# MLP
|
538 |
self.mlp = MLP(config, device=device)
|
@@ -600,9 +601,10 @@ class GLMTransformer(torch.nn.Module):
|
|
600 |
|
601 |
if self.post_layer_norm:
|
602 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
|
|
603 |
# Final layer norm before output.
|
604 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
605 |
-
dtype=
|
606 |
|
607 |
self.gradient_checkpointing = False
|
608 |
|
@@ -711,13 +713,14 @@ class Embedding(torch.nn.Module):
|
|
711 |
|
712 |
def __init__(self, config: ChatGLMConfig, device=None):
|
713 |
super(Embedding, self).__init__()
|
|
|
714 |
|
715 |
self.hidden_size = config.hidden_size
|
716 |
# Word embeddings (parallel).
|
717 |
self.word_embeddings = nn.Embedding(
|
718 |
config.padded_vocab_size,
|
719 |
self.hidden_size,
|
720 |
-
dtype=
|
721 |
device=device
|
722 |
)
|
723 |
self.fp32_residual_connection = config.fp32_residual_connection
|
@@ -748,6 +751,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
748 |
self.num_layers = config.num_layers
|
749 |
self.multi_query_group_num = config.multi_query_group_num
|
750 |
self.kv_channels = config.kv_channels
|
|
|
751 |
|
752 |
# Rotary positional embeddings
|
753 |
self.seq_length = config.seq_length
|
@@ -756,10 +760,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
756 |
)
|
757 |
|
758 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
759 |
-
dtype=
|
760 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
761 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
762 |
-
dtype=
|
763 |
self.pre_seq_len = config.pre_seq_len
|
764 |
self.prefix_projection = config.prefix_projection
|
765 |
if self.pre_seq_len is not None:
|
|
|
516 |
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
517 |
super(GLMBlock, self).__init__()
|
518 |
self.layer_number = layer_number
|
519 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
520 |
|
521 |
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
522 |
|
|
|
525 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
526 |
# Layernorm on the input data.
|
527 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
528 |
+
dtype=dtype)
|
529 |
|
530 |
# Self attention.
|
531 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
|
|
533 |
|
534 |
# Layernorm on the attention output
|
535 |
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
536 |
+
dtype=dtype)
|
537 |
|
538 |
# MLP
|
539 |
self.mlp = MLP(config, device=device)
|
|
|
601 |
|
602 |
if self.post_layer_norm:
|
603 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
604 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
605 |
# Final layer norm before output.
|
606 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
607 |
+
dtype=dtype)
|
608 |
|
609 |
self.gradient_checkpointing = False
|
610 |
|
|
|
713 |
|
714 |
def __init__(self, config: ChatGLMConfig, device=None):
|
715 |
super(Embedding, self).__init__()
|
716 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
717 |
|
718 |
self.hidden_size = config.hidden_size
|
719 |
# Word embeddings (parallel).
|
720 |
self.word_embeddings = nn.Embedding(
|
721 |
config.padded_vocab_size,
|
722 |
self.hidden_size,
|
723 |
+
dtype=dtype,
|
724 |
device=device
|
725 |
)
|
726 |
self.fp32_residual_connection = config.fp32_residual_connection
|
|
|
751 |
self.num_layers = config.num_layers
|
752 |
self.multi_query_group_num = config.multi_query_group_num
|
753 |
self.kv_channels = config.kv_channels
|
754 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
755 |
|
756 |
# Rotary positional embeddings
|
757 |
self.seq_length = config.seq_length
|
|
|
760 |
)
|
761 |
|
762 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
763 |
+
dtype=dtype)
|
764 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
765 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
766 |
+
dtype=dtype, **init_kwargs)
|
767 |
self.pre_seq_len = config.pre_seq_len
|
768 |
self.prefix_projection = config.prefix_projection
|
769 |
if self.pre_seq_len is not None:
|