katuni4ka commited on
Commit
1578ff5
·
verified ·
1 Parent(s): 80a12c4

Update modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +10 -6
modeling_chatglm.py CHANGED
@@ -516,6 +516,7 @@ class GLMBlock(torch.nn.Module):
516
  def __init__(self, config: ChatGLMConfig, layer_number, device=None):
517
  super(GLMBlock, self).__init__()
518
  self.layer_number = layer_number
 
519
 
520
  self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
521
 
@@ -524,7 +525,7 @@ class GLMBlock(torch.nn.Module):
524
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
525
  # Layernorm on the input data.
526
  self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
527
- dtype=config.torch_dtype)
528
 
529
  # Self attention.
530
  self.self_attention = SelfAttention(config, layer_number, device=device)
@@ -532,7 +533,7 @@ class GLMBlock(torch.nn.Module):
532
 
533
  # Layernorm on the attention output
534
  self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
535
- dtype=config.torch_dtype)
536
 
537
  # MLP
538
  self.mlp = MLP(config, device=device)
@@ -600,9 +601,10 @@ class GLMTransformer(torch.nn.Module):
600
 
601
  if self.post_layer_norm:
602
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
 
603
  # Final layer norm before output.
604
  self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
605
- dtype=config.torch_dtype)
606
 
607
  self.gradient_checkpointing = False
608
 
@@ -711,13 +713,14 @@ class Embedding(torch.nn.Module):
711
 
712
  def __init__(self, config: ChatGLMConfig, device=None):
713
  super(Embedding, self).__init__()
 
714
 
715
  self.hidden_size = config.hidden_size
716
  # Word embeddings (parallel).
717
  self.word_embeddings = nn.Embedding(
718
  config.padded_vocab_size,
719
  self.hidden_size,
720
- dtype=config.torch_dtype,
721
  device=device
722
  )
723
  self.fp32_residual_connection = config.fp32_residual_connection
@@ -748,6 +751,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
748
  self.num_layers = config.num_layers
749
  self.multi_query_group_num = config.multi_query_group_num
750
  self.kv_channels = config.kv_channels
 
751
 
752
  # Rotary positional embeddings
753
  self.seq_length = config.seq_length
@@ -756,10 +760,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
756
  )
757
 
758
  self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
759
- dtype=config.torch_dtype)
760
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
761
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
762
- dtype=config.torch_dtype, **init_kwargs)
763
  self.pre_seq_len = config.pre_seq_len
764
  self.prefix_projection = config.prefix_projection
765
  if self.pre_seq_len is not None:
 
516
  def __init__(self, config: ChatGLMConfig, layer_number, device=None):
517
  super(GLMBlock, self).__init__()
518
  self.layer_number = layer_number
519
+ dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
520
 
521
  self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
522
 
 
525
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
526
  # Layernorm on the input data.
527
  self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
528
+ dtype=dtype)
529
 
530
  # Self attention.
531
  self.self_attention = SelfAttention(config, layer_number, device=device)
 
533
 
534
  # Layernorm on the attention output
535
  self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
536
+ dtype=dtype)
537
 
538
  # MLP
539
  self.mlp = MLP(config, device=device)
 
601
 
602
  if self.post_layer_norm:
603
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
604
+ dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
605
  # Final layer norm before output.
606
  self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
607
+ dtype=dtype)
608
 
609
  self.gradient_checkpointing = False
610
 
 
713
 
714
  def __init__(self, config: ChatGLMConfig, device=None):
715
  super(Embedding, self).__init__()
716
+ dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
717
 
718
  self.hidden_size = config.hidden_size
719
  # Word embeddings (parallel).
720
  self.word_embeddings = nn.Embedding(
721
  config.padded_vocab_size,
722
  self.hidden_size,
723
+ dtype=dtype,
724
  device=device
725
  )
726
  self.fp32_residual_connection = config.fp32_residual_connection
 
751
  self.num_layers = config.num_layers
752
  self.multi_query_group_num = config.multi_query_group_num
753
  self.kv_channels = config.kv_channels
754
+ dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
755
 
756
  # Rotary positional embeddings
757
  self.seq_length = config.seq_length
 
760
  )
761
 
762
  self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
763
+ dtype=dtype)
764
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
765
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
766
+ dtype=dtype, **init_kwargs)
767
  self.pre_seq_len = config.pre_seq_len
768
  self.prefix_projection = config.prefix_projection
769
  if self.pre_seq_len is not None: