Spaces:
Runtime error
Runtime error
import glob | |
import os | |
import re | |
import torch | |
def get_last_checkpoint(work_dir, steps=None): | |
checkpoint = None | |
last_ckpt_path = None | |
ckpt_paths = get_all_ckpts(work_dir, steps) | |
if len(ckpt_paths) > 0: | |
last_ckpt_path = ckpt_paths[0] | |
checkpoint = torch.load(last_ckpt_path, map_location='cpu') | |
return checkpoint, last_ckpt_path | |
def get_all_ckpts(work_dir, steps=None): | |
if steps is None: | |
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' | |
else: | |
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' | |
return sorted(glob.glob(ckpt_path_pattern), | |
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) | |
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True): | |
if os.path.isfile(ckpt_base_dir): | |
base_dir = os.path.dirname(ckpt_base_dir) | |
ckpt_path = ckpt_base_dir | |
checkpoint = torch.load(ckpt_base_dir, map_location='cpu') | |
else: | |
base_dir = ckpt_base_dir | |
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) | |
if checkpoint is not None: | |
state_dict = checkpoint["state_dict"] | |
if len([k for k in state_dict.keys() if '.' in k]) > 0: | |
state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() | |
if k.startswith(f'{model_name}.')} | |
else: | |
if '.' not in model_name: | |
state_dict = state_dict[model_name] | |
else: | |
base_model_name = model_name.split('.')[0] | |
rest_model_name = model_name[len(base_model_name) + 1:] | |
state_dict = { | |
k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items() | |
if k.startswith(f'{rest_model_name}.')} | |
if not strict: | |
cur_model_state_dict = cur_model.state_dict() | |
unmatched_keys = [] | |
for key, param in state_dict.items(): | |
if key in cur_model_state_dict: | |
new_param = cur_model_state_dict[key] | |
if new_param.shape != param.shape: | |
unmatched_keys.append(key) | |
print("| Unmatched keys: ", key, new_param.shape, param.shape) | |
for key in unmatched_keys: | |
del state_dict[key] | |
cur_model.load_state_dict(state_dict, strict=strict) | |
print(f"| load '{model_name}' from '{ckpt_path}'.") | |
else: | |
e_msg = f"| ckpt not found in {base_dir}." | |
if force: | |
assert False, e_msg | |
else: | |
print(e_msg) | |