import importlib import os import sys from typing import Callable, Dict, Union import numpy as np import yaml import torch def merge_a_into_b(a, b): # merge dict a into dict b. values in a will overwrite b. for k, v in a.items(): if isinstance(v, dict) and k in b: assert isinstance( b[k], dict ), "Cannot inherit key '{}' from base!".format(k) merge_a_into_b(v, b[k]) else: b[k] = v def load_config(config_file): with open(config_file, "r") as reader: config = yaml.load(reader, Loader=yaml.FullLoader) if "inherit_from" in config: base_config_file = config["inherit_from"] base_config_file = os.path.join( os.path.dirname(config_file), base_config_file ) assert not os.path.samefile(config_file, base_config_file), \ "inherit from itself" base_config = load_config(base_config_file) del config["inherit_from"] merge_a_into_b(config, base_config) return base_config return config def get_cls_from_str(string, reload=False): module_name, cls_name = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module_name) importlib.reload(module_imp) return getattr(importlib.import_module(module_name, package=None), cls_name) def init_obj_from_dict(config, **kwargs): obj_args = config["args"].copy() obj_args.update(kwargs) for k in config: if k not in ["type", "args"] and isinstance(config[k], dict) and k not in kwargs: obj_args[k] = init_obj_from_dict(config[k]) try: obj = get_cls_from_str(config["type"])(**obj_args) return obj except Exception as e: print(f"Initializing {config} failed, detailed error stack: ") raise e def init_model_from_config(config, print_fn=sys.stdout.write): kwargs = {} for k in config: if k not in ["type", "args", "pretrained"]: sub_model = init_model_from_config(config[k], print_fn) if "pretrained" in config[k]: load_pretrained_model(sub_model, config[k]["pretrained"], print_fn) kwargs[k] = sub_model model = init_obj_from_dict(config, **kwargs) return model def merge_load_state_dict(state_dict, model: torch.nn.Module, output_fn: Callable = sys.stdout.write): model_dict = model.state_dict() pretrained_dict = {} mismatch_keys = [] for key, value in state_dict.items(): if key in model_dict and model_dict[key].shape == value.shape: pretrained_dict[key] = value else: mismatch_keys.append(key) output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}") model_dict.update(pretrained_dict) model.load_state_dict(model_dict, strict=True) return pretrained_dict.keys() def load_pretrained_model(model: torch.nn.Module, pretrained: Union[str, Dict], output_fn: Callable = sys.stdout.write): if not isinstance(pretrained, dict) and not os.path.exists(pretrained): output_fn(f"pretrained {pretrained} not exist!") return if hasattr(model, "load_pretrained"): model.load_pretrained(pretrained, output_fn) return if isinstance(pretrained, dict): state_dict = pretrained else: state_dict = torch.load(pretrained, map_location="cpu") if "model" in state_dict: state_dict = state_dict["model"] merge_load_state_dict(state_dict, model, output_fn) def pad_sequence(data, pad_value=0): if isinstance(data[0], (np.ndarray, torch.Tensor)): data = [torch.as_tensor(arr) for arr in data] padded_seq = torch.nn.utils.rnn.pad_sequence(data, batch_first=True, padding_value=pad_value) length = np.array([x.shape[0] for x in data]) return padded_seq, length