|
|
|
|
|
import os |
|
import sys |
|
import logging |
|
from typing import Callable, Dict, Union |
|
import yaml |
|
import torch |
|
from torch.optim.swa_utils import AveragedModel as torch_average_model |
|
import numpy as np |
|
import pandas as pd |
|
from pprint import pformat |
|
|
|
|
|
def load_dict_from_csv(csv, cols): |
|
df = pd.read_csv(csv, sep="\t") |
|
output = dict(zip(df[cols[0]], df[cols[1]])) |
|
return output |
|
|
|
|
|
def init_logger(filename, level="INFO"): |
|
formatter = logging.Formatter( |
|
"[ %(levelname)s : %(asctime)s ] - %(message)s") |
|
logger = logging.getLogger(__name__ + "." + filename) |
|
logger.setLevel(getattr(logging, level)) |
|
|
|
|
|
|
|
|
|
filehandler = logging.FileHandler(filename) |
|
filehandler.setFormatter(formatter) |
|
logger.addHandler(filehandler) |
|
|
|
return logger |
|
|
|
|
|
def init_obj(module, config, **kwargs): |
|
obj_args = config["args"].copy() |
|
obj_args.update(kwargs) |
|
return getattr(module, config["type"])(**obj_args) |
|
|
|
|
|
def pprint_dict(in_dict, outputfun=sys.stdout.write, formatter='yaml'): |
|
"""pprint_dict |
|
|
|
:param outputfun: function to use, defaults to sys.stdout |
|
:param in_dict: dict to print |
|
""" |
|
if formatter == 'yaml': |
|
format_fun = yaml.dump |
|
elif formatter == 'pretty': |
|
format_fun = pformat |
|
for line in format_fun(in_dict).split('\n'): |
|
outputfun(line) |
|
|
|
|
|
def merge_a_into_b(a, b): |
|
|
|
for k, v in a.items(): |
|
if isinstance(v, dict) and k in b: |
|
assert isinstance( |
|
b[k], dict |
|
), "Cannot inherit key '{}' from base!".format(k) |
|
merge_a_into_b(v, b[k]) |
|
else: |
|
b[k] = v |
|
|
|
|
|
def load_config(config_file): |
|
with open(config_file, "r") as reader: |
|
config = yaml.load(reader, Loader=yaml.FullLoader) |
|
if "inherit_from" in config: |
|
base_config_file = config["inherit_from"] |
|
base_config_file = os.path.join( |
|
os.path.dirname(config_file), base_config_file |
|
) |
|
assert not os.path.samefile(config_file, base_config_file), \ |
|
"inherit from itself" |
|
base_config = load_config(base_config_file) |
|
del config["inherit_from"] |
|
merge_a_into_b(config, base_config) |
|
return base_config |
|
return config |
|
|
|
|
|
def parse_config_or_kwargs(config_file, **kwargs): |
|
yaml_config = load_config(config_file) |
|
|
|
args = dict(yaml_config, **kwargs) |
|
return args |
|
|
|
|
|
def store_yaml(config, config_file): |
|
with open(config_file, "w") as con_writer: |
|
yaml.dump(config, con_writer, indent=4, default_flow_style=False) |
|
|
|
|
|
class MetricImprover: |
|
|
|
def __init__(self, mode): |
|
assert mode in ("min", "max") |
|
self.mode = mode |
|
|
|
self.best_value = np.inf if mode == "min" else -np.inf |
|
|
|
def compare(self, x, best_x): |
|
return x < best_x if self.mode == "min" else x > best_x |
|
|
|
def __call__(self, x): |
|
if self.compare(x, self.best_value): |
|
self.best_value = x |
|
return True |
|
return False |
|
|
|
def state_dict(self): |
|
return self.__dict__ |
|
|
|
def load_state_dict(self, state_dict): |
|
self.__dict__.update(state_dict) |
|
|
|
|
|
def fix_batchnorm(model: torch.nn.Module): |
|
def inner(module): |
|
class_name = module.__class__.__name__ |
|
if class_name.find("BatchNorm") != -1: |
|
module.eval() |
|
model.apply(inner) |
|
|
|
|
|
def load_pretrained_model(model: torch.nn.Module, |
|
pretrained: Union[str, Dict], |
|
output_fn: Callable = sys.stdout.write): |
|
if not isinstance(pretrained, dict) and not os.path.exists(pretrained): |
|
output_fn(f"pretrained {pretrained} not exist!") |
|
return |
|
|
|
if hasattr(model, "load_pretrained"): |
|
model.load_pretrained(pretrained) |
|
return |
|
|
|
if isinstance(pretrained, dict): |
|
state_dict = pretrained |
|
else: |
|
state_dict = torch.load(pretrained, map_location="cpu") |
|
|
|
if "model" in state_dict: |
|
state_dict = state_dict["model"] |
|
model_dict = model.state_dict() |
|
pretrained_dict = { |
|
k: v for k, v in state_dict.items() if (k in model_dict) and ( |
|
model_dict[k].shape == v.shape) |
|
} |
|
output_fn(f"Loading pretrained keys {pretrained_dict.keys()}") |
|
model_dict.update(pretrained_dict) |
|
model.load_state_dict(model_dict, strict=True) |
|
|
|
|
|
class AveragedModel(torch_average_model): |
|
|
|
def update_parameters(self, model): |
|
for p_swa, p_model in zip(self.parameters(), model.parameters()): |
|
device = p_swa.device |
|
p_model_ = p_model.detach().to(device) |
|
if self.n_averaged == 0: |
|
p_swa.detach().copy_(p_model_) |
|
else: |
|
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, |
|
self.n_averaged.to(device))) |
|
|
|
for b_swa, b_model in zip(list(self.buffers())[1:], model.buffers()): |
|
device = b_swa.device |
|
b_model_ = b_model.detach().to(device) |
|
if self.n_averaged == 0: |
|
b_swa.detach().copy_(b_model_) |
|
else: |
|
b_swa.detach().copy_(self.avg_fn(b_swa.detach(), b_model_, |
|
self.n_averaged.to(device))) |
|
self.n_averaged += 1 |
|
|