davda54 commited on
Commit
c05156b
1 Parent(s): 3d4b7e5

Delete convert_weight.py

Browse files
Files changed (1) hide show
  1. convert_weight.py +0 -72
convert_weight.py DELETED
@@ -1,72 +0,0 @@
1
- import torch
2
-
3
-
4
- input_dir_path = "/scratch/project_462000086/norwegian_gpt/Megatron-DeepSpeed-fixed/checkpoints/global_step120000"
5
- output_dir_path = "/scratch/project_462000086/norwegian_gpt/Megatron-DeepSpeed-fixed/hf_pilot_checkpoint_120k"
6
-
7
- n_hidden = 4096
8
- n_heads = 32
9
- n_layers = 32
10
- n_tp = 4
11
-
12
-
13
- weights = {}
14
-
15
- # embedding
16
- embedding_weights = []
17
- for i in range(n_tp):
18
- path = f"{input_dir_path}/layer_01-model_0{i}-model_states.pt"
19
- checkpoint = torch.load(path)
20
-
21
- embedding_weights.append(checkpoint["word_embeddings.weight"].bfloat16())
22
-
23
- weights[f"transformer.word_embeddings_layernorm.weight"] = checkpoint["word_embeddings.norm.weight"].bfloat16()
24
- weights[f"transformer.word_embeddings_layernorm.bias"] = checkpoint["word_embeddings.norm.bias"].bfloat16()
25
-
26
- weights[f"transformer.word_embeddings.weight"] = torch.cat(embedding_weights, dim=0)
27
- weights[f"lm_head.weight"] = torch.cat(embedding_weights, dim=0)
28
- del embedding_weights
29
-
30
-
31
- # transformer layers
32
- for layer in range(n_layers):
33
- qkv_weights = []
34
- qkv_biases = []
35
- o_weights = []
36
- up_weights = []
37
- up_biases = []
38
- down_weights = []
39
-
40
- for i in range(n_tp):
41
- path = f"{input_dir_path}/layer_{layer+3:02d}-model_0{i}-model_states.pt"
42
- checkpoint = torch.load(path)
43
-
44
- weights[f"transformer.h.{layer}.input_layernorm.weight"] = checkpoint["input_layernorm.weight"].bfloat16()
45
- weights[f"transformer.h.{layer}.input_layernorm.bias"] = checkpoint["input_layernorm.bias"].bfloat16()
46
- weights[f"transformer.h.{layer}.self_attention.dense.bias"] = checkpoint["self_attention.dense.bias"].bfloat16()
47
- weights[f"transformer.h.{layer}.post_attention_layernorm.weight"] = checkpoint["post_attention_layernorm.weight"].bfloat16()
48
- weights[f"transformer.h.{layer}.post_attention_layernorm.bias"] = checkpoint["post_attention_layernorm.bias"].bfloat16()
49
- weights[f"transformer.h.{layer}.mlp.dense_4h_to_h.bias"] = checkpoint["mlp.dense_4h_to_h.bias"].bfloat16()
50
-
51
- qkv_weights.append(checkpoint["self_attention.query_key_value.weight"].bfloat16())
52
- qkv_biases.append(checkpoint["self_attention.query_key_value.bias"].bfloat16())
53
- o_weights.append(checkpoint["self_attention.dense.weight"].bfloat16())
54
- up_weights.append(checkpoint["mlp.dense_h_to_4h.weight"].bfloat16())
55
- up_biases.append(checkpoint["mlp.dense_h_to_4h.bias"].bfloat16())
56
- down_weights.append(checkpoint["mlp.dense_4h_to_h.weight"].bfloat16())
57
-
58
- weights[f"transformer.h.{layer}.self_attention.query_key_value.weight"] = torch.cat(qkv_weights, dim=0)
59
- weights[f"transformer.h.{layer}.self_attention.query_key_value.bias"] = torch.cat(qkv_biases, dim=0)
60
- weights[f"transformer.h.{layer}.self_attention.dense.weight"] = torch.cat(o_weights, dim=1)
61
- weights[f"transformer.h.{layer}.mlp.dense_h_to_4h.weight"] = torch.cat(up_weights, dim=0)
62
- weights[f"transformer.h.{layer}.mlp.dense_h_to_4h.bias"] = torch.cat(up_biases, dim=0)
63
- weights[f"transformer.h.{layer}.mlp.dense_4h_to_h.weight"] = torch.cat(down_weights, dim=1)
64
-
65
- # output layer norm
66
- path = f"{input_dir_path}/layer_36-model_00-model_states.pt"
67
- checkpoint = torch.load(path)
68
-
69
- weights[f"transformer.ln_f.bias"] = checkpoint["bias"].bfloat16()
70
- weights[f"transformer.ln_f.weight"] = checkpoint["weight"].bfloat16()
71
-
72
- torch.save(weights, f"{output_dir_path}/pytorch_model.bin")