Spaces:
Runtime error
Runtime error
import importlib | |
import os | |
import sys | |
from typing import Callable, Dict, Union | |
import numpy as np | |
import yaml | |
import torch | |
def merge_a_into_b(a, b): | |
# merge dict a into dict b. values in a will overwrite b. | |
for k, v in a.items(): | |
if isinstance(v, dict) and k in b: | |
assert isinstance( | |
b[k], dict | |
), "Cannot inherit key '{}' from base!".format(k) | |
merge_a_into_b(v, b[k]) | |
else: | |
b[k] = v | |
def load_config(config_file): | |
with open(config_file, "r") as reader: | |
config = yaml.load(reader, Loader=yaml.FullLoader) | |
if "inherit_from" in config: | |
base_config_file = config["inherit_from"] | |
base_config_file = os.path.join( | |
os.path.dirname(config_file), base_config_file | |
) | |
assert not os.path.samefile(config_file, base_config_file), \ | |
"inherit from itself" | |
base_config = load_config(base_config_file) | |
del config["inherit_from"] | |
merge_a_into_b(config, base_config) | |
return base_config | |
return config | |
def get_cls_from_str(string, reload=False): | |
module_name, cls_name = string.rsplit(".", 1) | |
if reload: | |
module_imp = importlib.import_module(module_name) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module_name, package=None), cls_name) | |
def init_obj_from_dict(config, **kwargs): | |
obj_args = config["args"].copy() | |
obj_args.update(kwargs) | |
for k in config: | |
if k not in ["type", "args"] and isinstance(config[k], dict) and k not in kwargs: | |
obj_args[k] = init_obj_from_dict(config[k]) | |
try: | |
obj = get_cls_from_str(config["type"])(**obj_args) | |
return obj | |
except Exception as e: | |
print(f"Initializing {config} failed, detailed error stack: ") | |
raise e | |
def init_model_from_config(config, print_fn=sys.stdout.write): | |
kwargs = {} | |
for k in config: | |
if k not in ["type", "args", "pretrained"]: | |
sub_model = init_model_from_config(config[k], print_fn) | |
if "pretrained" in config[k]: | |
load_pretrained_model(sub_model, | |
config[k]["pretrained"], | |
print_fn) | |
kwargs[k] = sub_model | |
model = init_obj_from_dict(config, **kwargs) | |
return model | |
def merge_load_state_dict(state_dict, | |
model: torch.nn.Module, | |
output_fn: Callable = sys.stdout.write): | |
model_dict = model.state_dict() | |
pretrained_dict = {} | |
mismatch_keys = [] | |
for key, value in state_dict.items(): | |
if key in model_dict and model_dict[key].shape == value.shape: | |
pretrained_dict[key] = value | |
else: | |
mismatch_keys.append(key) | |
output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}") | |
model_dict.update(pretrained_dict) | |
model.load_state_dict(model_dict, strict=True) | |
return pretrained_dict.keys() | |
def load_pretrained_model(model: torch.nn.Module, | |
pretrained: Union[str, Dict], | |
output_fn: Callable = sys.stdout.write): | |
if not isinstance(pretrained, dict) and not os.path.exists(pretrained): | |
output_fn(f"pretrained {pretrained} not exist!") | |
return | |
if hasattr(model, "load_pretrained"): | |
model.load_pretrained(pretrained, output_fn) | |
return | |
if isinstance(pretrained, dict): | |
state_dict = pretrained | |
else: | |
state_dict = torch.load(pretrained, map_location="cpu") | |
if "model" in state_dict: | |
state_dict = state_dict["model"] | |
merge_load_state_dict(state_dict, model, output_fn) | |
def pad_sequence(data, pad_value=0): | |
if isinstance(data[0], (np.ndarray, torch.Tensor)): | |
data = [torch.as_tensor(arr) for arr in data] | |
padded_seq = torch.nn.utils.rnn.pad_sequence(data, | |
batch_first=True, | |
padding_value=pad_value) | |
length = np.array([x.shape[0] for x in data]) | |
return padded_seq, length |