File size: 3,431 Bytes
32ca76b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import os
import json
import sys
from utils import pickle_util

history_array = []


def save_model(epoch, model, optimizer, file_save_path):
    dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir))
    if not os.path.exists(dirpath):
        print("mkdir:", dirpath)
        os.makedirs(dirpath)

    opti = None
    if optimizer is not None:
        opti = optimizer.state_dict()

    torch.save(obj={
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': opti,
    }, f=file_save_path)

    history_array.append(file_save_path)


def save_model_v4(epoch, model, optimizer, file_save_path, discriminator):
    dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir))
    if not os.path.exists(dirpath):
        print("mkdir:", dirpath)
        os.makedirs(dirpath)

    opti = None
    if optimizer is not None:
        opti = optimizer.state_dict()

    torch.save(obj={
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': opti,
        "discriminator": discriminator,
    }, f=file_save_path)

    history_array.append(file_save_path)


def delete_last_saved_model():
    if len(history_array) == 0:
        return
    last_path = history_array.pop()
    if os.path.exists(last_path):
        os.remove(last_path)
        print("delete model:", last_path)

    if os.path.exists(last_path + ".json"):
        os.remove(last_path + ".json")


def load_model(resume_path, model, optimizer=None, strict=True):
    checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
    start_epoch = checkpoint['epoch'] + 1
    model.load_state_dict(checkpoint['model'], strict=strict)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer'])
    print("checkpoint loaded!")
    return start_epoch


def save_model_v2(model, args, model_save_name):
    model_save_path = os.path.join(args.model_save_folder, args.project, args.name, model_save_name)
    save_model(0, model, None, model_save_path)
    print("save:", model_save_path)


def save_project_info(args):
    run_info = {
        "cmd_str": ' '.join(sys.argv[1:]),
        "args": vars(args),
    }

    name = "run_info.json"
    folder = os.path.join(args.model_save_folder, args.project, args.name)
    if not os.path.exists(folder):
        os.makedirs(folder)

    json_file_path = os.path.join(folder, name)
    with open(json_file_path, "w") as f:
        json.dump(run_info, f)

    print("save_project_info:", json_file_path)


def get_pkl_json(folder):
    names = [i for i in os.listdir(folder) if ".pkl.json" in i]
    assert len(names) == 1
    json_path = os.path.join(folder, names[0])
    obj = pickle_util.read_json(json_path)
    return obj


# 并行

def is_data_parallel_checkpoint(state_dict):
    return any(key.startswith('module.') for key in state_dict.keys())


def map_state_dict(state_dict):
    if is_data_parallel_checkpoint(state_dict):
        # 处理 DataParallel 添加的前缀 'module.'
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] if k.startswith('module.') else k  # 移除前缀 'module.'
            new_state_dict[name] = v
        return new_state_dict
    return state_dict