File size: 1,628 Bytes
4d061f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import Dict, Any
import torch.nn as nn

def extract_model_weights(reference_model, n_layers):
    params = {}
    current_layer = 0  # To keep track of the main layer count

    # Iterate over all named modules
    for name, module in reference_model.named_modules():

        # Check and store parameters
        if hasattr(module, 'weight') and module.weight is not None:
            params[name + '.weight'] = module.weight.data.clone()
        if hasattr(module, 'bias') and module.bias is not None:
            params[name + '.bias'] = module.bias.data.clone()

        if 'model.layers.' in name:
            # Check the layer index
            layer_index = int(name.split('.')[2])  # This splits the name and gets the third element
            if layer_index > current_layer:
                current_layer = layer_index
                if current_layer > n_layers-1:
                    break  # Stop after reaching the specified main layer

    norm_layer = reference_model.model.norm  # Adjust this path based on your model's architecture
    if hasattr(norm_layer, 'weight') and norm_layer.weight is not None:
        params['model.norm.weight'] = norm_layer.weight.data.clone()
    if hasattr(norm_layer, 'bias') and norm_layer.bias is not None:
        params['model.norm.bias'] = norm_layer.bias.data.clone()

    lm_head = reference_model.lm_head
    if hasattr(lm_head, 'weight') and lm_head.weight is not None:
        params["lm_head.weight"] = lm_head.weight.data
    if hasattr(lm_head, 'bias') and lm_head.bias is not None:
        params["lm_head.bias"] = lm_head.bias.data

    return params