wsntxxn
Add AudioCaps checkpoint
6065472
raw
history blame
4.16 kB
import importlib
import os
import sys
from typing import Callable, Dict, Union
import numpy as np
import yaml
import torch
def merge_a_into_b(a, b):
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b:
assert isinstance(
b[k], dict
), "Cannot inherit key '{}' from base!".format(k)
merge_a_into_b(v, b[k])
else:
b[k] = v
def load_config(config_file):
with open(config_file, "r") as reader:
config = yaml.load(reader, Loader=yaml.FullLoader)
if "inherit_from" in config:
base_config_file = config["inherit_from"]
base_config_file = os.path.join(
os.path.dirname(config_file), base_config_file
)
assert not os.path.samefile(config_file, base_config_file), \
"inherit from itself"
base_config = load_config(base_config_file)
del config["inherit_from"]
merge_a_into_b(config, base_config)
return base_config
return config
def get_cls_from_str(string, reload=False):
module_name, cls_name = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module_name)
importlib.reload(module_imp)
return getattr(importlib.import_module(module_name, package=None), cls_name)
def init_obj_from_dict(config, **kwargs):
obj_args = config["args"].copy()
obj_args.update(kwargs)
for k in config:
if k not in ["type", "args"] and isinstance(config[k], dict) and k not in kwargs:
obj_args[k] = init_obj_from_dict(config[k])
try:
obj = get_cls_from_str(config["type"])(**obj_args)
return obj
except Exception as e:
print(f"Initializing {config} failed, detailed error stack: ")
raise e
def init_model_from_config(config, print_fn=sys.stdout.write):
kwargs = {}
for k in config:
if k not in ["type", "args", "pretrained"]:
sub_model = init_model_from_config(config[k], print_fn)
if "pretrained" in config[k]:
load_pretrained_model(sub_model,
config[k]["pretrained"],
print_fn)
kwargs[k] = sub_model
model = init_obj_from_dict(config, **kwargs)
return model
def merge_load_state_dict(state_dict,
model: torch.nn.Module,
output_fn: Callable = sys.stdout.write):
model_dict = model.state_dict()
pretrained_dict = {}
mismatch_keys = []
for key, value in state_dict.items():
if key in model_dict and model_dict[key].shape == value.shape:
pretrained_dict[key] = value
else:
mismatch_keys.append(key)
output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}")
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, strict=True)
return pretrained_dict.keys()
def load_pretrained_model(model: torch.nn.Module,
pretrained: Union[str, Dict],
output_fn: Callable = sys.stdout.write):
if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
output_fn(f"pretrained {pretrained} not exist!")
return
if hasattr(model, "load_pretrained"):
model.load_pretrained(pretrained, output_fn)
return
if isinstance(pretrained, dict):
state_dict = pretrained
else:
state_dict = torch.load(pretrained, map_location="cpu")
if "model" in state_dict:
state_dict = state_dict["model"]
merge_load_state_dict(state_dict, model, output_fn)
def pad_sequence(data, pad_value=0):
if isinstance(data[0], (np.ndarray, torch.Tensor)):
data = [torch.as_tensor(arr) for arr in data]
padded_seq = torch.nn.utils.rnn.pad_sequence(data,
batch_first=True,
padding_value=pad_value)
length = np.array([x.shape[0] for x in data])
return padded_seq, length