WavMark / utils /model_util.py
my
Add application file
32ca76b
raw
history blame
No virus
3.43 kB
import torch
import os
import json
import sys
from utils import pickle_util
history_array = []
def save_model(epoch, model, optimizer, file_save_path):
dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir))
if not os.path.exists(dirpath):
print("mkdir:", dirpath)
os.makedirs(dirpath)
opti = None
if optimizer is not None:
opti = optimizer.state_dict()
torch.save(obj={
'epoch': epoch,
'model': model.state_dict(),
'optimizer': opti,
}, f=file_save_path)
history_array.append(file_save_path)
def save_model_v4(epoch, model, optimizer, file_save_path, discriminator):
dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir))
if not os.path.exists(dirpath):
print("mkdir:", dirpath)
os.makedirs(dirpath)
opti = None
if optimizer is not None:
opti = optimizer.state_dict()
torch.save(obj={
'epoch': epoch,
'model': model.state_dict(),
'optimizer': opti,
"discriminator": discriminator,
}, f=file_save_path)
history_array.append(file_save_path)
def delete_last_saved_model():
if len(history_array) == 0:
return
last_path = history_array.pop()
if os.path.exists(last_path):
os.remove(last_path)
print("delete model:", last_path)
if os.path.exists(last_path + ".json"):
os.remove(last_path + ".json")
def load_model(resume_path, model, optimizer=None, strict=True):
checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
start_epoch = checkpoint['epoch'] + 1
model.load_state_dict(checkpoint['model'], strict=strict)
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer'])
print("checkpoint loaded!")
return start_epoch
def save_model_v2(model, args, model_save_name):
model_save_path = os.path.join(args.model_save_folder, args.project, args.name, model_save_name)
save_model(0, model, None, model_save_path)
print("save:", model_save_path)
def save_project_info(args):
run_info = {
"cmd_str": ' '.join(sys.argv[1:]),
"args": vars(args),
}
name = "run_info.json"
folder = os.path.join(args.model_save_folder, args.project, args.name)
if not os.path.exists(folder):
os.makedirs(folder)
json_file_path = os.path.join(folder, name)
with open(json_file_path, "w") as f:
json.dump(run_info, f)
print("save_project_info:", json_file_path)
def get_pkl_json(folder):
names = [i for i in os.listdir(folder) if ".pkl.json" in i]
assert len(names) == 1
json_path = os.path.join(folder, names[0])
obj = pickle_util.read_json(json_path)
return obj
# 并行
def is_data_parallel_checkpoint(state_dict):
return any(key.startswith('module.') for key in state_dict.keys())
def map_state_dict(state_dict):
if is_data_parallel_checkpoint(state_dict):
# 处理 DataParallel 添加的前缀 'module.'
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k # 移除前缀 'module.'
new_state_dict[name] = v
return new_state_dict
return state_dict