|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import re |
|
|
|
import yaml |
|
import torch |
|
from collections import OrderedDict |
|
|
|
import datetime |
|
|
|
|
|
def load_checkpoint(model: torch.nn.Module, path: str) -> dict: |
|
if torch.cuda.is_available(): |
|
logging.info("Checkpoint: loading from checkpoint %s for GPU" % path) |
|
checkpoint = torch.load(path) |
|
else: |
|
logging.info("Checkpoint: loading from checkpoint %s for CPU" % path) |
|
checkpoint = torch.load(path, map_location="cpu") |
|
model.load_state_dict(checkpoint, strict=False) |
|
info_path = re.sub(".pt$", ".yaml", path) |
|
configs = {} |
|
if os.path.exists(info_path): |
|
with open(info_path, "r") as fin: |
|
configs = yaml.load(fin, Loader=yaml.FullLoader) |
|
return configs |
|
|
|
|
|
def save_checkpoint(model: torch.nn.Module, path: str, infos=None): |
|
""" |
|
Args: |
|
infos (dict or None): any info you want to save. |
|
""" |
|
logging.info("Checkpoint: save to checkpoint %s" % path) |
|
if isinstance(model, torch.nn.DataParallel): |
|
state_dict = model.module.state_dict() |
|
elif isinstance(model, torch.nn.parallel.DistributedDataParallel): |
|
state_dict = model.module.state_dict() |
|
else: |
|
state_dict = model.state_dict() |
|
torch.save(state_dict, path) |
|
info_path = re.sub(".pt$", ".yaml", path) |
|
if infos is None: |
|
infos = {} |
|
infos["save_time"] = datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S") |
|
with open(info_path, "w") as fout: |
|
data = yaml.dump(infos) |
|
fout.write(data) |
|
|
|
|
|
def filter_modules(model_state_dict, modules): |
|
new_mods = [] |
|
incorrect_mods = [] |
|
mods_model = model_state_dict.keys() |
|
for mod in modules: |
|
if any(key.startswith(mod) for key in mods_model): |
|
new_mods += [mod] |
|
else: |
|
incorrect_mods += [mod] |
|
if incorrect_mods: |
|
logging.warning( |
|
"module(s) %s don't match or (partially match) " |
|
"available modules in model.", |
|
incorrect_mods, |
|
) |
|
logging.warning("for information, the existing modules in model are:") |
|
logging.warning("%s", mods_model) |
|
|
|
return new_mods |
|
|
|
|
|
def load_trained_modules(model: torch.nn.Module, args: None): |
|
|
|
enc_model_path = args.enc_init |
|
enc_modules = args.enc_init_mods |
|
main_state_dict = model.state_dict() |
|
logging.warning("model(s) found for pre-initialization") |
|
if os.path.isfile(enc_model_path): |
|
logging.info("Checkpoint: loading from checkpoint %s for CPU" % enc_model_path) |
|
model_state_dict = torch.load(enc_model_path, map_location="cpu") |
|
modules = filter_modules(model_state_dict, enc_modules) |
|
partial_state_dict = OrderedDict() |
|
for key, value in model_state_dict.items(): |
|
if any(key.startswith(m) for m in modules): |
|
partial_state_dict[key] = value |
|
main_state_dict.update(partial_state_dict) |
|
else: |
|
logging.warning("model was not found : %s", enc_model_path) |
|
|
|
model.load_state_dict(main_state_dict) |
|
configs = {} |
|
return configs |
|
|