bababababooey
commited on
Commit
•
eab6e7f
1
Parent(s):
3be96ad
Upload hotswap.py
Browse files- swapper/hotswap.py +158 -0
swapper/hotswap.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import MllamaForConditionalGeneration, AutoProcessor
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from safetensors import safe_open
|
6 |
+
import re
|
7 |
+
|
8 |
+
# apologies in advance for shitty gpt-assisted code
|
9 |
+
|
10 |
+
# this script should also work with 70b/90b if you change `cross_attention_layers` and `total_layers` accordingly
|
11 |
+
# but i dont have enough deditated wam to test it and i dont feel like spinning up runpod so
|
12 |
+
|
13 |
+
cross_attention_layers = [3, 8, 13, 18, 23, 28, 33, 38]
|
14 |
+
|
15 |
+
#b8 = './models/mlabonne_Meta-Llama-3.1-8B-Instruct-abliterated'
|
16 |
+
b8 = './models/v000000_L3-8B-Stheno-v3.2-abliterated'
|
17 |
+
#b8 = './models/arcee-ai_Llama-3.1-SuperNova-Lite'
|
18 |
+
print(b8)
|
19 |
+
|
20 |
+
model_id = "./models/meta-llama_Llama-3.2-11B-Vision-Instruct"
|
21 |
+
|
22 |
+
def create_layer_mapping(total_layers=32, cross_attn_layers=cross_attention_layers):
|
23 |
+
"""
|
24 |
+
Creates a mapping from llama-3.1-8b layer indices to llama-3.2-11b layer indices.
|
25 |
+
"""
|
26 |
+
mapping = {}
|
27 |
+
shift = 0
|
28 |
+
next_cross_attn_idx = 0
|
29 |
+
for X in range(total_layers):
|
30 |
+
# Check if a cross-attention layer is inserted before this layer
|
31 |
+
if next_cross_attn_idx < len(cross_attn_layers) and (X + shift) == cross_attn_layers[next_cross_attn_idx]:
|
32 |
+
shift += 1
|
33 |
+
next_cross_attn_idx += 1
|
34 |
+
Y = X + shift
|
35 |
+
mapping[X] = Y
|
36 |
+
return mapping
|
37 |
+
|
38 |
+
def load_sharded_state_dict(model_dir):
|
39 |
+
index_file = os.path.join(model_dir, 'model.safetensors.index.json')
|
40 |
+
with open(index_file, 'r') as f:
|
41 |
+
index_data = json.load(f)
|
42 |
+
weight_map = index_data['weight_map']
|
43 |
+
state_dict = {}
|
44 |
+
shard_to_params = {}
|
45 |
+
for param_name, shard_file in weight_map.items():
|
46 |
+
if shard_file not in shard_to_params:
|
47 |
+
shard_to_params[shard_file] = []
|
48 |
+
shard_to_params[shard_file].append(param_name)
|
49 |
+
for shard_file, params_in_shard in shard_to_params.items():
|
50 |
+
shard_path = os.path.join(model_dir, shard_file)
|
51 |
+
with safe_open(shard_path, framework="pt", device="cpu") as f:
|
52 |
+
for name in params_in_shard:
|
53 |
+
state_dict[name] = f.get_tensor(name)
|
54 |
+
return state_dict
|
55 |
+
|
56 |
+
def compare_model_states(model, new_state_dict):
|
57 |
+
current_state = model.state_dict()
|
58 |
+
unchanged_params = []
|
59 |
+
changed_params = []
|
60 |
+
missing_params = []
|
61 |
+
|
62 |
+
for name, param in current_state.items():
|
63 |
+
if name not in new_state_dict:
|
64 |
+
missing_params.append(name)
|
65 |
+
elif torch.equal(param.cpu(), new_state_dict[name].cpu()):
|
66 |
+
unchanged_params.append(name)
|
67 |
+
else:
|
68 |
+
changed_params.append(name)
|
69 |
+
|
70 |
+
return {
|
71 |
+
'unchanged': unchanged_params,
|
72 |
+
'changed': changed_params,
|
73 |
+
'missing': missing_params
|
74 |
+
}
|
75 |
+
|
76 |
+
|
77 |
+
layer_mapping = create_layer_mapping()
|
78 |
+
|
79 |
+
# Load Llama 3.2 state dict
|
80 |
+
llama_3_2_state_dict = load_sharded_state_dict(model_id)
|
81 |
+
|
82 |
+
# Extract the embedding matrix from Llama 3.2
|
83 |
+
llama_3_2_embeddings = llama_3_2_state_dict['language_model.model.embed_tokens.weight'] # Shape: [128264, 4096]
|
84 |
+
|
85 |
+
llama_3_2_state_dict.clear()
|
86 |
+
|
87 |
+
b8dict = load_sharded_state_dict(b8)
|
88 |
+
|
89 |
+
embed_tokens_weight = b8dict['model.embed_tokens.weight'] # Shape: [128256, 4096]
|
90 |
+
new_vocab_size = 128264 # From Llama 3.2
|
91 |
+
new_embed_tokens_weight = torch.zeros((new_vocab_size, 4096), dtype=embed_tokens_weight.dtype)
|
92 |
+
|
93 |
+
# Copy the existing embeddings
|
94 |
+
new_embed_tokens_weight[:128256, :] = embed_tokens_weight
|
95 |
+
# Copy the additional embeddings from Llama 3.2
|
96 |
+
new_embed_tokens_weight[128256:, :] = llama_3_2_embeddings[128256:, :]
|
97 |
+
|
98 |
+
b8dict['model.embed_tokens.weight'] = new_embed_tokens_weight
|
99 |
+
|
100 |
+
|
101 |
+
llama_3_2_embeddings = None
|
102 |
+
|
103 |
+
# Adjust Llama 3.1 parameter names to match Llama 3.2 language model
|
104 |
+
st8dict = {}
|
105 |
+
for name, param in b8dict.items():
|
106 |
+
# Prefix non-layer parameters with 'language_model.'
|
107 |
+
if not re.match(r'model\.layers\.\d+\.', name):
|
108 |
+
new_name = 'language_model.' + name
|
109 |
+
else:
|
110 |
+
# Extract the layer index X from 'model.layers.X.'
|
111 |
+
match = re.match(r'model\.layers\.(\d+)\.(.+)', name)
|
112 |
+
if match:
|
113 |
+
X = int(match.group(1))
|
114 |
+
suffix = match.group(2)
|
115 |
+
# Get the corresponding Y in llama-3.2-11b
|
116 |
+
Y = layer_mapping.get(X, X + len(cross_attention_layers))
|
117 |
+
new_name = f'language_model.model.layers.{Y}.{suffix}'
|
118 |
+
else:
|
119 |
+
# If the pattern doesn't match, just prefix with 'language_model.'
|
120 |
+
new_name = 'language_model.' + name
|
121 |
+
st8dict[new_name] = param
|
122 |
+
|
123 |
+
#write st8dict keys to file for verification
|
124 |
+
with open('st8dict.txt', 'w') as f:
|
125 |
+
f.write('\n'.join(st8dict.keys()))
|
126 |
+
|
127 |
+
|
128 |
+
model = MllamaForConditionalGeneration.from_pretrained(
|
129 |
+
model_id,
|
130 |
+
torch_dtype=torch.bfloat16,
|
131 |
+
device_map="cpu",
|
132 |
+
)
|
133 |
+
|
134 |
+
#original_state = {k: v.clone() for k, v in model.state_dict().items()}
|
135 |
+
|
136 |
+
model.load_state_dict(st8dict, strict=False)
|
137 |
+
|
138 |
+
b8dict.clear()
|
139 |
+
st8dict.clear()
|
140 |
+
|
141 |
+
|
142 |
+
'''
|
143 |
+
result = compare_model_states(model, original_state)
|
144 |
+
|
145 |
+
print("Unchanged parameters:", len(result['unchanged']))
|
146 |
+
print("Changed parameters:", len(result['changed']))
|
147 |
+
print("Missing parameters:", len(result['missing']))
|
148 |
+
|
149 |
+
#write result to file
|
150 |
+
with open('result.txt', 'w') as f:
|
151 |
+
f.write(json.dumps(result, indent=2))
|
152 |
+
'''
|
153 |
+
|
154 |
+
|
155 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
156 |
+
|
157 |
+
|
158 |
+
model.save_pretrained("llama-3.2-11b-vision-stheno-abliterated")
|