BitNet3-8B-Converted / model_utils.py
ejbejaranos's picture
Upload folder using huggingface_hub
4d061f7 verified
from typing import Dict, Any
import torch.nn as nn
def extract_model_weights(reference_model, n_layers):
params = {}
current_layer = 0 # To keep track of the main layer count
# Iterate over all named modules
for name, module in reference_model.named_modules():
# Check and store parameters
if hasattr(module, 'weight') and module.weight is not None:
params[name + '.weight'] = module.weight.data.clone()
if hasattr(module, 'bias') and module.bias is not None:
params[name + '.bias'] = module.bias.data.clone()
if 'model.layers.' in name:
# Check the layer index
layer_index = int(name.split('.')[2]) # This splits the name and gets the third element
if layer_index > current_layer:
current_layer = layer_index
if current_layer > n_layers-1:
break # Stop after reaching the specified main layer
norm_layer = reference_model.model.norm # Adjust this path based on your model's architecture
if hasattr(norm_layer, 'weight') and norm_layer.weight is not None:
params['model.norm.weight'] = norm_layer.weight.data.clone()
if hasattr(norm_layer, 'bias') and norm_layer.bias is not None:
params['model.norm.bias'] = norm_layer.bias.data.clone()
lm_head = reference_model.lm_head
if hasattr(lm_head, 'weight') and lm_head.weight is not None:
params["lm_head.weight"] = lm_head.weight.data
if hasattr(lm_head, 'bias') and lm_head.bias is not None:
params["lm_head.bias"] = lm_head.bias.data
return params