File size: 1,628 Bytes
4d061f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from typing import Dict, Any
import torch.nn as nn
def extract_model_weights(reference_model, n_layers):
params = {}
current_layer = 0 # To keep track of the main layer count
# Iterate over all named modules
for name, module in reference_model.named_modules():
# Check and store parameters
if hasattr(module, 'weight') and module.weight is not None:
params[name + '.weight'] = module.weight.data.clone()
if hasattr(module, 'bias') and module.bias is not None:
params[name + '.bias'] = module.bias.data.clone()
if 'model.layers.' in name:
# Check the layer index
layer_index = int(name.split('.')[2]) # This splits the name and gets the third element
if layer_index > current_layer:
current_layer = layer_index
if current_layer > n_layers-1:
break # Stop after reaching the specified main layer
norm_layer = reference_model.model.norm # Adjust this path based on your model's architecture
if hasattr(norm_layer, 'weight') and norm_layer.weight is not None:
params['model.norm.weight'] = norm_layer.weight.data.clone()
if hasattr(norm_layer, 'bias') and norm_layer.bias is not None:
params['model.norm.bias'] = norm_layer.bias.data.clone()
lm_head = reference_model.lm_head
if hasattr(lm_head, 'weight') and lm_head.weight is not None:
params["lm_head.weight"] = lm_head.weight.data
if hasattr(lm_head, 'bias') and lm_head.bias is not None:
params["lm_head.bias"] = lm_head.bias.data
return params
|