from typing import Dict, Any import torch.nn as nn def extract_model_weights(reference_model, n_layers): params = {} current_layer = 0 # To keep track of the main layer count # Iterate over all named modules for name, module in reference_model.named_modules(): # Check and store parameters if hasattr(module, 'weight') and module.weight is not None: params[name + '.weight'] = module.weight.data.clone() if hasattr(module, 'bias') and module.bias is not None: params[name + '.bias'] = module.bias.data.clone() if 'model.layers.' in name: # Check the layer index layer_index = int(name.split('.')[2]) # This splits the name and gets the third element if layer_index > current_layer: current_layer = layer_index if current_layer > n_layers-1: break # Stop after reaching the specified main layer norm_layer = reference_model.model.norm # Adjust this path based on your model's architecture if hasattr(norm_layer, 'weight') and norm_layer.weight is not None: params['model.norm.weight'] = norm_layer.weight.data.clone() if hasattr(norm_layer, 'bias') and norm_layer.bias is not None: params['model.norm.bias'] = norm_layer.bias.data.clone() lm_head = reference_model.lm_head if hasattr(lm_head, 'weight') and lm_head.weight is not None: params["lm_head.weight"] = lm_head.weight.data if hasattr(lm_head, 'bias') and lm_head.bias is not None: params["lm_head.bias"] = lm_head.bias.data return params