File size: 5,406 Bytes
eab6e7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
from transformers import MllamaForConditionalGeneration, AutoProcessor
import os
import json
from safetensors import safe_open
import re
# apologies in advance for shitty gpt-assisted code
# this script should also work with 70b/90b if you change `cross_attention_layers` and `total_layers` accordingly
# but i dont have enough deditated wam to test it and i dont feel like spinning up runpod so
cross_attention_layers = [3, 8, 13, 18, 23, 28, 33, 38]
#b8 = './models/mlabonne_Meta-Llama-3.1-8B-Instruct-abliterated'
b8 = './models/v000000_L3-8B-Stheno-v3.2-abliterated'
#b8 = './models/arcee-ai_Llama-3.1-SuperNova-Lite'
print(b8)
model_id = "./models/meta-llama_Llama-3.2-11B-Vision-Instruct"
def create_layer_mapping(total_layers=32, cross_attn_layers=cross_attention_layers):
"""
Creates a mapping from llama-3.1-8b layer indices to llama-3.2-11b layer indices.
"""
mapping = {}
shift = 0
next_cross_attn_idx = 0
for X in range(total_layers):
# Check if a cross-attention layer is inserted before this layer
if next_cross_attn_idx < len(cross_attn_layers) and (X + shift) == cross_attn_layers[next_cross_attn_idx]:
shift += 1
next_cross_attn_idx += 1
Y = X + shift
mapping[X] = Y
return mapping
def load_sharded_state_dict(model_dir):
index_file = os.path.join(model_dir, 'model.safetensors.index.json')
with open(index_file, 'r') as f:
index_data = json.load(f)
weight_map = index_data['weight_map']
state_dict = {}
shard_to_params = {}
for param_name, shard_file in weight_map.items():
if shard_file not in shard_to_params:
shard_to_params[shard_file] = []
shard_to_params[shard_file].append(param_name)
for shard_file, params_in_shard in shard_to_params.items():
shard_path = os.path.join(model_dir, shard_file)
with safe_open(shard_path, framework="pt", device="cpu") as f:
for name in params_in_shard:
state_dict[name] = f.get_tensor(name)
return state_dict
def compare_model_states(model, new_state_dict):
current_state = model.state_dict()
unchanged_params = []
changed_params = []
missing_params = []
for name, param in current_state.items():
if name not in new_state_dict:
missing_params.append(name)
elif torch.equal(param.cpu(), new_state_dict[name].cpu()):
unchanged_params.append(name)
else:
changed_params.append(name)
return {
'unchanged': unchanged_params,
'changed': changed_params,
'missing': missing_params
}
layer_mapping = create_layer_mapping()
# Load Llama 3.2 state dict
llama_3_2_state_dict = load_sharded_state_dict(model_id)
# Extract the embedding matrix from Llama 3.2
llama_3_2_embeddings = llama_3_2_state_dict['language_model.model.embed_tokens.weight'] # Shape: [128264, 4096]
llama_3_2_state_dict.clear()
b8dict = load_sharded_state_dict(b8)
embed_tokens_weight = b8dict['model.embed_tokens.weight'] # Shape: [128256, 4096]
new_vocab_size = 128264 # From Llama 3.2
new_embed_tokens_weight = torch.zeros((new_vocab_size, 4096), dtype=embed_tokens_weight.dtype)
# Copy the existing embeddings
new_embed_tokens_weight[:128256, :] = embed_tokens_weight
# Copy the additional embeddings from Llama 3.2
new_embed_tokens_weight[128256:, :] = llama_3_2_embeddings[128256:, :]
b8dict['model.embed_tokens.weight'] = new_embed_tokens_weight
llama_3_2_embeddings = None
# Adjust Llama 3.1 parameter names to match Llama 3.2 language model
st8dict = {}
for name, param in b8dict.items():
# Prefix non-layer parameters with 'language_model.'
if not re.match(r'model\.layers\.\d+\.', name):
new_name = 'language_model.' + name
else:
# Extract the layer index X from 'model.layers.X.'
match = re.match(r'model\.layers\.(\d+)\.(.+)', name)
if match:
X = int(match.group(1))
suffix = match.group(2)
# Get the corresponding Y in llama-3.2-11b
Y = layer_mapping.get(X, X + len(cross_attention_layers))
new_name = f'language_model.model.layers.{Y}.{suffix}'
else:
# If the pattern doesn't match, just prefix with 'language_model.'
new_name = 'language_model.' + name
st8dict[new_name] = param
#write st8dict keys to file for verification
with open('st8dict.txt', 'w') as f:
f.write('\n'.join(st8dict.keys()))
model = MllamaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
#original_state = {k: v.clone() for k, v in model.state_dict().items()}
model.load_state_dict(st8dict, strict=False)
b8dict.clear()
st8dict.clear()
'''
result = compare_model_states(model, original_state)
print("Unchanged parameters:", len(result['unchanged']))
print("Changed parameters:", len(result['changed']))
print("Missing parameters:", len(result['missing']))
#write result to file
with open('result.txt', 'w') as f:
f.write(json.dumps(result, indent=2))
'''
processor = AutoProcessor.from_pretrained(model_id)
model.save_pretrained("llama-3.2-11b-vision-stheno-abliterated")
|