bababababooey
commited on
Commit
•
d7e01c8
1
Parent(s):
ffd11f3
Upload 32to31.py
Browse files- swapper/32to31.py +182 -0
swapper/32to31.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from huggingface_hub import snapshot_download
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from safetensors import safe_open
|
8 |
+
from transformers import AutoProcessor, MllamaForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
9 |
+
|
10 |
+
#total_layers=80 # 70B model has 80 layers
|
11 |
+
total_layers=32 # 8B model has 32 layers
|
12 |
+
|
13 |
+
#cross_attention_layers = [3, 8, 13, 18, 23, 28, 33, 38, 43, 48, 53, 58, 63, 68, 73, 78, 83, 88, 93, 98] # 90B
|
14 |
+
cross_attention_layers = [3, 8, 13, 18, 23, 28, 33, 38] # 11b
|
15 |
+
|
16 |
+
# Update paths - switch source and target
|
17 |
+
target_model = "meta-llama/Llama-3.1-8B-Instruct"
|
18 |
+
print(f"Target model: {target_model}")
|
19 |
+
|
20 |
+
source_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
21 |
+
print(f"Source model: {source_model}")
|
22 |
+
|
23 |
+
def create_inverse_layer_mapping(total_layers=total_layers, cross_attn_layers=cross_attention_layers):
|
24 |
+
"""
|
25 |
+
Creates a mapping from 90B/11B layer indices to 70B/8B layer indices.
|
26 |
+
"""
|
27 |
+
mapping = {}
|
28 |
+
removed_layers = []
|
29 |
+
|
30 |
+
#for i in range(100): # 90B has 100 layers (80 + 20 cross-attention layers)
|
31 |
+
for i in range(40): # 11B has 40 layers (32 + 8 cross-attention layers)
|
32 |
+
if i not in cross_attn_layers and len(mapping) < total_layers:
|
33 |
+
mapping[i] = len(mapping)
|
34 |
+
else:
|
35 |
+
removed_layers.append(i)
|
36 |
+
return mapping, removed_layers
|
37 |
+
|
38 |
+
def load_sharded_state_dict(model_id):
|
39 |
+
"""
|
40 |
+
Load a sharded state dict from either a local directory or a Hugging Face model ID.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
model_id: Either a local path or a Hugging Face model ID (e.g., "meta-llama/Llama-2-7b")
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
dict: The loaded state dictionary
|
47 |
+
"""
|
48 |
+
# Check if model_id is a local path
|
49 |
+
if os.path.isdir(model_id):
|
50 |
+
model_dir = model_id
|
51 |
+
else:
|
52 |
+
# If not local, assume it's a Hugging Face model ID and download it
|
53 |
+
print(f"Downloading model from Hugging Face: {model_id}")
|
54 |
+
model_dir = snapshot_download(
|
55 |
+
model_id,
|
56 |
+
allow_patterns=["*.safetensors*", "*.json"],
|
57 |
+
ignore_patterns=["*.bin", "*.md", "*.py"]
|
58 |
+
)
|
59 |
+
|
60 |
+
# Load the index file
|
61 |
+
index_file = os.path.join(model_dir, 'model.safetensors.index.json')
|
62 |
+
if not os.path.exists(index_file):
|
63 |
+
raise FileNotFoundError(f"Could not find index file: {index_file}")
|
64 |
+
|
65 |
+
with open(index_file, 'r') as f:
|
66 |
+
index_data = json.load(f)
|
67 |
+
|
68 |
+
weight_map = index_data['weight_map']
|
69 |
+
state_dict = {}
|
70 |
+
shard_to_params = {}
|
71 |
+
|
72 |
+
# Group parameters by shard file
|
73 |
+
for param_name, shard_file in weight_map.items():
|
74 |
+
if shard_file not in shard_to_params:
|
75 |
+
shard_to_params[shard_file] = []
|
76 |
+
shard_to_params[shard_file].append(param_name)
|
77 |
+
|
78 |
+
# Load parameters from each shard
|
79 |
+
for shard_file, params_in_shard in shard_to_params.items():
|
80 |
+
shard_path = os.path.join(model_dir, shard_file)
|
81 |
+
with safe_open(shard_path, framework="pt", device="cpu") as f:
|
82 |
+
for name in params_in_shard:
|
83 |
+
state_dict[name] = f.get_tensor(name)
|
84 |
+
|
85 |
+
return state_dict
|
86 |
+
|
87 |
+
def compare_model_states(model, new_state_dict):
|
88 |
+
current_state = model.state_dict()
|
89 |
+
unchanged_params = []
|
90 |
+
changed_params = []
|
91 |
+
missing_params = []
|
92 |
+
|
93 |
+
for name, param in current_state.items():
|
94 |
+
if name not in new_state_dict:
|
95 |
+
missing_params.append(name)
|
96 |
+
elif torch.equal(param, new_state_dict[name]):
|
97 |
+
unchanged_params.append(name)
|
98 |
+
else:
|
99 |
+
sum_abs_diff = torch.sum(torch.abs(param - new_state_dict[name]))
|
100 |
+
changed_params.append({'name': name, 'sum_abs_diff': sum_abs_diff.item()})
|
101 |
+
|
102 |
+
return {
|
103 |
+
'unchanged': unchanged_params,
|
104 |
+
'changed': changed_params,
|
105 |
+
'missing': missing_params
|
106 |
+
}
|
107 |
+
|
108 |
+
|
109 |
+
layer_mapping, removed_layers = create_inverse_layer_mapping()
|
110 |
+
|
111 |
+
# Load source (90B) state dict
|
112 |
+
source_state_dict = load_sharded_state_dict(source_model)
|
113 |
+
|
114 |
+
# Create new state dict for target model (70B)
|
115 |
+
target_state_dict = {}
|
116 |
+
|
117 |
+
# Convert parameter names and copy tensors
|
118 |
+
for name, param in source_state_dict.items():
|
119 |
+
# Skip parameters that aren't part of the language model layers
|
120 |
+
if not (name.startswith('language_model.model.layers.') or
|
121 |
+
name == 'language_model.model.embed_tokens.weight' or
|
122 |
+
name == 'language_model.lm_head.weight' or
|
123 |
+
name == 'language_model.model.norm.weight'):
|
124 |
+
continue
|
125 |
+
|
126 |
+
if name.startswith('language_model.model.layers.'):
|
127 |
+
# Handle layer parameters
|
128 |
+
layer_match = re.match(r'language_model\.model\.layers\.(\d+)\.(.+)', name)
|
129 |
+
if layer_match:
|
130 |
+
source_layer = int(layer_match.group(1))
|
131 |
+
if source_layer in layer_mapping:
|
132 |
+
target_layer = layer_mapping[source_layer]
|
133 |
+
new_name = f'model.layers.{target_layer}.{layer_match.group(2)}'
|
134 |
+
target_state_dict[new_name] = param
|
135 |
+
elif name == 'language_model.lm_head.weight':
|
136 |
+
# Handle lm_head weight
|
137 |
+
target_state_dict['lm_head.weight'] = param
|
138 |
+
elif name == 'language_model.model.embed_tokens.weight':
|
139 |
+
# Handle embeddings - keep original vocab size for 70B model
|
140 |
+
original_embed_size = 128256
|
141 |
+
target_state_dict['model.embed_tokens.weight'] = param[:original_embed_size, :]
|
142 |
+
elif name == 'language_model.model.norm.weight':
|
143 |
+
# Handle model norm weight
|
144 |
+
target_state_dict['model.norm.weight'] = param
|
145 |
+
|
146 |
+
|
147 |
+
#write target_state_dict keys to file for verification
|
148 |
+
with open('target_state_dict.txt', 'w') as f:
|
149 |
+
f.write('\n'.join(target_state_dict.keys()))
|
150 |
+
|
151 |
+
config = AutoConfig.from_pretrained(target_model)
|
152 |
+
|
153 |
+
model = AutoModelForCausalLM.from_pretrained(
|
154 |
+
None,
|
155 |
+
config=config,
|
156 |
+
state_dict = target_state_dict,
|
157 |
+
torch_dtype=torch.bfloat16,
|
158 |
+
)
|
159 |
+
|
160 |
+
'''
|
161 |
+
|
162 |
+
origmodel = AutoModelForCausalLM.from_pretrained(
|
163 |
+
target_model,
|
164 |
+
torch_dtype=torch.bfloat16,
|
165 |
+
)
|
166 |
+
|
167 |
+
result = compare_model_states(model, origmodel.state_dict())
|
168 |
+
print("Unchanged parameters:", len(result['unchanged']))
|
169 |
+
print("Changed parameters:", len(result['changed']))
|
170 |
+
print("Missing parameters:", len(result['missing']))
|
171 |
+
|
172 |
+
#write result to file
|
173 |
+
with open('result.txt', 'w') as f:
|
174 |
+
f.write(json.dumps(result, indent=2))
|
175 |
+
|
176 |
+
'''
|
177 |
+
|
178 |
+
processor = AutoTokenizer.from_pretrained(target_model) #8b/70b
|
179 |
+
#processor = AutoProcessor.from_pretrained(source_model) #11b/90b
|
180 |
+
|
181 |
+
model.save_pretrained("Llama-3.2-8B-extracted")
|
182 |
+
processor.save_pretrained("Llama-3.2-8B-extracted")
|