|
from typing import Dict, Any |
|
import torch.nn as nn |
|
|
|
def extract_model_weights(reference_model, n_layers): |
|
params = {} |
|
current_layer = 0 |
|
|
|
|
|
for name, module in reference_model.named_modules(): |
|
|
|
|
|
if hasattr(module, 'weight') and module.weight is not None: |
|
params[name + '.weight'] = module.weight.data.clone() |
|
if hasattr(module, 'bias') and module.bias is not None: |
|
params[name + '.bias'] = module.bias.data.clone() |
|
|
|
if 'model.layers.' in name: |
|
|
|
layer_index = int(name.split('.')[2]) |
|
if layer_index > current_layer: |
|
current_layer = layer_index |
|
if current_layer > n_layers-1: |
|
break |
|
|
|
norm_layer = reference_model.model.norm |
|
if hasattr(norm_layer, 'weight') and norm_layer.weight is not None: |
|
params['model.norm.weight'] = norm_layer.weight.data.clone() |
|
if hasattr(norm_layer, 'bias') and norm_layer.bias is not None: |
|
params['model.norm.bias'] = norm_layer.bias.data.clone() |
|
|
|
lm_head = reference_model.lm_head |
|
if hasattr(lm_head, 'weight') and lm_head.weight is not None: |
|
params["lm_head.weight"] = lm_head.weight.data |
|
if hasattr(lm_head, 'bias') and lm_head.bias is not None: |
|
params["lm_head.bias"] = lm_head.bias.data |
|
|
|
return params |
|
|
|
|