File size: 3,329 Bytes
590af54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import deepspeed
from transformers import AutoConfig
from transformers.deepspeed import is_deepspeed_zero3_enabled
from torch import nn


def remove_mismatched_weights(model, pretrained_state_dict):
    own_state = model.state_dict()
    mismatch_keys = []

    for name in list(pretrained_state_dict.keys()):
        if name not in own_state or own_state[name].shape != pretrained_state_dict[name].shape:
            mismatch_keys.append(name)
            pretrained_state_dict.pop(name)

    return pretrained_state_dict, mismatch_keys


def load_zero3_checkpoint(module: nn.Module, state_dict, prefix="", error_msgs = [], top=True):
    # check if zero3 
    
    zero3_enabled = is_deepspeed_zero3_enabled()
    # print(f'zero3_enabled: {zero3_enabled}')

    if not is_deepspeed_zero3_enabled():

        state_dict, mismatch_keys = remove_mismatched_weights(module, state_dict)



        info = module.load_state_dict(state_dict, strict=False)


        if len(mismatch_keys) > 0:
            print("shape mismatch keys: ", mismatch_keys)


        if len(info.missing_keys) > 0:
            print("missing keys: ", info.missing_keys)
        
        if len(info.unexpected_keys) > 0:
            print("unexpected keys: ", info.unexpected_keys)

    else:
        # error_msgs = []
        local_metadata = {}
        args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
        # Parameters of module and children will start with prefix. We can exit early if there are none in this
        # state_dict
        if len([key for key in state_dict if key.startswith(prefix)]) > 0:
    
            named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
            params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
            params_name = [k for k in state_dict.keys() if k in named_parameters]
            ## named buffer for layers like batchnorm
            named_buffers = dict(module.named_buffers(prefix=prefix[:-1], recurse=False))
            buffers_to_gather = [named_buffers[k] for k in state_dict.keys() if k in named_buffers]

            if len(params_to_gather) > 0 or len(buffers_to_gather)>0:
                # if len(buffers_to_gather)>0:
                #     print("loading buffers")
                with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
                    # if torch.distributed.get_rank() == 0:
                    # if only rank0, then module's buffer will not be syncd
                    # for k, v in zip(params_name, params_to_gather):
                        # log the shape of the loaded weights
                        # print(f'loading {k} with shape {v.shape}')
                    module._load_from_state_dict(*args)

                
            # if len (error_msgs) > 0:
            #     print(error_msgs)
        
        for name, child in module._modules.items():
            if child is not None:
                load_zero3_checkpoint(child, state_dict, prefix + name + ".", top=False)
        
        if top:
            if len(error_msgs) > 0:
                print('loading zero3 model weights meets error messages!')
                print(error_msgs)
            else:
                print('loading zero3 model weights success!')